From 1a3b343537c8d0858ca2cfeee9859a4633563973 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 8 Nov 2022 14:00:44 -0500 Subject: [PATCH] minor clean up and get tests to pass --- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 80 ++++++++++++++----- gtsam/hybrid/HybridGaussianFactorGraph.h | 21 +++-- .../tests/testHybridNonlinearFactorGraph.cpp | 12 ++- 3 files changed, 87 insertions(+), 26 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 6024a59bc..62d681665 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -512,9 +512,17 @@ HybridGaussianFactorGraph::continuousDelta( /* ************************************************************************ */ AlgebraicDecisionTree HybridGaussianFactorGraph::continuousProbPrimes( - const DiscreteKeys &discrete_keys, - const boost::shared_ptr &continuousBayesNet, - const std::vector &assignments) const { + const DiscreteKeys &orig_discrete_keys, + const boost::shared_ptr &continuousBayesNet) const { + // Generate all possible assignments. + const std::vector assignments = + DiscreteValues::CartesianProduct(orig_discrete_keys); + + // Save a copy of the original discrete key ordering + DiscreteKeys discrete_keys(orig_discrete_keys); + // Reverse discrete keys order for correct tree construction + std::reverse(discrete_keys.begin(), discrete_keys.end()); + // Create a decision tree of all the different VectorValues DecisionTree delta_tree = this->continuousDelta(discrete_keys, continuousBayesNet, assignments); @@ -532,7 +540,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::continuousProbPrimes( } double error = 0.0; - + // Compute the error given the delta and the assignment. for (size_t idx = 0; idx < size(); idx++) { auto factor = factors_.at(idx); @@ -563,15 +571,21 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::continuousProbPrimes( /* ************************************************************************ */ boost::shared_ptr -HybridGaussianFactorGraph::eliminateHybridSequential(const boost::optional continuous, const boost::optional discrete) const { - Ordering continuous_ordering(this->continuousKeys()), - discrete_ordering(this->discreteKeys()); +HybridGaussianFactorGraph::eliminateHybridSequential( + const boost::optional continuous, + const boost::optional discrete, const Eliminate &function, + OptionalVariableIndex variableIndex) const { + Ordering continuous_ordering = + continuous ? *continuous : Ordering(this->continuousKeys()); + Ordering discrete_ordering = + discrete ? *discrete : Ordering(this->discreteKeys()); // Eliminate continuous HybridBayesNet::shared_ptr bayesNet; HybridGaussianFactorGraph::shared_ptr discreteGraph; std::tie(bayesNet, discreteGraph) = - BaseEliminateable::eliminatePartialSequential(continuous_ordering); + BaseEliminateable::eliminatePartialSequential(continuous_ordering, + function, variableIndex); // Get the last continuous conditional which will have all the discrete keys auto last_conditional = bayesNet->at(bayesNet->size() - 1); @@ -582,26 +596,54 @@ HybridGaussianFactorGraph::eliminateHybridSequential(const boost::optional assignments = - DiscreteValues::CartesianProduct(discrete_keys); - - // Save a copy of the original discrete key ordering - DiscreteKeys orig_discrete_keys(discrete_keys); - // Reverse discrete keys order for correct tree construction - std::reverse(discrete_keys.begin(), discrete_keys.end()); - AlgebraicDecisionTree probPrimeTree = - continuousProbPrimes(discrete_keys, bayesNet, assignments); + this->continuousProbPrimes(discrete_keys, bayesNet); - discreteGraph->add(DecisionTreeFactor(orig_discrete_keys, probPrimeTree)); + discreteGraph->add(DecisionTreeFactor(discrete_keys, probPrimeTree)); // Perform discrete elimination HybridBayesNet::shared_ptr discreteBayesNet = - discreteGraph->eliminateSequential(discrete_ordering); + discreteGraph->BaseEliminateable::eliminateSequential( + discrete_ordering, function, variableIndex); bayesNet->add(*discreteBayesNet); return bayesNet; } +/* ************************************************************************ */ +boost::shared_ptr +HybridGaussianFactorGraph::eliminateSequential( + OptionalOrderingType orderingType, const Eliminate &function, + OptionalVariableIndex variableIndex) const { + return BaseEliminateable::eliminateSequential(orderingType, function, + variableIndex); +} + +/* ************************************************************************ */ +boost::shared_ptr +HybridGaussianFactorGraph::eliminateSequential( + const Ordering &ordering, const Eliminate &function, + OptionalVariableIndex variableIndex) const { + KeySet all_continuous_keys = this->continuousKeys(); + KeySet all_discrete_keys = this->discreteKeys(); + Ordering continuous_ordering, discrete_ordering; + + // Segregate the continuous and the discrete keys + for (auto &&key : ordering) { + if (std::find(all_continuous_keys.begin(), all_continuous_keys.end(), + key) != all_continuous_keys.end()) { + continuous_ordering.push_back(key); + } else if (std::find(all_discrete_keys.begin(), all_discrete_keys.end(), + key) != all_discrete_keys.end()) { + discrete_ordering.push_back(key); + } else { + throw std::runtime_error("Key in ordering not present in factors."); + } + } + + return this->eliminateHybridSequential(continuous_ordering, + discrete_ordering); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 31a707579..1198cc8bc 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -214,14 +214,11 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * @param discrete_keys The discrete keys which form all the modes. * @param continuousBayesNet The Bayes Net representing the continuous * eliminated variables. - * @param assignments List of all discrete assignments to create the final - * decision tree. * @return AlgebraicDecisionTree */ AlgebraicDecisionTree continuousProbPrimes( const DiscreteKeys& discrete_keys, - const boost::shared_ptr& continuousBayesNet, - const std::vector& assignments) const; + const boost::shared_ptr& continuousBayesNet) const; /** * @brief Custom elimination function which computes the correct @@ -232,8 +229,20 @@ class GTSAM_EXPORT HybridGaussianFactorGraph * @return boost::shared_ptr */ boost::shared_ptr eliminateHybridSequential( - const boost::optional continuous, - const boost::optional discrete) const; + const boost::optional continuous = boost::none, + const boost::optional discrete = boost::none, + const Eliminate& function = EliminationTraitsType::DefaultEliminate, + OptionalVariableIndex variableIndex = boost::none) const; + + boost::shared_ptr eliminateSequential( + OptionalOrderingType orderingType = boost::none, + const Eliminate& function = EliminationTraitsType::DefaultEliminate, + OptionalVariableIndex variableIndex = boost::none) const; + + boost::shared_ptr eliminateSequential( + const Ordering& ordering, + const Eliminate& function = EliminationTraitsType::DefaultEliminate, + OptionalVariableIndex variableIndex = boost::none) const; /** * @brief Return a Colamd constrained ordering where the discrete keys are diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index f6889f132..f8c61baf7 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -372,7 +372,8 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) { dynamic_pointer_cast(hybridDiscreteFactor->inner()); CHECK(discreteFactor); EXPECT_LONGS_EQUAL(1, discreteFactor->discreteKeys().size()); - EXPECT(discreteFactor->root_->isLeaf() == false); + // All leaves should be probability 1 since this is not P*(X|M,Z) + EXPECT(discreteFactor->root_->isLeaf()); // TODO(Varun) Test emplace_discrete } @@ -439,6 +440,15 @@ TEST(HybridFactorGraph, Full_Elimination) { auto df = dynamic_pointer_cast(factor); discrete_fg.push_back(df->inner()); } + + // Get the probabilit P*(X | M, Z) + DiscreteKeys discrete_keys = + remainingFactorGraph_partial->at(2)->discreteKeys(); + AlgebraicDecisionTree probPrimeTree = + linearizedFactorGraph.continuousProbPrimes(discrete_keys, + hybridBayesNet_partial); + discrete_fg.add(DecisionTreeFactor(discrete_keys, probPrimeTree)); + ordering.clear(); for (size_t k = 0; k < self.K - 1; k++) ordering += M(k); discreteBayesNet =