discreteConditionals returns DiscreteConditional

release/4.3a0
Varun Agrawal 2023-07-09 20:24:24 -04:00
parent f6b1872b13
commit 2940e69a73
2 changed files with 9 additions and 13 deletions

View File

@ -38,21 +38,17 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { DiscreteConditional::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> discreteProbs;
// The canonical decision tree factor which will get // The canonical decision tree factor which will get
// the discrete conditionals added to it. // the discrete conditionals added to it.
DecisionTreeFactor discreteProbsFactor; DiscreteConditional discreteProbs;
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor. discreteProbs = discreteProbs * (*conditional->asDiscrete());
DecisionTreeFactor f(*conditional->asDiscrete());
discreteProbsFactor = discreteProbsFactor * f;
} }
} }
return std::make_shared<DecisionTreeFactor>(discreteProbsFactor); return std::make_shared<DiscreteConditional>(discreteProbs);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -146,7 +142,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals( void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor &prunedDiscreteProbs) { const DecisionTreeFactor &prunedDiscreteProbs) {
//TODO(Varun) Should prune the joint conditional, maybe during elimination? // TODO(Varun) Should prune the joint conditional, maybe during elimination?
// Loop with index since we need it later. // Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i); HybridConditional::shared_ptr conditional = this->at(i);
@ -179,7 +175,7 @@ void HybridBayesNet::updateDiscreteConditionals(
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys // Get the decision tree of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals); gttic_(HybridBayesNet_PruneDiscreteConditionals);
DecisionTreeFactor::shared_ptr discreteConditionals = DiscreteConditional::shared_ptr discreteConditionals =
this->discreteConditionals(); this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs = const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves); discreteConditionals->prune(maxNrLeaves);

View File

@ -139,9 +139,9 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/** /**
* @brief Get all the discrete conditionals as a decision tree factor. * @brief Get all the discrete conditionals as a decision tree factor.
* *
* @return DecisionTreeFactor::shared_ptr * @return DiscreteConditional::shared_ptr
*/ */
DecisionTreeFactor::shared_ptr discreteConditionals() const; DiscreteConditional::shared_ptr discreteConditionals() const;
/** /**
* @brief Sample from an incomplete BayesNet, given missing variables. * @brief Sample from an incomplete BayesNet, given missing variables.