Merge pull request #1294 from borglab/hybrid/check-elimination

release/4.3a0
Varun Agrawal 2022-10-03 17:25:21 -04:00 committed by GitHub
commit 3407f9798b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 47 additions and 35 deletions

View File

@ -73,6 +73,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
HybridConditional(boost::make_shared<DiscreteConditional>(key, table))); HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
} }
using Base::push_back;
/// Get a specific Gaussian mixture by index `i`. /// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atMixture(size_t i) const; GaussianMixture::shared_ptr atMixture(size_t i) const;

View File

@ -135,6 +135,28 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
push_hybrid(p); push_hybrid(p);
} }
} }
/// Get all the discrete keys in the factor graph.
const KeySet discreteKeys() const {
KeySet discrete_keys;
for (auto& factor : factors_) {
for (const DiscreteKey& k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
return discrete_keys;
}
/// Get all the continuous keys in the factor graph.
const KeySet continuousKeys() const {
KeySet keys;
for (auto& factor : factors_) {
for (const Key& key : factor->continuousKeys()) {
keys.insert(key);
}
}
return keys;
}
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -404,31 +404,9 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor)); FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
} }
/* ************************************************************************ */
const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const {
KeySet discrete_keys;
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
return discrete_keys;
}
/* ************************************************************************ */
const KeySet HybridGaussianFactorGraph::getContinuousKeys() const {
KeySet keys;
for (auto &factor : factors_) {
for (const Key &key : factor->continuousKeys()) {
keys.insert(key);
}
}
return keys;
}
/* ************************************************************************ */ /* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const { const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = getDiscreteKeys(); KeySet discrete_keys = discreteKeys();
for (auto &factor : factors_) { for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) { for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first); discrete_keys.insert(k.first);

View File

@ -161,12 +161,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
} }
} }
/// Get all the discrete keys in the factor graph.
const KeySet getDiscreteKeys() const;
/// Get all the continuous keys in the factor graph.
const KeySet getContinuousKeys() const;
/** /**
* @brief Return a Colamd constrained ordering where the discrete keys are * @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys. * eliminated after the continuous keys.

View File

@ -62,23 +62,24 @@ void HybridGaussianISAM::updateInternal(
for (const sharedClique& orphan : *orphans) for (const sharedClique& orphan : *orphans)
factors += boost::make_shared<BayesTreeOrphanWrapper<Node> >(orphan); factors += boost::make_shared<BayesTreeOrphanWrapper<Node> >(orphan);
KeySet allDiscrete; // Get all the discrete keys from the factors
for (auto& factor : factors) { KeySet allDiscrete = factors.discreteKeys();
for (auto& k : factor->discreteKeys()) {
allDiscrete.insert(k.first); // Create KeyVector with continuous keys followed by discrete keys.
}
}
KeyVector newKeysDiscreteLast; KeyVector newKeysDiscreteLast;
// Insert continuous keys first.
for (auto& k : newFactorKeys) { for (auto& k : newFactorKeys) {
if (!allDiscrete.exists(k)) { if (!allDiscrete.exists(k)) {
newKeysDiscreteLast.push_back(k); newKeysDiscreteLast.push_back(k);
} }
} }
// Insert discrete keys at the end
std::copy(allDiscrete.begin(), allDiscrete.end(), std::copy(allDiscrete.begin(), allDiscrete.end(),
std::back_inserter(newKeysDiscreteLast)); std::back_inserter(newKeysDiscreteLast));
// Get an ordering where the new keys are eliminated last // Get an ordering where the new keys are eliminated last
const VariableIndex index(factors); const VariableIndex index(factors);
Ordering elimination_ordering; Ordering elimination_ordering;
if (ordering) { if (ordering) {
elimination_ordering = *ordering; elimination_ordering = *ordering;

View File

@ -52,6 +52,21 @@ TEST(HybridBayesNet, Creation) {
EXPECT(df.equals(expected)); EXPECT(df.equals(expected));
} }
/* ****************************************************************************/
// Test adding a bayes net to another one.
TEST(HybridBayesNet, Add) {
HybridBayesNet bayesNet;
bayesNet.add(Asia, "99/1");
DiscreteConditional expected(Asia, "99/1");
HybridBayesNet other;
other.push_back(bayesNet);
EXPECT(bayesNet.equals(other));
}
/* ****************************************************************************/ /* ****************************************************************************/
// Test choosing an assignment of conditionals // Test choosing an assignment of conditionals
TEST(HybridBayesNet, Choose) { TEST(HybridBayesNet, Choose) {