Merge pull request #1294 from borglab/hybrid/check-elimination
commit
3407f9798b
|
|
@ -73,6 +73,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
|
HybridConditional(boost::make_shared<DiscreteConditional>(key, table)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
using Base::push_back;
|
||||||
|
|
||||||
/// Get a specific Gaussian mixture by index `i`.
|
/// Get a specific Gaussian mixture by index `i`.
|
||||||
GaussianMixture::shared_ptr atMixture(size_t i) const;
|
GaussianMixture::shared_ptr atMixture(size_t i) const;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -135,6 +135,28 @@ class HybridFactorGraph : public FactorGraph<HybridFactor> {
|
||||||
push_hybrid(p);
|
push_hybrid(p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get all the discrete keys in the factor graph.
|
||||||
|
const KeySet discreteKeys() const {
|
||||||
|
KeySet discrete_keys;
|
||||||
|
for (auto& factor : factors_) {
|
||||||
|
for (const DiscreteKey& k : factor->discreteKeys()) {
|
||||||
|
discrete_keys.insert(k.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return discrete_keys;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get all the continuous keys in the factor graph.
|
||||||
|
const KeySet continuousKeys() const {
|
||||||
|
KeySet keys;
|
||||||
|
for (auto& factor : factors_) {
|
||||||
|
for (const Key& key : factor->continuousKeys()) {
|
||||||
|
keys.insert(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return keys;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -404,31 +404,9 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
|
||||||
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const {
|
|
||||||
KeySet discrete_keys;
|
|
||||||
for (auto &factor : factors_) {
|
|
||||||
for (const DiscreteKey &k : factor->discreteKeys()) {
|
|
||||||
discrete_keys.insert(k.first);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return discrete_keys;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
|
||||||
const KeySet HybridGaussianFactorGraph::getContinuousKeys() const {
|
|
||||||
KeySet keys;
|
|
||||||
for (auto &factor : factors_) {
|
|
||||||
for (const Key &key : factor->continuousKeys()) {
|
|
||||||
keys.insert(key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return keys;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
||||||
KeySet discrete_keys = getDiscreteKeys();
|
KeySet discrete_keys = discreteKeys();
|
||||||
for (auto &factor : factors_) {
|
for (auto &factor : factors_) {
|
||||||
for (const DiscreteKey &k : factor->discreteKeys()) {
|
for (const DiscreteKey &k : factor->discreteKeys()) {
|
||||||
discrete_keys.insert(k.first);
|
discrete_keys.insert(k.first);
|
||||||
|
|
|
||||||
|
|
@ -161,12 +161,6 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get all the discrete keys in the factor graph.
|
|
||||||
const KeySet getDiscreteKeys() const;
|
|
||||||
|
|
||||||
/// Get all the continuous keys in the factor graph.
|
|
||||||
const KeySet getContinuousKeys() const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return a Colamd constrained ordering where the discrete keys are
|
* @brief Return a Colamd constrained ordering where the discrete keys are
|
||||||
* eliminated after the continuous keys.
|
* eliminated after the continuous keys.
|
||||||
|
|
|
||||||
|
|
@ -62,23 +62,24 @@ void HybridGaussianISAM::updateInternal(
|
||||||
for (const sharedClique& orphan : *orphans)
|
for (const sharedClique& orphan : *orphans)
|
||||||
factors += boost::make_shared<BayesTreeOrphanWrapper<Node> >(orphan);
|
factors += boost::make_shared<BayesTreeOrphanWrapper<Node> >(orphan);
|
||||||
|
|
||||||
KeySet allDiscrete;
|
// Get all the discrete keys from the factors
|
||||||
for (auto& factor : factors) {
|
KeySet allDiscrete = factors.discreteKeys();
|
||||||
for (auto& k : factor->discreteKeys()) {
|
|
||||||
allDiscrete.insert(k.first);
|
// Create KeyVector with continuous keys followed by discrete keys.
|
||||||
}
|
|
||||||
}
|
|
||||||
KeyVector newKeysDiscreteLast;
|
KeyVector newKeysDiscreteLast;
|
||||||
|
// Insert continuous keys first.
|
||||||
for (auto& k : newFactorKeys) {
|
for (auto& k : newFactorKeys) {
|
||||||
if (!allDiscrete.exists(k)) {
|
if (!allDiscrete.exists(k)) {
|
||||||
newKeysDiscreteLast.push_back(k);
|
newKeysDiscreteLast.push_back(k);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Insert discrete keys at the end
|
||||||
std::copy(allDiscrete.begin(), allDiscrete.end(),
|
std::copy(allDiscrete.begin(), allDiscrete.end(),
|
||||||
std::back_inserter(newKeysDiscreteLast));
|
std::back_inserter(newKeysDiscreteLast));
|
||||||
|
|
||||||
// Get an ordering where the new keys are eliminated last
|
// Get an ordering where the new keys are eliminated last
|
||||||
const VariableIndex index(factors);
|
const VariableIndex index(factors);
|
||||||
|
|
||||||
Ordering elimination_ordering;
|
Ordering elimination_ordering;
|
||||||
if (ordering) {
|
if (ordering) {
|
||||||
elimination_ordering = *ordering;
|
elimination_ordering = *ordering;
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,21 @@ TEST(HybridBayesNet, Creation) {
|
||||||
EXPECT(df.equals(expected));
|
EXPECT(df.equals(expected));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test adding a bayes net to another one.
|
||||||
|
TEST(HybridBayesNet, Add) {
|
||||||
|
HybridBayesNet bayesNet;
|
||||||
|
|
||||||
|
bayesNet.add(Asia, "99/1");
|
||||||
|
|
||||||
|
DiscreteConditional expected(Asia, "99/1");
|
||||||
|
|
||||||
|
HybridBayesNet other;
|
||||||
|
other.push_back(bayesNet);
|
||||||
|
EXPECT(bayesNet.equals(other));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Test choosing an assignment of conditionals
|
// Test choosing an assignment of conditionals
|
||||||
TEST(HybridBayesNet, Choose) {
|
TEST(HybridBayesNet, Choose) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue