Address PR comments

release/4.3a0
Fan Jiang 2022-05-22 21:29:12 -07:00
parent 74af969f68
commit b215d3a377
10 changed files with 61 additions and 44 deletions

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file GaussianMixture.cpp * @file GaussianMixtureConditional.cpp
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme * @brief A hybrid conditional in the Conditional Linear Gaussian scheme
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
@ -20,38 +20,40 @@
#include <gtsam/base/utilities.h> #include <gtsam/base/utilities.h>
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixtureConditional.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam { namespace gtsam {
GaussianMixture::GaussianMixture( GaussianMixtureConditional::GaussianMixtureConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents, const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const GaussianMixture::Conditionals &conditionals) const GaussianMixtureConditional::Conditionals &conditionals)
: BaseFactor(CollectKeys(continuousFrontals, continuousParents), : BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents), discreteParents),
BaseConditional(continuousFrontals.size()), BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {} conditionals_(conditionals) {}
const GaussianMixture::Conditionals &GaussianMixture::conditionals() { /* *******************************************************************************/
const GaussianMixtureConditional::Conditionals &GaussianMixtureConditional::conditionals() {
return conditionals_; return conditionals_;
} }
GaussianMixture GaussianMixture::FromConditionalList( /* *******************************************************************************/
GaussianMixtureConditional GaussianMixtureConditional::FromConditionalList(
const KeyVector &continuousFrontals, const KeyVector &continuousParents, const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) { const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
Conditionals dt(discreteParents, conditionalsList); Conditionals dt(discreteParents, conditionalsList);
return GaussianMixture(continuousFrontals, continuousParents, discreteParents, return GaussianMixtureConditional(continuousFrontals, continuousParents, discreteParents,
dt); dt);
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add( GaussianMixtureConditional::Sum GaussianMixtureConditional::add(
const GaussianMixture::Sum &sum) const { const GaussianMixtureConditional::Sum &sum) const {
using Y = GaussianFactorGraph; using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1; auto result = graph1;
@ -63,7 +65,7 @@ GaussianMixture::Sum GaussianMixture::add(
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::asGraph() const { GaussianMixtureConditional::Sum GaussianMixtureConditional::asGraph() const {
auto lambda = [](const GaussianFactor::shared_ptr &factor) { auto lambda = [](const GaussianFactor::shared_ptr &factor) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(factor);
@ -73,11 +75,12 @@ GaussianMixture::Sum GaussianMixture::asGraph() const {
} }
/* TODO(fan): this (for Testable) is not implemented! */ /* TODO(fan): this (for Testable) is not implemented! */
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { bool GaussianMixtureConditional::equals(const HybridFactor &lf, double tol) const {
return false; return false;
} }
void GaussianMixture::print(const std::string &s, /* *******************************************************************************/
void GaussianMixtureConditional::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
std::cout << s << ": "; std::cout << s << ": ";
if (isContinuous_) std::cout << "Cont. "; if (isContinuous_) std::cout << "Cont. ";

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file GaussianMixture.h * @file GaussianMixtureConditional.h
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme * @brief A hybrid conditional in the Conditional Linear Gaussian scheme
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
@ -25,13 +25,13 @@
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
namespace gtsam { namespace gtsam {
class GaussianMixture : public HybridFactor, class GaussianMixtureConditional : public HybridFactor,
public Conditional<HybridFactor, GaussianMixture> { public Conditional<HybridFactor, GaussianMixtureConditional> {
public: public:
using This = GaussianMixture; using This = GaussianMixtureConditional;
using shared_ptr = boost::shared_ptr<GaussianMixture>; using shared_ptr = boost::shared_ptr<GaussianMixtureConditional>;
using BaseFactor = HybridFactor; using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, GaussianMixture>; using BaseConditional = Conditional<HybridFactor, GaussianMixtureConditional>;
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>; using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
@ -46,7 +46,7 @@ class GaussianMixture : public HybridFactor,
* @param discreteParents the discrete parents. Will be placed last. * @param discreteParents the discrete parents. Will be placed last.
* @param conditionals a decision tree of GaussianConditionals. * @param conditionals a decision tree of GaussianConditionals.
*/ */
GaussianMixture(const KeyVector &continuousFrontals, GaussianMixtureConditional(const KeyVector &continuousFrontals,
const KeyVector &continuousParents, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const Conditionals &conditionals); const Conditionals &conditionals);
@ -55,21 +55,35 @@ class GaussianMixture : public HybridFactor,
const Conditionals &conditionals(); const Conditionals &conditionals();
/* *******************************************************************************/ /**
* @brief Combine Decision Trees
*/
Sum add(const Sum &sum) const; Sum add(const Sum &sum) const;
/* *******************************************************************************/ /**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
*/
Sum asGraph() const; Sum asGraph() const;
/**
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
* @param discreteParents Discrete parents variables
* @param conditionals List of conditionals
*/
static This FromConditionalList( static This FromConditionalList(
const KeyVector &continuousFrontals, const KeyVector &continuousParents, const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals); const std::vector<GaussianConditional::shared_ptr> &conditionals);
/* TODO: this is only a stub */
bool equals(const HybridFactor &lf, double tol = 1e-9) const override; bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
/* print utility */
void print( void print(
const std::string &s = "GaussianMixture\n", const std::string &s = "GaussianMixtureConditional\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -34,6 +34,7 @@ bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
return false; return false;
} }
/* *******************************************************************************/
GaussianMixtureFactor GaussianMixtureFactor::FromFactorList( GaussianMixtureFactor GaussianMixtureFactor::FromFactorList(
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factorsList) { const std::vector<GaussianFactor::shared_ptr> &factorsList) {
@ -42,6 +43,8 @@ GaussianMixtureFactor GaussianMixtureFactor::FromFactorList(
return GaussianMixtureFactor(continuousKeys, discreteKeys, dt); return GaussianMixtureFactor(continuousKeys, discreteKeys, dt);
} }
/* *******************************************************************************/
void GaussianMixtureFactor::print(const std::string &s, void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter); HybridFactor::print(s, formatter);
@ -57,6 +60,7 @@ void GaussianMixtureFactor::print(const std::string &s,
}); });
} }
/* *******************************************************************************/
const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() { const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() {
return factors_; return factors_;
} }

View File

@ -55,10 +55,8 @@ class GaussianMixtureFactor : public HybridFactor {
const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactor::shared_ptr> &factors); const std::vector<GaussianFactor::shared_ptr> &factors);
/* *******************************************************************************/
Sum add(const Sum &sum) const; Sum add(const Sum &sum) const;
/* *******************************************************************************/
Sum wrappedFactors() const; Sum wrappedFactors() const;
bool equals(const HybridFactor &lf, double tol = 1e-9) const override; bool equals(const HybridFactor &lf, double tol = 1e-9) const override;

View File

@ -20,8 +20,6 @@
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
#include <iostream> // TODO!
namespace gtsam { namespace gtsam {
/** /**

View File

@ -50,7 +50,7 @@ HybridConditional::HybridConditional(
} }
HybridConditional::HybridConditional( HybridConditional::HybridConditional(
boost::shared_ptr<GaussianMixture> gaussianMixture) boost::shared_ptr<GaussianMixtureConditional> gaussianMixture)
: BaseFactor(KeyVector(gaussianMixture->keys().begin(), : BaseFactor(KeyVector(gaussianMixture->keys().begin(),
gaussianMixture->keys().begin() + gaussianMixture->keys().begin() +
gaussianMixture->nrContinuous), gaussianMixture->nrContinuous),

View File

@ -18,7 +18,7 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixtureConditional.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h> #include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
@ -42,7 +42,7 @@ class HybridFactorGraph;
* As a type-erased variant of: * As a type-erased variant of:
* - DiscreteConditional * - DiscreteConditional
* - GaussianConditional * - GaussianConditional
* - GaussianMixture * - GaussianMixtureConditional
* *
* The reason why this is important is that `Conditional<T>` is a CRTP class. * The reason why this is important is that `Conditional<T>` is a CRTP class.
* CRTP is static polymorphism such that all CRTP classes, while bearing the * CRTP is static polymorphism such that all CRTP classes, while bearing the
@ -93,11 +93,11 @@ class GTSAM_EXPORT HybridConditional
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional); HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional);
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture); HybridConditional(boost::shared_ptr<GaussianMixtureConditional> gaussianMixture);
GaussianMixture::shared_ptr asMixture() { GaussianMixtureConditional::shared_ptr asMixture() {
if (!isHybrid_) throw std::invalid_argument("Not a mixture"); if (!isHybrid_) throw std::invalid_argument("Not a mixture");
return boost::static_pointer_cast<GaussianMixture>(inner); return boost::static_pointer_cast<GaussianMixtureConditional>(inner);
} }
DiscreteConditional::shared_ptr asDiscreteConditional() { DiscreteConditional::shared_ptr asDiscreteConditional() {

View File

@ -22,7 +22,7 @@
#include <gtsam/discrete/Assignment.h> #include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DiscreteEliminationTree.h> #include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixtureConditional.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h> #include <gtsam/hybrid/HybridDiscreteFactor.h>
@ -104,7 +104,7 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
// Because of all these reasons, we need to think very carefully about how to // Because of all these reasons, we need to think very carefully about how to
// implement the hybrid factors so that we do not get poor performance. // implement the hybrid factors so that we do not get poor performance.
// //
// The first thing is how to represent the GaussianMixture. A very possible // The first thing is how to represent the GaussianMixtureConditional. A very possible
// scenario is that the incoming factors will have different levels of // scenario is that the incoming factors will have different levels of
// discrete keys. For example, imagine we are going to eliminate the fragment: // discrete keys. For example, imagine we are going to eliminate the fragment:
// $\phi(x1,c1,c2)$, $\phi(x1,c2,c3)$, which is perfectly valid. Now we will // $\phi(x1,c1,c2)$, $\phi(x1,c2,c3)$, which is perfectly valid. Now we will
@ -358,11 +358,11 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
auto pair = unzip(eliminationResults); auto pair = unzip(eliminationResults);
const GaussianMixture::Conditionals &conditionals = pair.first; const GaussianMixtureConditional::Conditionals &conditionals = pair.first;
const GaussianMixtureFactor::Factors &separatorFactors = pair.second; const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
// Create the GaussianMixture from the conditionals // Create the GaussianMixtureConditional from the conditionals
auto conditional = boost::make_shared<GaussianMixture>( auto conditional = boost::make_shared<GaussianMixtureConditional>(
frontalKeys, keysOfSeparator, discreteSeparator, conditionals); frontalKeys, keysOfSeparator, discreteSeparator, conditionals);
if (DEBUG) { if (DEBUG) {

View File

@ -38,16 +38,16 @@ class GaussianMixtureFactor : gtsam::HybridFactor {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
}; };
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixtureConditional.h>
class GaussianMixture : gtsam::HybridFactor { class GaussianMixtureConditional : gtsam::HybridFactor {
static GaussianMixture FromConditionalList( static GaussianMixtureConditional FromConditionalList(
const gtsam::KeyVector& continuousFrontals, const gtsam::KeyVector& continuousFrontals,
const gtsam::KeyVector& continuousParents, const gtsam::KeyVector& continuousParents,
const gtsam::DiscreteKeys& discreteParents, const gtsam::DiscreteKeys& discreteParents,
const std::vector<gtsam::GaussianConditional::shared_ptr>& const std::vector<gtsam::GaussianConditional::shared_ptr>&
conditionalsList); conditionalsList);
void print(string s = "GaussianMixture\n", void print(string s = "GaussianMixtureConditional\n",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
}; };

View File

@ -20,7 +20,7 @@
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixtureConditional.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
@ -79,8 +79,8 @@ TEST(HybridFactorGraph, creation) {
hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1))); hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1)));
GaussianMixture clgc({X(0)}, {X(1)}, DiscreteKeys(DiscreteKey{C(0), 2}), GaussianMixtureConditional clgc({X(0)}, {X(1)}, DiscreteKeys(DiscreteKey{C(0), 2}),
GaussianMixture::Conditionals( GaussianMixtureConditional::Conditionals(
C(0), C(0),
boost::make_shared<GaussianConditional>( boost::make_shared<GaussianConditional>(
X(0), Z_3x1, I_3x3, X(1), I_3x3), X(0), Z_3x1, I_3x3, X(1), I_3x3),