From a9c1d43dc530092f6fb4cbf35047a70069eb21de Mon Sep 17 00:00:00 2001 From: Fan Jiang Date: Fri, 1 Apr 2022 00:20:22 -0400 Subject: [PATCH] Working iSAM for Hybrid --- gtsam/hybrid/HybridBayesTree.h | 29 ++ gtsam/hybrid/HybridConditional.h | 10 +- gtsam/hybrid/HybridFactorGraph.cpp | 45 ++- gtsam/hybrid/HybridISAM.cpp | 99 ++++++ gtsam/hybrid/HybridISAM.h | 59 ++++ gtsam/hybrid/HybridJunctionTree.cpp | 6 +- gtsam/hybrid/tests/Switching.h | 68 ++++ gtsam/hybrid/tests/testHybridConditional.cpp | 313 +++++++++++++++++-- gtsam/inference/BayesTree.h | 24 +- 9 files changed, 595 insertions(+), 58 deletions(-) create mode 100644 gtsam/hybrid/HybridISAM.cpp create mode 100644 gtsam/hybrid/HybridISAM.h create mode 100644 gtsam/hybrid/tests/Switching.h diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index ee51bdd6c..dd8b7f022 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -70,4 +70,33 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { /// @} }; +/* This does special stuff for the hybrid case */ +template +class BayesTreeOrphanWrapper< + CLIQUE, + boost::enable_if_t::value> > + : public CLIQUE::ConditionalType { + public: + typedef CLIQUE CliqueType; + typedef typename CLIQUE::ConditionalType Base; + + boost::shared_ptr clique; + + BayesTreeOrphanWrapper(const boost::shared_ptr& clique) + : clique(clique) { + // Store parent keys in our base type factor so that eliminating those + // parent keys will pull this subtree into the elimination. + this->keys_.assign(clique->conditional()->beginParents(), + clique->conditional()->endParents()); + this->discreteKeys_.assign(clique->conditional()->discreteKeys_.begin(), + clique->conditional()->discreteKeys_.end()); + } + + void print( + const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { + clique->print(s + "stored clique", formatter); + } +}; + } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index b91d2b91a..5d7ee2351 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include #include @@ -31,8 +32,6 @@ #include #include -#include - namespace gtsam { class HybridFactorGraph; @@ -101,10 +100,13 @@ class GTSAM_EXPORT HybridConditional return boost::static_pointer_cast(inner); } - boost::shared_ptr getInner() { - return inner; + DiscreteConditional::shared_ptr asDiscreteConditional() { + if (!isDiscrete_) throw std::invalid_argument("Not a discrete conditional"); + return boost::static_pointer_cast(inner); } + boost::shared_ptr getInner() { return inner; } + /// @} /// @name Testable /// @{ diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index 7a26dfadb..ee6a5dd82 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -217,11 +217,17 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { GaussianFactorGraph gfg; for (auto &fp : factors) { auto ptr = boost::dynamic_pointer_cast(fp); - if (ptr) + if (ptr) { gfg.push_back(ptr->inner); - else - gfg.push_back(boost::static_pointer_cast( - boost::static_pointer_cast(fp)->inner)); + } else { + auto p = boost::static_pointer_cast(fp)->inner; + if (p) { + gfg.push_back(boost::static_pointer_cast(p)); + } else { + // It is an orphan wrapper + if (DEBUG) std::cout << "Got an orphan wrapper conditional\n"; + } + } } auto result = EliminatePreferCholesky(gfg, frontalKeys); @@ -239,11 +245,17 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { DiscreteFactorGraph dfg; for (auto &fp : factors) { auto ptr = boost::dynamic_pointer_cast(fp); - if (ptr) + if (ptr) { dfg.push_back(ptr->inner); - else - dfg.push_back(boost::static_pointer_cast( - boost::static_pointer_cast(fp)->inner)); + } else { + auto p = boost::static_pointer_cast(fp)->inner; + if (p) { + dfg.push_back(boost::static_pointer_cast(p)); + } else { + // It is an orphan wrapper + if (DEBUG) std::cout << "Got an orphan wrapper conditional\n"; + } + } } auto result = EliminateDiscrete(dfg, frontalKeys); @@ -286,8 +298,18 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { deferredFactors.push_back( boost::dynamic_pointer_cast(f)->inner); } else { - throw std::invalid_argument( - "factor is discrete in continuous elimination"); + // We need to handle the case where the object is actually an + // BayesTreeOrphanWrapper! + auto orphan = boost::dynamic_pointer_cast< + BayesTreeOrphanWrapper>(f); + if (orphan) { + if (DEBUG) std::cout << "Got an orphan wrapper conditional\n"; + } else { + auto &fr = *f; + throw std::invalid_argument( + std::string("factor is discrete in continuous elimination") + + typeid(fr).name()); + } } } @@ -373,7 +395,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { } else { // Create a resulting DCGaussianMixture on the separator. auto factor = boost::make_shared( - frontalKeys, discreteSeparator, separatorFactors); + KeyVector(continuousSeparator.begin(), continuousSeparator.end()), + discreteSeparator, separatorFactors); return {boost::make_shared(conditional), factor}; } } diff --git a/gtsam/hybrid/HybridISAM.cpp b/gtsam/hybrid/HybridISAM.cpp new file mode 100644 index 000000000..0db30f1f3 --- /dev/null +++ b/gtsam/hybrid/HybridISAM.cpp @@ -0,0 +1,99 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridISAM.h + * @date March 31, 2022 + * @author Fan Jiang + * @author Frank Dellaert + * @author Richard Roberts + */ + +#include +#include +#include +#include +#include + +#include + +namespace gtsam { + +// Instantiate base class +// template class ISAM; + +/* ************************************************************************* */ +HybridISAM::HybridISAM() {} + +/* ************************************************************************* */ +HybridISAM::HybridISAM(const HybridBayesTree& bayesTree) : Base(bayesTree) {} + +void HybridISAM::updateInternal(const HybridFactorGraph& newFactors, + HybridBayesTree::Cliques* orphans, + const HybridBayesTree::Eliminate& function) { + // Remove the contaminated part of the Bayes tree + BayesNetType bn; + const KeySet newFactorKeys = newFactors.keys(); + if (!this->empty()) { + KeyVector keyVector(newFactorKeys.begin(), newFactorKeys.end()); + this->removeTop(keyVector, &bn, orphans); + } + + // Add the removed top and the new factors + FactorGraphType factors; + factors += bn; + factors += newFactors; + + // Add the orphaned subtrees + for (const sharedClique& orphan : *orphans) + factors += boost::make_shared >(orphan); + + KeySet allDiscrete; + for (auto& factor : factors) { + for (auto& k : factor->discreteKeys_) { + allDiscrete.insert(k.first); + } + } + KeyVector newKeysDiscreteLast; + for (auto& k : newFactorKeys) { + if (!allDiscrete.exists(k)) { + newKeysDiscreteLast.push_back(k); + } + } + std::copy(allDiscrete.begin(), allDiscrete.end(), + std::back_inserter(newKeysDiscreteLast)); + + // KeyVector new + + // Get an ordering where the new keys are eliminated last + const VariableIndex index(factors); + const Ordering ordering = Ordering::ColamdConstrainedLast( + index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()), + true); + + ordering.print("ORD"); + + // eliminate all factors (top, added, orphans) into a new Bayes tree + auto bayesTree = factors.eliminateMultifrontal(ordering, function, index); + + // Re-add into Bayes tree data structures + this->roots_.insert(this->roots_.end(), bayesTree->roots().begin(), + bayesTree->roots().end()); + this->nodes_.insert(bayesTree->nodes().begin(), bayesTree->nodes().end()); +} + +void HybridISAM::update(const HybridFactorGraph& newFactors, + const HybridBayesTree::Eliminate& function) { + Cliques orphans; + this->updateInternal(newFactors, &orphans, function); +} + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridISAM.h b/gtsam/hybrid/HybridISAM.h new file mode 100644 index 000000000..ac9e17e83 --- /dev/null +++ b/gtsam/hybrid/HybridISAM.h @@ -0,0 +1,59 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridISAM.h + * @date March 31, 2022 + * @author Fan Jiang + * @author Frank Dellaert + * @author Richard Roberts + */ + +#pragma once + +#include +#include +#include +#include + +namespace gtsam { + +class GTSAM_EXPORT HybridISAM : public ISAM { + public: + typedef ISAM Base; + typedef HybridISAM This; + typedef boost::shared_ptr shared_ptr; + + /// @name Standard Constructors + /// @{ + + /** Create an empty Bayes Tree */ + HybridISAM(); + + /** Copy constructor */ + HybridISAM(const HybridBayesTree& bayesTree); + + void updateInternal( + const HybridFactorGraph& newFactors, HybridBayesTree::Cliques* orphans, + const HybridBayesTree::Eliminate& function = + HybridBayesTree::EliminationTraitsType::DefaultEliminate); + + void update(const HybridFactorGraph& newFactors, + const HybridBayesTree::Eliminate& function = + HybridBayesTree::EliminationTraitsType::DefaultEliminate); + /// @} +}; + +/// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp index 22b40ad05..77e623a41 100644 --- a/gtsam/hybrid/HybridJunctionTree.cpp +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -16,14 +16,14 @@ */ #include +#include #include #include +#include #include -#include "gtsam/hybrid/HybridFactorGraph.h" -#include "gtsam/inference/Key.h" - +// #undef NDEBUG namespace gtsam { // Instantiate base classes diff --git a/gtsam/hybrid/tests/Switching.h b/gtsam/hybrid/tests/Switching.h new file mode 100644 index 000000000..0816db8df --- /dev/null +++ b/gtsam/hybrid/tests/Switching.h @@ -0,0 +1,68 @@ +#include +#include +#include +#include +#include +#include + +#pragma once + +using gtsam::symbol_shorthand::C; +using gtsam::symbol_shorthand::X; + +namespace gtsam { +inline HybridFactorGraph::shared_ptr makeSwitchingChain( + size_t n, std::function keyFunc = X, + std::function dKeyFunc = C) { + HybridFactorGraph hfg; + + hfg.add(JacobianFactor(keyFunc(1), I_3x3, Z_3x1)); + + // keyFunc(1) to keyFunc(n+1) + for (size_t t = 1; t < n; t++) { + hfg.add(GaussianMixtureFactor::FromFactorList( + {keyFunc(t), keyFunc(t + 1)}, {{dKeyFunc(t), 2}}, + {boost::make_shared(keyFunc(t), I_3x3, keyFunc(t + 1), + I_3x3, Z_3x1), + boost::make_shared(keyFunc(t), I_3x3, keyFunc(t + 1), + I_3x3, Vector3::Ones())})); + + if (t > 1) { + hfg.add(DecisionTreeFactor({{dKeyFunc(t - 1), 2}, {dKeyFunc(t), 2}}, + "0 1 1 3")); + } + } + + return boost::make_shared(std::move(hfg)); +} + +inline std::pair, std::vector> makeBinaryOrdering( + std::vector &input) { + std::vector new_order; + std::vector levels(input.size()); + std::function::iterator, std::vector::iterator, + int)> + bsg = [&bsg, &new_order, &levels, &input]( + std::vector::iterator begin, + std::vector::iterator end, int lvl) { + if (std::distance(begin, end) > 1) { + std::vector::iterator pivot = + begin + std::distance(begin, end) / 2; + + new_order.push_back(*pivot); + levels[std::distance(input.begin(), pivot)] = lvl; + bsg(begin, pivot, lvl + 1); + bsg(pivot + 1, end, lvl + 1); + } else if (std::distance(begin, end) == 1) { + new_order.push_back(*begin); + levels[std::distance(input.begin(), begin)] = lvl; + } + }; + + bsg(input.begin(), input.end(), 0); + std::reverse(new_order.begin(), new_order.end()); + // std::reverse(levels.begin(), levels.end()); + return {new_order, levels}; +} + +} // namespace gtsam \ No newline at end of file diff --git a/gtsam/hybrid/tests/testHybridConditional.cpp b/gtsam/hybrid/tests/testHybridConditional.cpp index 703ec136e..f3664e619 100644 --- a/gtsam/hybrid/tests/testHybridConditional.cpp +++ b/gtsam/hybrid/tests/testHybridConditional.cpp @@ -17,6 +17,9 @@ #include #include +#include +#include +#include #include #include #include @@ -26,16 +29,23 @@ #include #include #include +#include #include +#include +#include +#include #include #include +#include #include #include +#include +#include +#include +#include -#include "gtsam/inference/DotWriter.h" -#include "gtsam/inference/Key.h" -#include "gtsam/inference/Ordering.h" +#include "Switching.h" using namespace boost::assign; @@ -43,7 +53,9 @@ using namespace std; using namespace gtsam; using gtsam::symbol_shorthand::C; +using gtsam::symbol_shorthand::D; using gtsam::symbol_shorthand::X; +using gtsam::symbol_shorthand::Y; #define BOOST_STACKTRACE_GNU_SOURCE_NOT_REQUIRED @@ -99,6 +111,29 @@ TEST_DISABLED(HybridFactorGraph, eliminateMultifrontal) { EXPECT_LONGS_EQUAL(result.second->size(), 1); } +TEST(HybridFactorGraph, eliminateFullSequentialEqualChance) { + HybridFactorGraph hfg; + + DiscreteKey c1(C(1), 2); + + hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); + + DecisionTree dt( + C(1), boost::make_shared(X(1), I_3x3, Z_3x1), + boost::make_shared(X(1), I_3x3, Vector3::Ones())); + + hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt)); + + auto result = + hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {C(1)})); + + auto dc = result->at(2)->asDiscreteConditional(); + DiscreteValues dv; + dv[C(1)] = 0; + EXPECT_DOUBLES_EQUAL(0.6225, dc->operator()(dv), 1e-3); +} + TEST_DISABLED(HybridFactorGraph, eliminateFullSequentialSimple) { std::cout << ">>>>>>>>>>>>>>\n"; @@ -268,27 +303,10 @@ TEST_DISABLED(HybridFactorGraph, eliminateFullMultifrontalTwoClique) { */ } -HybridFactorGraph::shared_ptr makeSwitchingChain(size_t n) { - HybridFactorGraph hfg; - - hfg.add(JacobianFactor(X(1), I_3x3, Z_3x1)); - - // X(1) to X(n+1) - for (size_t t = 1; t < n; t++) { - hfg.add(GaussianMixtureFactor::FromFactorList( - {X(t), X(t + 1)}, {{C(t), 2}}, - {boost::make_shared(X(t), I_3x3, X(t + 1), I_3x3, - Z_3x1), - boost::make_shared(X(t), I_3x3, X(t + 1), I_3x3, - Vector3::Ones())})); - } - - return boost::make_shared(std::move(hfg)); -} - // TODO(fan): make a graph like Varun's paper one TEST(HybridFactorGraph, Switching) { - auto hfg = makeSwitchingChain(9); + auto N = 12; + auto hfg = makeSwitchingChain(N); // X(5) will be the center, X(1-4), X(6-9) // X(3), X(7) @@ -297,20 +315,73 @@ TEST(HybridFactorGraph, Switching) { // C(5) will be the center, C(1-4), C(6-8) // C(3), C(7) // C(1), C(4), C(2), C(6), C(8) - auto ordering_full = - Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7), X(5), - C(1), C(4), C(2), C(6), C(8), C(3), C(7), C(5)}); + // auto ordering_full = + // Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7), + // X(5), + // C(1), C(4), C(2), C(6), C(8), C(3), C(7), C(5)}); + KeyVector ordering; - GTSAM_PRINT(*hfg); + { + std::vector naturalX(N); + std::iota(naturalX.begin(), naturalX.end(), 1); + std::vector ordX; + std::transform(naturalX.begin(), naturalX.end(), std::back_inserter(ordX), + [](int x) { return X(x); }); + + KeyVector ndX; + std::vector lvls; + std::tie(ndX, lvls) = makeBinaryOrdering(ordX); + std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering)); + for (auto &l : lvls) { + l = -l; + } + std::copy(lvls.begin(), lvls.end(), + std::ostream_iterator(std::cout, ",")); + std::cout << "\n"; + } + { + std::vector naturalC(N - 1); + std::iota(naturalC.begin(), naturalC.end(), 1); + std::vector ordC; + std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC), + [](int x) { return C(x); }); + KeyVector ndC; + std::vector lvls; + + // std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering)); + std::tie(ndC, lvls) = makeBinaryOrdering(ordC); + std::copy(ndC.begin(), ndC.end(), std::back_inserter(ordering)); + std::copy(lvls.begin(), lvls.end(), + std::ostream_iterator(std::cout, ",")); + } + auto ordering_full = Ordering(ordering); + + // auto ordering_full = + // Ordering(); + + // for (int i = 1; i <= 9; i++) { + // ordering_full.push_back(X(i)); + // } + + // for (int i = 1; i < 9; i++) { + // ordering_full.push_back(C(i)); + // } + + // auto ordering_full = + // Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7), + // X(5), + // C(1), C(2), C(3), C(4), C(5), C(6), C(7), C(8)}); + + // GTSAM_PRINT(*hfg); GTSAM_PRINT(ordering_full); HybridBayesTree::shared_ptr hbt; HybridFactorGraph::shared_ptr remaining; std::tie(hbt, remaining) = hfg->eliminatePartialMultifrontal(ordering_full); - GTSAM_PRINT(*hbt); + // GTSAM_PRINT(*hbt); - GTSAM_PRINT(*remaining); + // GTSAM_PRINT(*remaining); { DotWriter dw; @@ -323,7 +394,8 @@ TEST(HybridFactorGraph, Switching) { { DotWriter dw; - dw.positionHints['x'] = 1; + // dw.positionHints['c'] = 2; + // dw.positionHints['x'] = 1; std::cout << "\n"; std::cout << hfg->eliminateSequential(ordering_full) ->dot(DefaultKeyFormatter, dw); @@ -335,6 +407,189 @@ TEST(HybridFactorGraph, Switching) { is 1. expensive and 2. inexact. neverless it is doable. And I believe that we should do this. */ + hbt->marginalFactor(C(11))->print("HBT: "); +} + +// TODO(fan): make a graph like Varun's paper one +TEST(HybridFactorGraph, SwitchingISAM) { + auto N = 11; + auto hfg = makeSwitchingChain(N); + + // X(5) will be the center, X(1-4), X(6-9) + // X(3), X(7) + // X(2), X(8) + // X(1), X(4), X(6), X(9) + // C(5) will be the center, C(1-4), C(6-8) + // C(3), C(7) + // C(1), C(4), C(2), C(6), C(8) + // auto ordering_full = + // Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7), + // X(5), + // C(1), C(4), C(2), C(6), C(8), C(3), C(7), C(5)}); + KeyVector ordering; + + { + std::vector naturalX(N); + std::iota(naturalX.begin(), naturalX.end(), 1); + std::vector ordX; + std::transform(naturalX.begin(), naturalX.end(), std::back_inserter(ordX), + [](int x) { return X(x); }); + + KeyVector ndX; + std::vector lvls; + std::tie(ndX, lvls) = makeBinaryOrdering(ordX); + std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering)); + for (auto &l : lvls) { + l = -l; + } + std::copy(lvls.begin(), lvls.end(), + std::ostream_iterator(std::cout, ",")); + std::cout << "\n"; + } + { + std::vector naturalC(N - 1); + std::iota(naturalC.begin(), naturalC.end(), 1); + std::vector ordC; + std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC), + [](int x) { return C(x); }); + KeyVector ndC; + std::vector lvls; + + // std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering)); + std::tie(ndC, lvls) = makeBinaryOrdering(ordC); + std::copy(ndC.begin(), ndC.end(), std::back_inserter(ordering)); + std::copy(lvls.begin(), lvls.end(), + std::ostream_iterator(std::cout, ",")); + } + auto ordering_full = Ordering(ordering); + + // GTSAM_PRINT(*hfg); + GTSAM_PRINT(ordering_full); + + HybridBayesTree::shared_ptr hbt; + HybridFactorGraph::shared_ptr remaining; + std::tie(hbt, remaining) = hfg->eliminatePartialMultifrontal(ordering_full); + + // GTSAM_PRINT(*hbt); + + // GTSAM_PRINT(*remaining); + + { + DotWriter dw; + dw.positionHints['c'] = 2; + dw.positionHints['x'] = 1; + std::cout << hfg->dot(DefaultKeyFormatter, dw); + std::cout << "\n"; + hbt->dot(std::cout); + } + + { + DotWriter dw; + // dw.positionHints['c'] = 2; + // dw.positionHints['x'] = 1; + std::cout << "\n"; + std::cout << hfg->eliminateSequential(ordering_full) + ->dot(DefaultKeyFormatter, dw); + } + + auto new_fg = makeSwitchingChain(12); + auto isam = HybridISAM(*hbt); + + { + HybridFactorGraph factorGraph; + factorGraph.push_back(new_fg->at(new_fg->size() - 2)); + factorGraph.push_back(new_fg->at(new_fg->size() - 1)); + isam.update(factorGraph); + std::cout << isam.dot(); + isam.marginalFactor(C(11))->print(); + } +} + +TEST_DISABLED(HybridFactorGraph, SwitchingTwoVar) { + const int N = 7; + auto hfg = makeSwitchingChain(N, X); + hfg->push_back(*makeSwitchingChain(N, Y, D)); + + for (int t = 1; t <= N; t++) { + hfg->add(JacobianFactor(X(t), I_3x3, Y(t), -I_3x3, Vector3(1.0, 0.0, 0.0))); + } + + KeyVector ordering; + + KeyVector naturalX(N); + std::iota(naturalX.begin(), naturalX.end(), 1); + KeyVector ordX; + for (size_t i = 1; i <= N; i++) { + ordX.emplace_back(X(i)); + ordX.emplace_back(Y(i)); + } + + // { + // KeyVector ndX; + // std::vector lvls; + // std::tie(ndX, lvls) = makeBinaryOrdering(naturalX); + // std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering)); + // std::copy(lvls.begin(), lvls.end(), + // std::ostream_iterator(std::cout, ",")); + // std::cout << "\n"; + + // for (size_t i = 0; i < N; i++) { + // ordX.emplace_back(X(ndX[i])); + // ordX.emplace_back(Y(ndX[i])); + // } + // } + + for (size_t i = 1; i <= N - 1; i++) { + ordX.emplace_back(C(i)); + } + for (size_t i = 1; i <= N - 1; i++) { + ordX.emplace_back(D(i)); + } + + { + DotWriter dw; + dw.positionHints['x'] = 1; + dw.positionHints['c'] = 0; + dw.positionHints['d'] = 3; + dw.positionHints['y'] = 2; + std::cout << hfg->dot(DefaultKeyFormatter, dw); + std::cout << "\n"; + } + + { + DotWriter dw; + dw.positionHints['y'] = 9; + // dw.positionHints['c'] = 0; + // dw.positionHints['d'] = 3; + dw.positionHints['x'] = 1; + std::cout << "\n"; + // std::cout << hfg->eliminateSequential(Ordering(ordX)) + // ->dot(DefaultKeyFormatter, dw); + hfg->eliminateMultifrontal(Ordering(ordX))->dot(std::cout); + } + + Ordering ordering_partial; + for (size_t i = 1; i <= N; i++) { + ordering_partial.emplace_back(X(i)); + ordering_partial.emplace_back(Y(i)); + } + { + HybridBayesNet::shared_ptr hbn; + HybridFactorGraph::shared_ptr remaining; + std::tie(hbn, remaining) = + hfg->eliminatePartialSequential(ordering_partial); + + // remaining->print(); + { + DotWriter dw; + dw.positionHints['x'] = 1; + dw.positionHints['c'] = 0; + dw.positionHints['d'] = 3; + dw.positionHints['y'] = 2; + std::cout << remaining->dot(DefaultKeyFormatter, dw); + std::cout << "\n"; + } + } } /* ************************************************************************* */ diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index a32b3ce22..1abc0a682 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -33,6 +33,7 @@ namespace gtsam { // Forward declarations template class FactorGraph; template class EliminatableClusterTree; + class HybridBayesTreeClique; /* ************************************************************************* */ /** clique statistics */ @@ -272,24 +273,25 @@ namespace gtsam { }; // BayesTree /* ************************************************************************* */ - template - class BayesTreeOrphanWrapper : public CLIQUE::ConditionalType - { - public: + template + class BayesTreeOrphanWrapper : public CLIQUE::ConditionalType { + public: typedef CLIQUE CliqueType; typedef typename CLIQUE::ConditionalType Base; boost::shared_ptr clique; - BayesTreeOrphanWrapper(const boost::shared_ptr& clique) : - clique(clique) - { - // Store parent keys in our base type factor so that eliminating those parent keys will pull - // this subtree into the elimination. - this->keys_.assign(clique->conditional()->beginParents(), clique->conditional()->endParents()); + BayesTreeOrphanWrapper(const boost::shared_ptr& clique) + : clique(clique) { + // Store parent keys in our base type factor so that eliminating those + // parent keys will pull this subtree into the elimination. + this->keys_.assign(clique->conditional()->beginParents(), + clique->conditional()->endParents()); } - void print(const std::string& s="", const KeyFormatter& formatter = DefaultKeyFormatter) const override { + void print( + const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { clique->print(s + "stored clique", formatter); } };