Merge pull request #1349 from borglab/hybrid/two_ways
commit
cfcbddaa61
|
|
@ -99,6 +99,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
// TODO(dellaert): what is this non-const method used for? Abolish it?
|
||||||
void HybridBayesNet::updateDiscreteConditionals(
|
void HybridBayesNet::updateDiscreteConditionals(
|
||||||
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
|
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
|
||||||
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
|
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
|
||||||
|
|
@ -150,9 +151,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||||
|
|
||||||
// Go through all the conditionals in the
|
// Go through all the conditionals in the
|
||||||
// Bayes Net and prune them as per decisionTree.
|
// Bayes Net and prune them as per decisionTree.
|
||||||
for (size_t i = 0; i < this->size(); i++) {
|
for (auto &&conditional : *this) {
|
||||||
HybridConditional::shared_ptr conditional = this->at(i);
|
|
||||||
|
|
||||||
if (conditional->isHybrid()) {
|
if (conditional->isHybrid()) {
|
||||||
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
|
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -344,18 +344,20 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
// However this is also the case with iSAM2, so no pressure :)
|
// However this is also the case with iSAM2, so no pressure :)
|
||||||
|
|
||||||
// PREPROCESS: Identify the nature of the current elimination
|
// PREPROCESS: Identify the nature of the current elimination
|
||||||
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
|
|
||||||
std::set<DiscreteKey> discreteSeparatorSet;
|
|
||||||
std::set<DiscreteKey> discreteFrontals;
|
|
||||||
|
|
||||||
|
// First, identify the separator keys, i.e. all keys that are not frontal.
|
||||||
KeySet separatorKeys;
|
KeySet separatorKeys;
|
||||||
KeySet allContinuousKeys;
|
|
||||||
KeySet continuousFrontals;
|
|
||||||
KeySet continuousSeparator;
|
|
||||||
|
|
||||||
// This initializes separatorKeys and mapFromKeyToDiscreteKey
|
|
||||||
for (auto &&factor : factors) {
|
for (auto &&factor : factors) {
|
||||||
separatorKeys.insert(factor->begin(), factor->end());
|
separatorKeys.insert(factor->begin(), factor->end());
|
||||||
|
}
|
||||||
|
// remove frontals from separator
|
||||||
|
for (auto &k : frontalKeys) {
|
||||||
|
separatorKeys.erase(k);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a map from keys to DiscreteKeys
|
||||||
|
std::unordered_map<Key, DiscreteKey> mapFromKeyToDiscreteKey;
|
||||||
|
for (auto &&factor : factors) {
|
||||||
if (!factor->isContinuous()) {
|
if (!factor->isContinuous()) {
|
||||||
for (auto &k : factor->discreteKeys()) {
|
for (auto &k : factor->discreteKeys()) {
|
||||||
mapFromKeyToDiscreteKey[k.first] = k;
|
mapFromKeyToDiscreteKey[k.first] = k;
|
||||||
|
|
@ -363,49 +365,50 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove frontals from separator
|
// Fill in discrete frontals and continuous frontals.
|
||||||
for (auto &k : frontalKeys) {
|
std::set<DiscreteKey> discreteFrontals;
|
||||||
separatorKeys.erase(k);
|
KeySet continuousFrontals;
|
||||||
}
|
|
||||||
|
|
||||||
// Fill in discrete frontals and continuous frontals for the end result
|
|
||||||
for (auto &k : frontalKeys) {
|
for (auto &k : frontalKeys) {
|
||||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
||||||
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
|
discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k));
|
||||||
} else {
|
} else {
|
||||||
continuousFrontals.insert(k);
|
continuousFrontals.insert(k);
|
||||||
allContinuousKeys.insert(k);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fill in discrete frontals and continuous frontals for the end result
|
// Fill in discrete discrete separator keys and continuous separator keys.
|
||||||
|
std::set<DiscreteKey> discreteSeparatorSet;
|
||||||
|
KeySet continuousSeparator;
|
||||||
for (auto &k : separatorKeys) {
|
for (auto &k : separatorKeys) {
|
||||||
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
|
||||||
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
|
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
|
||||||
} else {
|
} else {
|
||||||
continuousSeparator.insert(k);
|
continuousSeparator.insert(k);
|
||||||
allContinuousKeys.insert(k);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if we have any continuous keys:
|
||||||
|
const bool discrete_only =
|
||||||
|
continuousFrontals.empty() && continuousSeparator.empty();
|
||||||
|
|
||||||
// NOTE: We should really defer the product here because of pruning
|
// NOTE: We should really defer the product here because of pruning
|
||||||
|
|
||||||
// Case 1: we are only dealing with continuous
|
if (discrete_only) {
|
||||||
if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) {
|
// Case 1: we are only dealing with discrete
|
||||||
return continuousElimination(factors, frontalKeys);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case 2: we are only dealing with discrete
|
|
||||||
if (allContinuousKeys.empty()) {
|
|
||||||
return discreteElimination(factors, frontalKeys);
|
return discreteElimination(factors, frontalKeys);
|
||||||
}
|
} else {
|
||||||
|
// Case 2: we are only dealing with continuous
|
||||||
|
if (mapFromKeyToDiscreteKey.empty()) {
|
||||||
|
return continuousElimination(factors, frontalKeys);
|
||||||
|
} else {
|
||||||
|
// Case 3: We are now in the hybrid land!
|
||||||
#ifdef HYBRID_TIMING
|
#ifdef HYBRID_TIMING
|
||||||
tictoc_reset_();
|
tictoc_reset_();
|
||||||
#endif
|
#endif
|
||||||
// Case 3: We are now in the hybrid land!
|
|
||||||
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
return hybridElimination(factors, frontalKeys, continuousSeparator,
|
||||||
discreteSeparatorSet);
|
discreteSeparatorSet);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
||||||
|
|
@ -432,8 +432,65 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
|
||||||
EXPECT(assert_equal(discrete_seq, hybrid_values.discrete()));
|
EXPECT(assert_equal(discrete_seq, hybrid_values.discrete()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/*********************************************************************************
|
||||||
/**
|
// Create a hybrid nonlinear factor graph f(x0, x1, m0; z0, z1)
|
||||||
|
********************************************************************************/
|
||||||
|
static HybridNonlinearFactorGraph createHybridNonlinearFactorGraph() {
|
||||||
|
HybridNonlinearFactorGraph nfg;
|
||||||
|
|
||||||
|
constexpr double sigma = 0.5; // measurement noise
|
||||||
|
const auto noise_model = noiseModel::Isotropic::Sigma(1, sigma);
|
||||||
|
|
||||||
|
// Add "measurement" factors:
|
||||||
|
nfg.emplace_nonlinear<PriorFactor<double>>(X(0), 0.0, noise_model);
|
||||||
|
nfg.emplace_nonlinear<PriorFactor<double>>(X(1), 1.0, noise_model);
|
||||||
|
|
||||||
|
// Add mixture factor:
|
||||||
|
DiscreteKey m(M(0), 2);
|
||||||
|
const auto zero_motion =
|
||||||
|
boost::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model);
|
||||||
|
const auto one_motion =
|
||||||
|
boost::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model);
|
||||||
|
nfg.emplace_hybrid<MixtureFactor>(
|
||||||
|
KeyVector{X(0), X(1)}, DiscreteKeys{m},
|
||||||
|
std::vector<NonlinearFactor::shared_ptr>{zero_motion, one_motion});
|
||||||
|
|
||||||
|
return nfg;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************************************
|
||||||
|
// Create a hybrid nonlinear factor graph f(x0, x1, m0; z0, z1)
|
||||||
|
********************************************************************************/
|
||||||
|
static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() {
|
||||||
|
HybridNonlinearFactorGraph nfg = createHybridNonlinearFactorGraph();
|
||||||
|
|
||||||
|
Values initial;
|
||||||
|
double z0 = 0.0, z1 = 1.0;
|
||||||
|
initial.insert<double>(X(0), z0);
|
||||||
|
initial.insert<double>(X(1), z1);
|
||||||
|
return nfg.linearize(initial);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************************************
|
||||||
|
* Do hybrid elimination and do regression test on discrete conditional.
|
||||||
|
********************************************************************************/
|
||||||
|
TEST(HybridEstimation, eliminateSequentialRegression) {
|
||||||
|
// 1. Create the factor graph from the nonlinear factor graph.
|
||||||
|
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph();
|
||||||
|
|
||||||
|
// 2. Eliminate into BN
|
||||||
|
const Ordering ordering = fg->getHybridOrdering();
|
||||||
|
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
|
||||||
|
// GTSAM_PRINT(*bn);
|
||||||
|
|
||||||
|
// TODO(dellaert): dc should be discrete conditional on m0, but it is an unnormalized factor?
|
||||||
|
// DiscreteKey m(M(0), 2);
|
||||||
|
// DiscreteConditional expected(m % "0.51341712/1");
|
||||||
|
// auto dc = bn->back()->asDiscreteConditional();
|
||||||
|
// EXPECT(assert_equal(expected, *dc, 1e-9));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*********************************************************************************
|
||||||
* Test for correctness via sampling.
|
* Test for correctness via sampling.
|
||||||
*
|
*
|
||||||
* Compute the conditional P(x0, m0, x1| z0, z1)
|
* Compute the conditional P(x0, m0, x1| z0, z1)
|
||||||
|
|
@ -442,32 +499,10 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
|
||||||
* 2. Eliminate the factor graph into a Bayes Net `BN`.
|
* 2. Eliminate the factor graph into a Bayes Net `BN`.
|
||||||
* 3. Sample from the Bayes Net.
|
* 3. Sample from the Bayes Net.
|
||||||
* 4. Check that the ratio `BN(x)/FG(x) = constant` for all samples `x`.
|
* 4. Check that the ratio `BN(x)/FG(x) = constant` for all samples `x`.
|
||||||
*/
|
********************************************************************************/
|
||||||
TEST(HybridEstimation, CorrectnessViaSampling) {
|
TEST(HybridEstimation, CorrectnessViaSampling) {
|
||||||
HybridNonlinearFactorGraph nfg;
|
|
||||||
|
|
||||||
// First we create a hybrid nonlinear factor graph
|
|
||||||
// which represents f(x0, x1, m0; z0, z1).
|
|
||||||
// We linearize and eliminate this to get
|
|
||||||
// the required Factor Graph FG and Bayes Net BN.
|
|
||||||
const auto noise_model = noiseModel::Isotropic::Sigma(1, 1.0);
|
|
||||||
const auto zero_motion =
|
|
||||||
boost::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model);
|
|
||||||
const auto one_motion =
|
|
||||||
boost::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model);
|
|
||||||
|
|
||||||
nfg.emplace_nonlinear<PriorFactor<double>>(X(0), 0.0, noise_model);
|
|
||||||
nfg.emplace_hybrid<MixtureFactor>(
|
|
||||||
KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)},
|
|
||||||
std::vector<NonlinearFactor::shared_ptr>{zero_motion, one_motion});
|
|
||||||
|
|
||||||
Values initial;
|
|
||||||
double z0 = 0.0, z1 = 1.0;
|
|
||||||
initial.insert<double>(X(0), z0);
|
|
||||||
initial.insert<double>(X(1), z1);
|
|
||||||
|
|
||||||
// 1. Create the factor graph from the nonlinear factor graph.
|
// 1. Create the factor graph from the nonlinear factor graph.
|
||||||
HybridGaussianFactorGraph::shared_ptr fg = nfg.linearize(initial);
|
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph();
|
||||||
|
|
||||||
// 2. Eliminate into BN
|
// 2. Eliminate into BN
|
||||||
const Ordering ordering = fg->getHybridOrdering();
|
const Ordering ordering = fg->getHybridOrdering();
|
||||||
|
|
|
||||||
|
|
@ -587,7 +587,7 @@ factor 6: Discrete [m1 m0]
|
||||||
// Expected output for hybridBayesNet.
|
// Expected output for hybridBayesNet.
|
||||||
string expected_hybridBayesNet = R"(
|
string expected_hybridBayesNet = R"(
|
||||||
size: 3
|
size: 3
|
||||||
factor 0: Hybrid P( x0 | x1 m0)
|
conditional 0: Hybrid P( x0 | x1 m0)
|
||||||
Discrete Keys = (m0, 2),
|
Discrete Keys = (m0, 2),
|
||||||
Choice(m0)
|
Choice(m0)
|
||||||
0 Leaf p(x0 | x1)
|
0 Leaf p(x0 | x1)
|
||||||
|
|
@ -602,7 +602,7 @@ factor 0: Hybrid P( x0 | x1 m0)
|
||||||
d = [ -9.95037 ]
|
d = [ -9.95037 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
factor 1: Hybrid P( x1 | x2 m0 m1)
|
conditional 1: Hybrid P( x1 | x2 m0 m1)
|
||||||
Discrete Keys = (m0, 2), (m1, 2),
|
Discrete Keys = (m0, 2), (m1, 2),
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Choice(m0)
|
0 Choice(m0)
|
||||||
|
|
@ -631,7 +631,7 @@ factor 1: Hybrid P( x1 | x2 m0 m1)
|
||||||
d = [ -10 ]
|
d = [ -10 ]
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
factor 2: Hybrid P( x2 | m0 m1)
|
conditional 2: Hybrid P( x2 | m0 m1)
|
||||||
Discrete Keys = (m0, 2), (m1, 2),
|
Discrete Keys = (m0, 2), (m1, 2),
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Choice(m0)
|
0 Choice(m0)
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,14 @@ namespace gtsam {
|
||||||
template <class CONDITIONAL>
|
template <class CONDITIONAL>
|
||||||
void BayesNet<CONDITIONAL>::print(const std::string& s,
|
void BayesNet<CONDITIONAL>::print(const std::string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
Base::print(s, formatter);
|
std::cout << (s.empty() ? "" : s + " ") << std::endl;
|
||||||
|
std::cout << "size: " << this->size() << std::endl;
|
||||||
|
for (size_t i = 0; i < this->size(); i++) {
|
||||||
|
const auto& conditional = this->at(i);
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "conditional " << i << ": ";
|
||||||
|
if (conditional) conditional->print(ss.str(), formatter);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue