diff --git a/gtsam/hybrid/CGMixtureFactor.cpp b/gtsam/hybrid/CGMixtureFactor.cpp index 2ddf80ec2..16ead783e 100644 --- a/gtsam/hybrid/CGMixtureFactor.cpp +++ b/gtsam/hybrid/CGMixtureFactor.cpp @@ -4,6 +4,9 @@ #include +#include +#include + namespace gtsam { CGMixtureFactor::CGMixtureFactor(const KeyVector &continuousKeys, @@ -14,4 +17,18 @@ bool CGMixtureFactor::equals(const HybridFactor &lf, double tol) const { return false; } +void CGMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const { + HybridFactor::print(s, formatter); + factors_.print( + "mixture = ", + [&](Key k) { + return formatter(k); + }, [&](const GaussianFactor::shared_ptr &gf) -> std::string { + RedirectCout rd; + if (!gf->empty()) gf->print("", formatter); + else return {"nullptr"}; + return rd.str(); + }); +} + } \ No newline at end of file diff --git a/gtsam/hybrid/CGMixtureFactor.h b/gtsam/hybrid/CGMixtureFactor.h index 9c9e43ec2..7ff53b7ed 100644 --- a/gtsam/hybrid/CGMixtureFactor.h +++ b/gtsam/hybrid/CGMixtureFactor.h @@ -42,6 +42,9 @@ public: const Factors &factors); bool equals(const HybridFactor &lf, double tol = 1e-9) const override; + + void print(const std::string &s = "HybridFactor\n", + const KeyFormatter &formatter = DefaultKeyFormatter) const override; }; } diff --git a/gtsam/hybrid/CLGaussianConditional.cpp b/gtsam/hybrid/CLGaussianConditional.cpp index 09babc4e2..dbc9631c8 100644 --- a/gtsam/hybrid/CLGaussianConditional.cpp +++ b/gtsam/hybrid/CLGaussianConditional.cpp @@ -18,3 +18,34 @@ #include +#include + +namespace gtsam { + +CLGaussianConditional::CLGaussianConditional(const KeyVector &continuousFrontals, + const KeyVector &continuousParents, + const DiscreteKeys &discreteParents, + const CLGaussianConditional::Conditionals &factors) + : BaseFactor( + CollectKeys(continuousFrontals, continuousParents), discreteParents), + BaseConditional(continuousFrontals.size()) { + +} + +bool CLGaussianConditional::equals(const HybridFactor &lf, double tol) const { + return false; +} + +void CLGaussianConditional::print(const std::string &s, const KeyFormatter &formatter) const { + std::cout << s << ": "; + if (isContinuous_) std::cout << "Cont. "; + if (isDiscrete_) std::cout << "Disc. "; + if (isHybrid_) std::cout << "Hybr. "; + BaseConditional::print("", formatter); + std::cout << "Discrete Keys = "; + for (auto &dk : discreteKeys_) { + std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), "; + } + std::cout << "\n"; +} +} \ No newline at end of file diff --git a/gtsam/hybrid/CLGaussianConditional.h b/gtsam/hybrid/CLGaussianConditional.h index 03a2d99e7..14989df72 100644 --- a/gtsam/hybrid/CLGaussianConditional.h +++ b/gtsam/hybrid/CLGaussianConditional.h @@ -19,13 +19,30 @@ #include #include +#include +#include + namespace gtsam { class CLGaussianConditional : public HybridFactor, public Conditional { public: using This = CLGaussianConditional; using shared_ptr = boost::shared_ptr; using BaseFactor = HybridFactor; + using BaseConditional = Conditional; + using Conditionals = DecisionTree; +public: + + CLGaussianConditional(const KeyVector &continuousFrontals, + const KeyVector &continuousParents, + const DiscreteKeys &discreteParents, + const Conditionals &factors); + + bool equals(const HybridFactor &lf, double tol = 1e-9) const override; + + void print( + const std::string &s = "CLGaussianConditional\n", + const KeyFormatter &formatter = DefaultKeyFormatter) const override; }; } \ No newline at end of file diff --git a/gtsam/hybrid/HybridDiscreteFactor.cpp b/gtsam/hybrid/HybridDiscreteFactor.cpp index 13766933b..1758e9025 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.cpp +++ b/gtsam/hybrid/HybridDiscreteFactor.cpp @@ -21,6 +21,11 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf) bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const { return false; +} + +void HybridDiscreteFactor::print(const std::string &s, const KeyFormatter &formatter) const { + HybridFactor::print(s, formatter); + inner->print("inner: ", formatter); }; } \ No newline at end of file diff --git a/gtsam/hybrid/HybridDiscreteFactor.h b/gtsam/hybrid/HybridDiscreteFactor.h index 0395c9512..9d574b736 100644 --- a/gtsam/hybrid/HybridDiscreteFactor.h +++ b/gtsam/hybrid/HybridDiscreteFactor.h @@ -37,5 +37,7 @@ class HybridDiscreteFactor : public HybridFactor { public: virtual bool equals(const HybridFactor& lf, double tol) const override; + + void print(const std::string &s = "HybridFactor\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override; }; } diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp index 907350e83..3095136a4 100644 --- a/gtsam/hybrid/HybridFactor.cpp +++ b/gtsam/hybrid/HybridFactor.cpp @@ -16,3 +16,38 @@ */ #include + +namespace gtsam { + +KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) { + KeyVector allKeys; + std::copy(continuousKeys.begin(), continuousKeys.end(), std::back_inserter(allKeys)); + std::transform(discreteKeys.begin(), + discreteKeys.end(), + std::back_inserter(allKeys), + [](const DiscreteKey &k) { return k.first; }); + return allKeys; +} + +KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2) { + KeyVector allKeys; + std::copy(keys1.begin(), keys1.end(), std::back_inserter(allKeys)); + std::copy(keys2.begin(), keys2.end(), std::back_inserter(allKeys)); + return allKeys; +} + +HybridFactor::HybridFactor() = default; + +HybridFactor::HybridFactor(const KeyVector &keys) : Base(keys), isContinuous_(true) {} + +HybridFactor::HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) + : Base( + CollectKeys(continuousKeys, discreteKeys)), isHybrid_(true), discreteKeys_(discreteKeys) {} + +HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), + isDiscrete_(true), + discreteKeys_(discreteKeys) {} + +HybridFactor::~HybridFactor() = default; + +} \ No newline at end of file diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 64b49f605..619d16078 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -25,6 +25,9 @@ #include namespace gtsam { +KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys); +KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); + /** * Base class for hybrid probabilistic factors */ @@ -49,33 +52,21 @@ public: /// @{ /** Default constructor creates empty factor */ - HybridFactor() {} + HybridFactor(); /** Construct from container of keys. This constructor is used internally from derived factor * constructors, either from a container of keys or from a boost::assign::list_of. */ // template // HybridFactor(const CONTAINER &keys) : Base(keys) {} - HybridFactor(const KeyVector &keys) : Base(keys), isContinuous_(true) {} + explicit HybridFactor(const KeyVector &keys); - static KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) { - KeyVector allKeys; - std::copy(continuousKeys.begin(), continuousKeys.end(), std::back_inserter(allKeys)); - std::transform(discreteKeys.begin(), - discreteKeys.end(), - std::back_inserter(allKeys), - [](const DiscreteKey &k) { return k.first; }); - return allKeys; - } + HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys); - HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) : Base( - CollectKeys(continuousKeys, discreteKeys)), isHybrid_(true) {} - - HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), isDiscrete_(true) {} + explicit HybridFactor(const DiscreteKeys &discreteKeys); /// Virtual destructor - virtual ~HybridFactor() { - } + virtual ~HybridFactor(); /// @} /// @name Testable diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index cd5bc651d..2dc54d75d 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -34,6 +34,7 @@ EliminateHybrid(const HybridFactorGraph &factors, // PREPROCESS: Identify the nature of the current elimination KeySet allKeys; + // TODO: we do a mock by just doing the correct key thing std::cout << "Begin Eliminate: "; frontalKeys.print(); @@ -43,6 +44,7 @@ EliminateHybrid(const HybridFactorGraph &factors, factor->print(); allKeys.insert(factor->begin(), factor->end()); } + for (auto &k : frontalKeys) { allKeys.erase(k); } @@ -51,13 +53,14 @@ EliminateHybrid(const HybridFactorGraph &factors, gttic(product); HybridConditional sum(allKeys.size(), Ordering(allKeys)); -// HybridDiscreteFactor product(DiscreteConditional()); -// for (auto&& factor : factors) product = (*factor) * product; + + // HybridDiscreteFactor product(DiscreteConditional()); + // for (auto&& factor : factors) product = (*factor) * product; gttoc(product); // sum out frontals, this is the factor on the separator gttic(sum); -// HybridFactor::shared_ptr sum = product.sum(frontalKeys); + // HybridFactor::shared_ptr sum = product.sum(frontalKeys); gttoc(sum); // Ordering keys for the conditional so that frontalKeys are really in front @@ -69,11 +72,11 @@ EliminateHybrid(const HybridFactorGraph &factors, // now divide product/sum to get conditional gttic(divide); -// auto conditional = -// boost::make_shared(product, *sum, orderedKeys); + // auto conditional = + // boost::make_shared(product, *sum, orderedKeys); gttoc(divide); -// return std::make_pair(conditional, sum); + // return std::make_pair(conditional, sum); return std::make_pair(boost::make_shared(frontalKeys.size(), orderedKeys), boost::make_shared(std::move(sum))); diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp index 1e87cbbc3..faa4ba998 100644 --- a/gtsam/hybrid/HybridGaussianFactor.cpp +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -18,6 +18,10 @@ HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf) : Base(jf.keys() bool HybridGaussianFactor::equals(const HybridFactor& lf, double tol) const { return false; +} +void HybridGaussianFactor::print(const std::string &s, const KeyFormatter &formatter) const { + HybridFactor::print(s, formatter); + inner->print("inner: ", formatter); }; } \ No newline at end of file diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h index 34a7c0004..8562075b4 100644 --- a/gtsam/hybrid/HybridGaussianFactor.h +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -37,5 +37,7 @@ class HybridGaussianFactor : public HybridFactor { public: virtual bool equals(const HybridFactor& lf, double tol) const override; + + void print(const std::string &s = "HybridFactor\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override; }; } diff --git a/gtsam/hybrid/tests/testHybridConditional.cpp b/gtsam/hybrid/tests/testHybridConditional.cpp index 46ec40475..4611026b3 100644 --- a/gtsam/hybrid/tests/testHybridConditional.cpp +++ b/gtsam/hybrid/tests/testHybridConditional.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -47,6 +48,13 @@ TEST_UNSAFE(HybridFactorGraph, creation) { HybridFactorGraph hfg; hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1))); + + CLGaussianConditional clgc( + {X(0)}, {X(1)}, + DiscreteKeys(DiscreteKey{C(0), 2}), + CLGaussianConditional::Conditionals() + ); + GTSAM_PRINT(clgc); } TEST_UNSAFE(HybridFactorGraph, eliminate) { @@ -84,12 +92,19 @@ TEST(HybridFactorGraph, eliminateFullMultifrontal) { hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); - DecisionTree dt; + DecisionTree dt(C(1), + boost::make_shared(X(1), + I_3x3, + Z_3x1), + boost::make_shared(X(1), + I_3x3, + Vector3::Ones())); - hfg.add(CGMixtureFactor({X(1)}, { x }, dt)); + hfg.add(CGMixtureFactor({X(1)}, {x}, dt)); hfg.add(HybridDiscreteFactor(DecisionTreeFactor(x, {2, 8}))); + hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4"))); - auto result = hfg.eliminateMultifrontal(Ordering::ColamdConstrainedLast(hfg, {C(1)})); + auto result = hfg.eliminateMultifrontal(Ordering::ColamdConstrainedLast(hfg, {C(1), C(2)})); GTSAM_PRINT(*result); GTSAM_PRINT(*result->marginalFactor(C(1)));