diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixtureConditional.cpp similarity index 72% rename from gtsam/hybrid/GaussianMixture.cpp rename to gtsam/hybrid/GaussianMixtureConditional.cpp index 66971b69f..5fc3b4f83 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixtureConditional.cpp @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file GaussianMixture.cpp + * @file GaussianMixtureConditional.cpp * @brief A hybrid conditional in the Conditional Linear Gaussian scheme * @author Fan Jiang * @author Varun Agrawal @@ -20,38 +20,40 @@ #include #include -#include +#include #include #include namespace gtsam { -GaussianMixture::GaussianMixture( +GaussianMixtureConditional::GaussianMixtureConditional( const KeyVector &continuousFrontals, const KeyVector &continuousParents, const DiscreteKeys &discreteParents, - const GaussianMixture::Conditionals &conditionals) + const GaussianMixtureConditional::Conditionals &conditionals) : BaseFactor(CollectKeys(continuousFrontals, continuousParents), discreteParents), BaseConditional(continuousFrontals.size()), conditionals_(conditionals) {} -const GaussianMixture::Conditionals &GaussianMixture::conditionals() { +/* *******************************************************************************/ +const GaussianMixtureConditional::Conditionals &GaussianMixtureConditional::conditionals() { return conditionals_; } -GaussianMixture GaussianMixture::FromConditionalList( +/* *******************************************************************************/ +GaussianMixtureConditional GaussianMixtureConditional::FromConditionalList( const KeyVector &continuousFrontals, const KeyVector &continuousParents, const DiscreteKeys &discreteParents, const std::vector &conditionalsList) { Conditionals dt(discreteParents, conditionalsList); - return GaussianMixture(continuousFrontals, continuousParents, discreteParents, + return GaussianMixtureConditional(continuousFrontals, continuousParents, discreteParents, dt); } /* *******************************************************************************/ -GaussianMixture::Sum GaussianMixture::add( - const GaussianMixture::Sum &sum) const { +GaussianMixtureConditional::Sum GaussianMixtureConditional::add( + const GaussianMixtureConditional::Sum &sum) const { using Y = GaussianFactorGraph; auto add = [](const Y &graph1, const Y &graph2) { auto result = graph1; @@ -63,7 +65,7 @@ GaussianMixture::Sum GaussianMixture::add( } /* *******************************************************************************/ -GaussianMixture::Sum GaussianMixture::asGraph() const { +GaussianMixtureConditional::Sum GaussianMixtureConditional::asGraph() const { auto lambda = [](const GaussianFactor::shared_ptr &factor) { GaussianFactorGraph result; result.push_back(factor); @@ -73,11 +75,12 @@ GaussianMixture::Sum GaussianMixture::asGraph() const { } /* TODO(fan): this (for Testable) is not implemented! */ -bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { +bool GaussianMixtureConditional::equals(const HybridFactor &lf, double tol) const { return false; } -void GaussianMixture::print(const std::string &s, +/* *******************************************************************************/ +void GaussianMixtureConditional::print(const std::string &s, const KeyFormatter &formatter) const { std::cout << s << ": "; if (isContinuous_) std::cout << "Cont. "; diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixtureConditional.h similarity index 70% rename from gtsam/hybrid/GaussianMixture.h rename to gtsam/hybrid/GaussianMixtureConditional.h index 4379ea1ca..e0cf7c050 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixtureConditional.h @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file GaussianMixture.h + * @file GaussianMixtureConditional.h * @brief A hybrid conditional in the Conditional Linear Gaussian scheme * @author Fan Jiang * @author Varun Agrawal @@ -25,13 +25,13 @@ #include namespace gtsam { -class GaussianMixture : public HybridFactor, - public Conditional { +class GaussianMixtureConditional : public HybridFactor, + public Conditional { public: - using This = GaussianMixture; - using shared_ptr = boost::shared_ptr; + using This = GaussianMixtureConditional; + using shared_ptr = boost::shared_ptr; using BaseFactor = HybridFactor; - using BaseConditional = Conditional; + using BaseConditional = Conditional; using Conditionals = DecisionTree; @@ -46,7 +46,7 @@ class GaussianMixture : public HybridFactor, * @param discreteParents the discrete parents. Will be placed last. * @param conditionals a decision tree of GaussianConditionals. */ - GaussianMixture(const KeyVector &continuousFrontals, + GaussianMixtureConditional(const KeyVector &continuousFrontals, const KeyVector &continuousParents, const DiscreteKeys &discreteParents, const Conditionals &conditionals); @@ -55,21 +55,35 @@ class GaussianMixture : public HybridFactor, const Conditionals &conditionals(); - /* *******************************************************************************/ + /** + * @brief Combine Decision Trees + */ Sum add(const Sum &sum) const; - /* *******************************************************************************/ + /** + * @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. + */ Sum asGraph() const; + /** + * @brief Make a Gaussian Mixture from a list of Gaussian conditionals + * + * @param continuousFrontals The continuous frontal variables + * @param continuousParents The continuous parent variables + * @param discreteParents Discrete parents variables + * @param conditionals List of conditionals + */ static This FromConditionalList( const KeyVector &continuousFrontals, const KeyVector &continuousParents, const DiscreteKeys &discreteParents, const std::vector &conditionals); + /* TODO: this is only a stub */ bool equals(const HybridFactor &lf, double tol = 1e-9) const override; + /* print utility */ void print( - const std::string &s = "GaussianMixture\n", + const std::string &s = "GaussianMixtureConditional\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override; }; } // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp index c85383322..65c5c7001 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.cpp +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -34,6 +34,7 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { return false; } +/* *******************************************************************************/ GaussianMixtureFactor GaussianMixtureFactor::FromFactorList( const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, const std::vector &factorsList) { @@ -42,6 +43,8 @@ GaussianMixtureFactor GaussianMixtureFactor::FromFactorList( return GaussianMixtureFactor(continuousKeys, discreteKeys, dt); } + +/* *******************************************************************************/ void GaussianMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const { HybridFactor::print(s, formatter); @@ -57,6 +60,7 @@ void GaussianMixtureFactor::print(const std::string &s, }); } +/* *******************************************************************************/ const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() { return factors_; } diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h index 1a3c582ae..f0f55911a 100644 --- a/gtsam/hybrid/GaussianMixtureFactor.h +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -55,10 +55,8 @@ class GaussianMixtureFactor : public HybridFactor { const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, const std::vector &factors); - /* *******************************************************************************/ Sum add(const Sum &sum) const; - /* *******************************************************************************/ Sum wrappedFactors() const; bool equals(const HybridFactor &lf, double tol = 1e-9) const override; diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index d7e2f33af..4e411b781 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -20,8 +20,6 @@ #include #include -#include // TODO! - namespace gtsam { /** diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp index 48bee192c..73e7747c6 100644 --- a/gtsam/hybrid/HybridConditional.cpp +++ b/gtsam/hybrid/HybridConditional.cpp @@ -50,7 +50,7 @@ HybridConditional::HybridConditional( } HybridConditional::HybridConditional( - boost::shared_ptr gaussianMixture) + boost::shared_ptr gaussianMixture) : BaseFactor(KeyVector(gaussianMixture->keys().begin(), gaussianMixture->keys().begin() + gaussianMixture->nrContinuous), diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 5d7ee2351..76d5b4833 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -18,7 +18,7 @@ #pragma once #include -#include +#include #include #include #include @@ -42,7 +42,7 @@ class HybridFactorGraph; * As a type-erased variant of: * - DiscreteConditional * - GaussianConditional - * - GaussianMixture + * - GaussianMixtureConditional * * The reason why this is important is that `Conditional` is a CRTP class. * CRTP is static polymorphism such that all CRTP classes, while bearing the @@ -93,11 +93,11 @@ class GTSAM_EXPORT HybridConditional HybridConditional(boost::shared_ptr discreteConditional); - HybridConditional(boost::shared_ptr gaussianMixture); + HybridConditional(boost::shared_ptr gaussianMixture); - GaussianMixture::shared_ptr asMixture() { + GaussianMixtureConditional::shared_ptr asMixture() { if (!isHybrid_) throw std::invalid_argument("Not a mixture"); - return boost::static_pointer_cast(inner); + return boost::static_pointer_cast(inner); } DiscreteConditional::shared_ptr asDiscreteConditional() { diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index f44ad898b..699e6d2c6 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include #include #include @@ -104,7 +104,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { // Because of all these reasons, we need to think very carefully about how to // implement the hybrid factors so that we do not get poor performance. // - // The first thing is how to represent the GaussianMixture. A very possible + // The first thing is how to represent the GaussianMixtureConditional. A very possible // scenario is that the incoming factors will have different levels of // discrete keys. For example, imagine we are going to eliminate the fragment: // $\phi(x1,c1,c2)$, $\phi(x1,c2,c3)$, which is perfectly valid. Now we will @@ -358,11 +358,11 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) { auto pair = unzip(eliminationResults); - const GaussianMixture::Conditionals &conditionals = pair.first; + const GaussianMixtureConditional::Conditionals &conditionals = pair.first; const GaussianMixtureFactor::Factors &separatorFactors = pair.second; - // Create the GaussianMixture from the conditionals - auto conditional = boost::make_shared( + // Create the GaussianMixtureConditional from the conditionals + auto conditional = boost::make_shared( frontalKeys, keysOfSeparator, discreteSeparator, conditionals); if (DEBUG) { diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i index 052575011..5a76aaf48 100644 --- a/gtsam/hybrid/hybrid.i +++ b/gtsam/hybrid/hybrid.i @@ -38,16 +38,16 @@ class GaussianMixtureFactor : gtsam::HybridFactor { gtsam::DefaultKeyFormatter) const; }; -#include -class GaussianMixture : gtsam::HybridFactor { - static GaussianMixture FromConditionalList( +#include +class GaussianMixtureConditional : gtsam::HybridFactor { + static GaussianMixtureConditional FromConditionalList( const gtsam::KeyVector& continuousFrontals, const gtsam::KeyVector& continuousParents, const gtsam::DiscreteKeys& discreteParents, const std::vector& conditionalsList); - void print(string s = "GaussianMixture\n", + void print(string s = "GaussianMixtureConditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; }; diff --git a/gtsam/hybrid/tests/testHybridFactorGraph.cpp b/gtsam/hybrid/tests/testHybridFactorGraph.cpp index 79c16d21a..4986cc2a7 100644 --- a/gtsam/hybrid/tests/testHybridFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridFactorGraph.cpp @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include #include @@ -79,8 +79,8 @@ TEST(HybridFactorGraph, creation) { hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1))); - GaussianMixture clgc({X(0)}, {X(1)}, DiscreteKeys(DiscreteKey{C(0), 2}), - GaussianMixture::Conditionals( + GaussianMixtureConditional clgc({X(0)}, {X(1)}, DiscreteKeys(DiscreteKey{C(0), 2}), + GaussianMixtureConditional::Conditionals( C(0), boost::make_shared( X(0), Z_3x1, I_3x3, X(1), I_3x3),