diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 96a6dfd63..cc27600f0 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -22,14 +22,6 @@ namespace gtsam { -/* ************************************************************************* */ -/// Return the DiscreteKey vector as a set. -static std::set DiscreteKeysAsSet(const DiscreteKeys &dkeys) { - std::set s; - s.insert(dkeys.begin(), dkeys.end()); - return s; -} - /* ************************************************************************* */ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { AlgebraicDecisionTree decisionTree; @@ -49,63 +41,6 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { return boost::make_shared(dtFactor); } -/** - * @brief Helper function to get the pruner functional. - * - * @param probDecisionTree The probability decision tree of only discrete keys. - * @param discreteFactorKeySet Set of DiscreteKeys in probDecisionTree. - * Pre-computed for efficiency. - * @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture. - * @return std::function &, const GaussianConditional::shared_ptr &)> - */ -std::function &, const GaussianConditional::shared_ptr &)> -PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree, - const std::set &discreteFactorKeySet, - const std::set &gaussianMixtureKeySet) { - auto pruner = [&](const Assignment &choices, - const GaussianConditional::shared_ptr &conditional) - -> GaussianConditional::shared_ptr { - // typecast so we can use this to get probability value - DiscreteValues values(choices); - - // Case where the gaussian mixture has the same - // discrete keys as the decision tree. - if (gaussianMixtureKeySet == discreteFactorKeySet) { - if ((*probDecisionTree)(values) == 0.0) { - // empty aka null pointer - boost::shared_ptr null; - return null; - } else { - return conditional; - } - } else { - std::vector set_diff; - std::set_difference( - discreteFactorKeySet.begin(), discreteFactorKeySet.end(), - gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(), - std::back_inserter(set_diff)); - const std::vector assignments = - DiscreteValues::CartesianProduct(set_diff); - for (const DiscreteValues &assignment : assignments) { - DiscreteValues augmented_values(values); - augmented_values.insert(assignment.begin(), assignment.end()); - - // If any one of the sub-branches are non-zero, - // we need this conditional. - if ((*probDecisionTree)(augmented_values) > 0.0) { - return conditional; - } - } - // If we are here, it means that all the sub-branches are 0, - // so we prune. - return nullptr; - } - }; - return pruner; -} - /* ************************************************************************* */ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { // Get the decision tree of only the discrete keys @@ -114,8 +49,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { boost::make_shared( discreteConditionals->prune(maxNrLeaves)); - auto discreteFactorKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys()); - /* To Prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. @@ -130,35 +63,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { for (size_t i = 0; i < this->size(); i++) { HybridConditional::shared_ptr conditional = this->at(i); - GaussianMixture::shared_ptr gaussianMixture = - boost::dynamic_pointer_cast(conditional->inner()); + if (conditional->isHybrid()) { + GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture(); - if (gaussianMixture) { - // We may have mixtures with less discrete keys than discreteFactor so - // we skip those since the label assignment does not exist. - auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys()); - - // Get the pruner function. - auto pruner = PrunerFunc(discreteFactor, discreteFactorKeySet, gmKeySet); - - // Run the pruning to get a new, pruned tree - GaussianMixture::Conditionals prunedTree = - gaussianMixture->conditionals().apply(pruner); - - DiscreteKeys discreteKeys = gaussianMixture->discreteKeys(); - // reverse keys to get a natural ordering - std::reverse(discreteKeys.begin(), discreteKeys.end()); - - // Convert from boost::iterator_range to KeyVector - // so we can pass it to constructor. - KeyVector frontals(gaussianMixture->frontals().begin(), - gaussianMixture->frontals().end()), - parents(gaussianMixture->parents().begin(), - gaussianMixture->parents().end()); - - // Create the new gaussian mixture and add it to the bayes net. - auto prunedGaussianMixture = boost::make_shared( - frontals, parents, discreteKeys, prunedTree); + // Make a copy of the gaussian mixture and prune it! + auto prunedGaussianMixture = + boost::make_shared(*gaussianMixture); + prunedGaussianMixture->prune(*discreteFactor); // Type-erase and add to the pruned Bayes Net fragment. prunedBayesNetFragment.push_back( diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 8fb487ae2..266b295dd 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -149,16 +149,16 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { auto decisionTree = boost::dynamic_pointer_cast( this->roots_.at(0)->conditional()->inner()); - DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); - decisionTree->root_ = prunedDiscreteFactor.root_; + DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); + decisionTree->root_ = prunedDecisionTree.root_; /// Helper struct for pruning the hybrid bayes tree. struct HybridPrunerData { /// The discrete decision tree after pruning. - DecisionTreeFactor prunedDiscreteFactor; - HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor, + DecisionTreeFactor prunedDecisionTree; + HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree, const HybridBayesTree::sharedNode& parentClique) - : prunedDiscreteFactor(prunedDiscreteFactor) {} + : prunedDecisionTree(prunedDecisionTree) {} /** * @brief A function used during tree traversal that operates on each node @@ -178,19 +178,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { if (conditional->isHybrid()) { auto gaussianMixture = conditional->asMixture(); - // Check if the number of discrete keys match, - // else we get an assignment error. - // TODO(Varun) Update prune method to handle assignment subset? - if (gaussianMixture->discreteKeys() == - parentData.prunedDiscreteFactor.discreteKeys()) { - gaussianMixture->prune(parentData.prunedDiscreteFactor); - } + gaussianMixture->prune(parentData.prunedDecisionTree); } return parentData; } }; - HybridPrunerData rootData(prunedDiscreteFactor, 0); + HybridPrunerData rootData(prunedDecisionTree, 0); { treeTraversal::no_op visitorPost; // Limits OpenMP threads since we're mixing TBB and OpenMP diff --git a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp index a0a87933f..a5e3903d9 100644 --- a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp @@ -337,7 +337,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) { EXPECT_LONGS_EQUAL( 2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( - 4, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents()); + 3, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( 5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( diff --git a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp index 4e1710c42..93a8a1e00 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp @@ -363,7 +363,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) { EXPECT_LONGS_EQUAL( 2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( - 4, bayesTree[X(2)]->conditional()->asMixture()->nrComponents()); + 3, bayesTree[X(2)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL( 5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents()); EXPECT_LONGS_EQUAL(