diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 546d0200b..659d44423 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -228,19 +228,19 @@ std::set DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { /** * @brief Helper function to get the pruner functional. * - * @param decisionTree The probability decision tree of only discrete keys. + * @param discreteProbs The probabilities of only discrete keys. * @return std::function &, const GaussianConditional::shared_ptr &)> */ std::function &, const GaussianConditional::shared_ptr &)> -GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { +GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) { // Get the discrete keys as sets for the decision tree // and the gaussian mixture. - auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); + auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys()); - auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet]( + auto pruner = [discreteProbs, discreteProbsKeySet, gaussianMixtureKeySet]( const Assignment &choices, const GaussianConditional::shared_ptr &conditional) -> GaussianConditional::shared_ptr { @@ -249,8 +249,8 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { // Case where the gaussian mixture has the same // discrete keys as the decision tree. - if (gaussianMixtureKeySet == decisionTreeKeySet) { - if (decisionTree(values) == 0.0) { + if (gaussianMixtureKeySet == discreteProbsKeySet) { + if (discreteProbs(values) == 0.0) { // empty aka null pointer std::shared_ptr null; return null; @@ -259,10 +259,10 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { } } else { std::vector set_diff; - std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), - gaussianMixtureKeySet.begin(), - gaussianMixtureKeySet.end(), - std::back_inserter(set_diff)); + std::set_difference( + discreteProbsKeySet.begin(), discreteProbsKeySet.end(), + gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(), + std::back_inserter(set_diff)); const std::vector assignments = DiscreteValues::CartesianProduct(set_diff); @@ -272,7 +272,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { // If any one of the sub-branches are non-zero, // we need this conditional. - if (decisionTree(augmented_values) > 0.0) { + if (discreteProbs(augmented_values) > 0.0) { return conditional; } } @@ -285,12 +285,12 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { } /* *******************************************************************************/ -void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { - auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); +void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) { + auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys()); // Functional which loops over all assignments and create a set of // GaussianConditionals - auto pruner = prunerFunc(decisionTree); + auto pruner = prunerFunc(discreteProbs); auto pruned_conditionals = conditionals_.apply(pruner); conditionals_.root_ = pruned_conditionals.root_; diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 2d715c6e3..0b68fcfd0 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -74,13 +74,13 @@ class GTSAM_EXPORT GaussianMixture /** * @brief Helper function to get the pruner functor. * - * @param decisionTree The pruned discrete probability decision tree. + * @param discreteProbs The pruned discrete probabilities. * @return std::function &, const GaussianConditional::shared_ptr &)> */ std::function &, const GaussianConditional::shared_ptr &)> - prunerFunc(const DecisionTreeFactor &decisionTree); + prunerFunc(const DecisionTreeFactor &discreteProbs); public: /// @name Constructors @@ -234,12 +234,11 @@ class GTSAM_EXPORT GaussianMixture /** * @brief Prune the decision tree of Gaussian factors as per the discrete - * `decisionTree`. + * `discreteProbs`. * - * @param decisionTree A pruned decision tree of discrete keys where the - * leaves are probabilities. + * @param discreteProbs A pruned set of probabilities for the discrete keys. */ - void prune(const DecisionTreeFactor &decisionTree); + void prune(const DecisionTreeFactor &discreteProbs); /** * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 68129bc27..ab68e170f 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -39,41 +39,41 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { /* ************************************************************************* */ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { - AlgebraicDecisionTree decisionTree; + AlgebraicDecisionTree discreteProbs; // The canonical decision tree factor which will get // the discrete conditionals added to it. - DecisionTreeFactor dtFactor; + DecisionTreeFactor discreteProbsFactor; for (auto &&conditional : *this) { if (conditional->isDiscrete()) { // Convert to a DecisionTreeFactor and add it to the main factor. DecisionTreeFactor f(*conditional->asDiscrete()); - dtFactor = dtFactor * f; + discreteProbsFactor = discreteProbsFactor * f; } } - return std::make_shared(dtFactor); + return std::make_shared(discreteProbsFactor); } /* ************************************************************************* */ /** * @brief Helper function to get the pruner functional. * - * @param prunedDecisionTree The prob. decision tree of only discrete keys. + * @param prunedDiscreteProbs The prob. decision tree of only discrete keys. * @param conditional Conditional to prune. Used to get full assignment. * @return std::function &, double)> */ std::function &, double)> prunerFunc( - const DecisionTreeFactor &prunedDecisionTree, + const DecisionTreeFactor &prunedDiscreteProbs, const HybridConditional &conditional) { // Get the discrete keys as sets for the decision tree // and the Gaussian mixture. - std::set decisionTreeKeySet = - DiscreteKeysAsSet(prunedDecisionTree.discreteKeys()); + std::set discreteProbsKeySet = + DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys()); std::set conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys()); - auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet]( + auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet]( const Assignment &choices, double probability) -> double { // This corresponds to 0 probability @@ -83,8 +83,8 @@ std::function &, double)> prunerFunc( DiscreteValues values(choices); // Case where the Gaussian mixture has the same // discrete keys as the decision tree. - if (conditionalKeySet == decisionTreeKeySet) { - if (prunedDecisionTree(values) == 0) { + if (conditionalKeySet == discreteProbsKeySet) { + if (prunedDiscreteProbs(values) == 0) { return pruned_prob; } else { return probability; @@ -114,11 +114,12 @@ std::function &, double)> prunerFunc( } // Now we generate the full assignment by enumerating - // over all keys in the prunedDecisionTree. + // over all keys in the prunedDiscreteProbs. // First we find the differing keys std::vector set_diff; - std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), - conditionalKeySet.begin(), conditionalKeySet.end(), + std::set_difference(discreteProbsKeySet.begin(), + discreteProbsKeySet.end(), conditionalKeySet.begin(), + conditionalKeySet.end(), std::back_inserter(set_diff)); // Now enumerate over all assignments of the differing keys @@ -130,7 +131,7 @@ std::function &, double)> prunerFunc( // If any one of the sub-branches are non-zero, // we need this probability. - if (prunedDecisionTree(augmented_values) > 0.0) { + if (prunedDiscreteProbs(augmented_values) > 0.0) { return probability; } } @@ -144,8 +145,8 @@ std::function &, double)> prunerFunc( /* ************************************************************************* */ void HybridBayesNet::updateDiscreteConditionals( - const DecisionTreeFactor &prunedDecisionTree) { - KeyVector prunedTreeKeys = prunedDecisionTree.keys(); + const DecisionTreeFactor &prunedDiscreteProbs) { + KeyVector prunedTreeKeys = prunedDiscreteProbs.keys(); // Loop with index since we need it later. for (size_t i = 0; i < this->size(); i++) { @@ -157,7 +158,7 @@ void HybridBayesNet::updateDiscreteConditionals( auto discreteTree = std::dynamic_pointer_cast(discrete); DecisionTreeFactor::ADT prunedDiscreteTree = - discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional)); + discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional)); // Create the new (hybrid) conditional KeyVector frontals(discrete->frontals().begin(), @@ -175,10 +176,12 @@ void HybridBayesNet::updateDiscreteConditionals( /* ************************************************************************* */ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { // Get the decision tree of only the discrete keys - auto discreteConditionals = this->discreteConditionals(); - const auto decisionTree = discreteConditionals->prune(maxNrLeaves); + DecisionTreeFactor::shared_ptr discreteConditionals = + this->discreteConditionals(); + const DecisionTreeFactor prunedDiscreteProbs = + discreteConditionals->prune(maxNrLeaves); - this->updateDiscreteConditionals(decisionTree); + this->updateDiscreteConditionals(prunedDiscreteProbs); /* To Prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree @@ -190,12 +193,12 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { HybridBayesNet prunedBayesNetFragment; // Go through all the conditionals in the - // Bayes Net and prune them as per decisionTree. + // Bayes Net and prune them as per prunedDiscreteProbs. for (auto &&conditional : *this) { if (auto gm = conditional->asMixture()) { // Make a copy of the Gaussian mixture and prune it! auto prunedGaussianMixture = std::make_shared(*gm); - prunedGaussianMixture->prune(decisionTree); // imperative :-( + prunedGaussianMixture->prune(prunedDiscreteProbs); // imperative :-( // Type-erase and add to the pruned Bayes Net fragment. prunedBayesNetFragment.push_back(prunedGaussianMixture); diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 2b0042b8d..23fc4d5d3 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -224,9 +224,9 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /** * @brief Update the discrete conditionals with the pruned versions. * - * @param prunedDecisionTree + * @param prunedDiscreteProbs */ - void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree); + void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs); #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index b252e613e..ae8fa0378 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -173,19 +173,18 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { /* ************************************************************************* */ void HybridBayesTree::prune(const size_t maxNrLeaves) { - auto decisionTree = - this->roots_.at(0)->conditional()->asDiscrete(); + auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete(); - DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); - decisionTree->root_ = prunedDecisionTree.root_; + DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves); + discreteProbs->root_ = prunedDiscreteProbs.root_; /// Helper struct for pruning the hybrid bayes tree. struct HybridPrunerData { /// The discrete decision tree after pruning. - DecisionTreeFactor prunedDecisionTree; - HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree, + DecisionTreeFactor prunedDiscreteProbs; + HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs, const HybridBayesTree::sharedNode& parentClique) - : prunedDecisionTree(prunedDecisionTree) {} + : prunedDiscreteProbs(prunedDiscreteProbs) {} /** * @brief A function used during tree traversal that operates on each node @@ -205,13 +204,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { if (conditional->isHybrid()) { auto gaussianMixture = conditional->asMixture(); - gaussianMixture->prune(parentData.prunedDecisionTree); + gaussianMixture->prune(parentData.prunedDiscreteProbs); } return parentData; } }; - HybridPrunerData rootData(prunedDecisionTree, 0); + HybridPrunerData rootData(prunedDiscreteProbs, 0); { treeTraversal::no_op visitorPost; // Limits OpenMP threads since we're mixing TBB and OpenMP diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index f0d28e9f5..8b9d62822 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -190,7 +190,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors, /* ************************************************************************ */ // If any GaussianFactorGraph in the decision tree contains a nullptr, convert // that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will -// otherwise create a GFG with a single (null) factor, which doesn't register as null. +// otherwise create a GFG with a single (null) factor, +// which doesn't register as null. GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) { auto emptyGaussian = [](const GaussianFactorGraph &graph) { bool hasNull = @@ -246,10 +247,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors, // Perform elimination! DecisionTree eliminationResults(factorGraphTree, eliminate); -#ifdef HYBRID_TIMING - tictoc_print_(); -#endif - // Separate out decision tree into conditionals and remaining factors. const auto [conditionals, newFactors] = unzip(eliminationResults); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index f911b135b..421e69aa0 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -112,8 +112,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph public: using Base = HybridFactorGraph; using This = HybridGaussianFactorGraph; ///< this class - using BaseEliminateable = - EliminateableFactorGraph; ///< for elimination + ///< for elimination + using BaseEliminateable = EliminateableFactorGraph; using shared_ptr = std::shared_ptr; ///< shared_ptr to This using Values = gtsam::Values; ///< backwards compatibility @@ -148,7 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph /// @name Standard Interface /// @{ - using Base::error; // Expose error(const HybridValues&) method.. + /// Expose error(const HybridValues&) method. + using Base::error; /** * @brief Compute error for each discrete assignment,