diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index e84103a50..f28224d37 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -73,6 +73,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { HybridConditional(boost::make_shared(key, table))); } + using Base::push_back; + /// Get a specific Gaussian mixture by index `i`. GaussianMixture::shared_ptr atMixture(size_t i) const; diff --git a/gtsam/hybrid/HybridFactorGraph.h b/gtsam/hybrid/HybridFactorGraph.h index fc730f0c9..05a17b000 100644 --- a/gtsam/hybrid/HybridFactorGraph.h +++ b/gtsam/hybrid/HybridFactorGraph.h @@ -135,6 +135,28 @@ class HybridFactorGraph : public FactorGraph { push_hybrid(p); } } + + /// Get all the discrete keys in the factor graph. + const KeySet discreteKeys() const { + KeySet discrete_keys; + for (auto& factor : factors_) { + for (const DiscreteKey& k : factor->discreteKeys()) { + discrete_keys.insert(k.first); + } + } + return discrete_keys; + } + + /// Get all the continuous keys in the factor graph. + const KeySet continuousKeys() const { + KeySet keys; + for (auto& factor : factors_) { + for (const Key& key : factor->continuousKeys()) { + keys.insert(key); + } + } + return keys; + } }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 09b592bd6..ddb776ff4 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -404,31 +404,9 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { FactorGraph::add(boost::make_shared(factor)); } -/* ************************************************************************ */ -const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const { - KeySet discrete_keys; - for (auto &factor : factors_) { - for (const DiscreteKey &k : factor->discreteKeys()) { - discrete_keys.insert(k.first); - } - } - return discrete_keys; -} - -/* ************************************************************************ */ -const KeySet HybridGaussianFactorGraph::getContinuousKeys() const { - KeySet keys; - for (auto &factor : factors_) { - for (const Key &key : factor->continuousKeys()) { - keys.insert(key); - } - } - return keys; -} - /* ************************************************************************ */ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { - KeySet discrete_keys = getDiscreteKeys(); + KeySet discrete_keys = discreteKeys(); for (auto &factor : factors_) { for (const DiscreteKey &k : factor->discreteKeys()) { discrete_keys.insert(k.first); diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index ad5cde09b..bd24cdeaa 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -161,12 +161,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph } } - /// Get all the discrete keys in the factor graph. - const KeySet getDiscreteKeys() const; - - /// Get all the continuous keys in the factor graph. - const KeySet getContinuousKeys() const; - /** * @brief Return a Colamd constrained ordering where the discrete keys are * eliminated after the continuous keys. diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp index 6946775b9..57e50104d 100644 --- a/gtsam/hybrid/HybridGaussianISAM.cpp +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -62,23 +62,24 @@ void HybridGaussianISAM::updateInternal( for (const sharedClique& orphan : *orphans) factors += boost::make_shared >(orphan); - KeySet allDiscrete; - for (auto& factor : factors) { - for (auto& k : factor->discreteKeys()) { - allDiscrete.insert(k.first); - } - } + // Get all the discrete keys from the factors + KeySet allDiscrete = factors.discreteKeys(); + + // Create KeyVector with continuous keys followed by discrete keys. KeyVector newKeysDiscreteLast; + // Insert continuous keys first. for (auto& k : newFactorKeys) { if (!allDiscrete.exists(k)) { newKeysDiscreteLast.push_back(k); } } + // Insert discrete keys at the end std::copy(allDiscrete.begin(), allDiscrete.end(), std::back_inserter(newKeysDiscreteLast)); // Get an ordering where the new keys are eliminated last const VariableIndex index(factors); + Ordering elimination_ordering; if (ordering) { elimination_ordering = *ordering; diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 0c15ee83d..cc2ab1759 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -52,6 +52,21 @@ TEST(HybridBayesNet, Creation) { EXPECT(df.equals(expected)); } +/* ****************************************************************************/ +// Test adding a bayes net to another one. +TEST(HybridBayesNet, Add) { + HybridBayesNet bayesNet; + + bayesNet.add(Asia, "99/1"); + + DiscreteConditional expected(Asia, "99/1"); + + HybridBayesNet other; + other.push_back(bayesNet); + EXPECT(bayesNet.equals(other)); +} + + /* ****************************************************************************/ // Test choosing an assignment of conditionals TEST(HybridBayesNet, Choose) { diff --git a/gtsam/hybrid/tests/testHybridIncremental.cpp b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp similarity index 100% rename from gtsam/hybrid/tests/testHybridIncremental.cpp rename to gtsam/hybrid/tests/testHybridGaussianISAM.cpp