From 93528c3d4f8ae5c87a063a3232fb2894e7899504 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 16 Sep 2022 12:19:24 -0400 Subject: [PATCH 1/6] Only eliminate variables that are in newFactors --- gtsam/hybrid/HybridGaussianISAM.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 6946775b9..c7811992e 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -62,23 +62,29 @@ void HybridGaussianISAM::updateInternal( for (const sharedClique& orphan : *orphans) factors += boost::make_shared >(orphan); + // Get all the discrete keys from the new factors KeySet allDiscrete; - for (auto& factor : factors) { + for (auto& factor : newFactors) { for (auto& k : factor->discreteKeys()) { allDiscrete.insert(k.first); } } + + // Create KeyVector with continuous keys followed by discrete keys. KeyVector newKeysDiscreteLast; + // Insert continuous keys first. for (auto& k : newFactorKeys) { if (!allDiscrete.exists(k)) { newKeysDiscreteLast.push_back(k); } } + // Insert discrete keys at the end std::copy(allDiscrete.begin(), allDiscrete.end(), std::back_inserter(newKeysDiscreteLast)); // Get an ordering where the new keys are eliminated last - const VariableIndex index(factors); + const VariableIndex index(newFactors); + Ordering elimination_ordering; if (ordering) { elimination_ordering = *ordering; @@ -89,10 +95,14 @@ void HybridGaussianISAM::updateInternal( true); } + GTSAM_PRINT(elimination_ordering); + std::cout << "\n\n\n\neliminateMultifrontal" << std::endl; + GTSAM_PRINT(factors); // eliminate all factors (top, added, orphans) into a new Bayes tree HybridBayesTree::shared_ptr bayesTree = factors.eliminateMultifrontal(elimination_ordering, function, index); + std::cout << "optionally prune" << std::endl; if (maxNrLeaves) { bayesTree->prune(*maxNrLeaves); } From aebcde99e2ae949a1bc20dac3d1fb5853991a576 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 16 Sep 2022 18:13:59 -0400 Subject: [PATCH 2/6] add push_back to HybridBayesNet --- gtsam/hybrid/HybridBayesNet.h | 2 ++ gtsam/hybrid/tests/testHybridBayesNet.cpp | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index e84103a50..f28224d37 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -73,6 +73,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { HybridConditional(boost::make_shared(key, table))); } + using Base::push_back; + /// Get a specific Gaussian mixture by index `i`. GaussianMixture::shared_ptr atMixture(size_t i) const; diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 0c15ee83d..cc2ab1759 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -52,6 +52,21 @@ TEST(HybridBayesNet, Creation) { EXPECT(df.equals(expected)); } +/* ****************************************************************************/ +// Test adding a bayes net to another one. +TEST(HybridBayesNet, Add) { + HybridBayesNet bayesNet; + + bayesNet.add(Asia, "99/1"); + + DiscreteConditional expected(Asia, "99/1"); + + HybridBayesNet other; + other.push_back(bayesNet); + EXPECT(bayesNet.equals(other)); +} + + /* ****************************************************************************/ // Test choosing an assignment of conditionals TEST(HybridBayesNet, Choose) { From 9ef5c184ec379acb04c1bbd8619bda0e9ca25b5f Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 17 Sep 2022 08:04:55 -0400 Subject: [PATCH 3/6] move renamed allDiscreteKeys and allContinuousKeys to HybridFactorGraph --- gtsam/hybrid/HybridFactorGraph.h | 22 ++++++++++++++++++++ gtsam/hybrid/HybridGaussianFactorGraph.cpp | 24 +--------------------- gtsam/hybrid/HybridGaussianFactorGraph.h | 6 ------ 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index fc730f0c9..ea071a020 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -135,6 +135,28 @@ class HybridFactorGraph : public FactorGraph { push_hybrid(p); } } + + /// Get all the discrete keys in the factor graph. + const KeySet allDiscreteKeys() const { + KeySet discrete_keys; + for (auto& factor : factors_) { + for (const DiscreteKey& k : factor->discreteKeys()) { + discrete_keys.insert(k.first); + } + } + return discrete_keys; + } + + /// Get all the continuous keys in the factor graph. + const KeySet allContinuousKeys() const { + KeySet keys; + for (auto& factor : factors_) { + for (const Key& key : factor->continuousKeys()) { + keys.insert(key); + } + } + return keys; + } }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 09b592bd6..c08e774f2 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -404,31 +404,9 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { FactorGraph::add(boost::make_shared(factor)); } -/* ************************************************************************ */ -const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const { - KeySet discrete_keys; - for (auto &factor : factors_) { - for (const DiscreteKey &k : factor->discreteKeys()) { - discrete_keys.insert(k.first); - } - } - return discrete_keys; -} - -/* ************************************************************************ */ -const KeySet HybridGaussianFactorGraph::getContinuousKeys() const { - KeySet keys; - for (auto &factor : factors_) { - for (const Key &key : factor->continuousKeys()) { - keys.insert(key); - } - } - return keys; -} - /* ************************************************************************ */ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { - KeySet discrete_keys = getDiscreteKeys(); + KeySet discrete_keys = allDiscreteKeys(); for (auto &factor : factors_) { for (const DiscreteKey &k : factor->discreteKeys()) { discrete_keys.insert(k.first); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index ad5cde09b..bd24cdeaa 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -161,12 +161,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph } } - /// Get all the discrete keys in the factor graph. - const KeySet getDiscreteKeys() const; - - /// Get all the continuous keys in the factor graph. - const KeySet getContinuousKeys() const; - /** * @brief Return a Colamd constrained ordering where the discrete keys are * eliminated after the continuous keys. From 12db5dd947481bc1990dddae5f527ca4724be788 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 17 Sep 2022 08:05:07 -0400 Subject: [PATCH 4/6] undo changes --- gtsam/hybrid/HybridGaussianISAM.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index c7811992e..1a95c0c93 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -62,9 +62,9 @@ void HybridGaussianISAM::updateInternal( for (const sharedClique& orphan : *orphans) factors += boost::make_shared >(orphan); - // Get all the discrete keys from the new factors + // Get all the discrete keys from the factors KeySet allDiscrete; - for (auto& factor : newFactors) { + for (auto& factor : factors) { for (auto& k : factor->discreteKeys()) { allDiscrete.insert(k.first); } @@ -83,7 +83,7 @@ void HybridGaussianISAM::updateInternal( std::back_inserter(newKeysDiscreteLast)); // Get an ordering where the new keys are eliminated last - const VariableIndex index(newFactors); + const VariableIndex index(factors); Ordering elimination_ordering; if (ordering) { @@ -95,14 +95,10 @@ void HybridGaussianISAM::updateInternal( true); } - GTSAM_PRINT(elimination_ordering); - std::cout << "\n\n\n\neliminateMultifrontal" << std::endl; - GTSAM_PRINT(factors); // eliminate all factors (top, added, orphans) into a new Bayes tree HybridBayesTree::shared_ptr bayesTree = factors.eliminateMultifrontal(elimination_ordering, function, index); - std::cout << "optionally prune" << std::endl; if (maxNrLeaves) { bayesTree->prune(*maxNrLeaves); } From 2f8a0f82e0337e80dca1b7d9c608b1f7cbce792c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Mon, 19 Sep 2022 18:23:18 -0400 Subject: [PATCH 5/6] rename testHybridIncremental to testHybridGaussianISAM --- .../{testHybridIncremental.cpp => testHybridGaussianISAM.cpp} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename gtsam/hybrid/tests/{testHybridIncremental.cpp => testHybridGaussianISAM.cpp} (100%) diff --git a/gtsam/hybrid/tests/testHybridIncremental.cpp b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp similarity index 100% rename from gtsam/hybrid/tests/testHybridIncremental.cpp rename to gtsam/hybrid/tests/testHybridGaussianISAM.cpp From c2ca426acce0bf60eb14f66267b08b029abc28fb Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Tue, 20 Sep 2022 05:16:26 -0400 Subject: [PATCH 6/6] rename allDiscreteKeys and allContinuousKeys to discreteKeys and continuousKeys respectively --- gtsam/hybrid/HybridFactorGraph.h | 4 ++-- gtsam/hybrid/HybridGaussianFactorGraph.cpp | 2 +- gtsam/hybrid/HybridGaussianISAM.cpp | 7 +------ 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index ea071a020..05a17b000 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -137,7 +137,7 @@ class HybridFactorGraph : public FactorGraph { } /// Get all the discrete keys in the factor graph. - const KeySet allDiscreteKeys() const { + const KeySet discreteKeys() const { KeySet discrete_keys; for (auto& factor : factors_) { for (const DiscreteKey& k : factor->discreteKeys()) { @@ -148,7 +148,7 @@ class HybridFactorGraph : public FactorGraph { } /// Get all the continuous keys in the factor graph. - const KeySet allContinuousKeys() const { + const KeySet continuousKeys() const { KeySet keys; for (auto& factor : factors_) { for (const Key& key : factor->continuousKeys()) { diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index c08e774f2..ddb776ff4 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -406,7 +406,7 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { /* ************************************************************************ */ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { - KeySet discrete_keys = allDiscreteKeys(); + KeySet discrete_keys = discreteKeys(); for (auto &factor : factors_) { for (const DiscreteKey &k : factor->discreteKeys()) { discrete_keys.insert(k.first); diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 1a95c0c93..57e50104d 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -63,12 +63,7 @@ void HybridGaussianISAM::updateInternal( factors += boost::make_shared >(orphan); // Get all the discrete keys from the factors - KeySet allDiscrete; - for (auto& factor : factors) { - for (auto& k : factor->discreteKeys()) { - allDiscrete.insert(k.first); - } - } + KeySet allDiscrete = factors.discreteKeys(); // Create KeyVector with continuous keys followed by discrete keys. KeyVector newKeysDiscreteLast;