Added mixture factor functionality

release/4.3a0
Fan Jiang 2022-03-13 11:42:36 -04:00
parent 2bae2865d7
commit ee4f9d19f0
12 changed files with 151 additions and 26 deletions

View File

@ -4,6 +4,9 @@
#include <gtsam/hybrid/CGMixtureFactor.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/base/utilities.h>
namespace gtsam {
CGMixtureFactor::CGMixtureFactor(const KeyVector &continuousKeys,
@ -14,4 +17,18 @@ bool CGMixtureFactor::equals(const HybridFactor &lf, double tol) const {
return false;
}
void CGMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
factors_.print(
"mixture = ",
[&](Key k) {
return formatter(k);
}, [&](const GaussianFactor::shared_ptr &gf) -> std::string {
RedirectCout rd;
if (!gf->empty()) gf->print("", formatter);
else return {"nullptr"};
return rd.str();
});
}
}

View File

@ -42,6 +42,9 @@ public:
const Factors &factors);
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
void print(const std::string &s = "HybridFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
};
}

View File

@ -18,3 +18,34 @@
#include <gtsam/hybrid/CLGaussianConditional.h>
#include <gtsam/inference/Conditional-inst.h>
namespace gtsam {
CLGaussianConditional::CLGaussianConditional(const KeyVector &continuousFrontals,
const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const CLGaussianConditional::Conditionals &factors)
: BaseFactor(
CollectKeys(continuousFrontals, continuousParents), discreteParents),
BaseConditional(continuousFrontals.size()) {
}
bool CLGaussianConditional::equals(const HybridFactor &lf, double tol) const {
return false;
}
void CLGaussianConditional::print(const std::string &s, const KeyFormatter &formatter) const {
std::cout << s << ": ";
if (isContinuous_) std::cout << "Cont. ";
if (isDiscrete_) std::cout << "Disc. ";
if (isHybrid_) std::cout << "Hybr. ";
BaseConditional::print("", formatter);
std::cout << "Discrete Keys = ";
for (auto &dk : discreteKeys_) {
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
}
std::cout << "\n";
}
}

View File

@ -19,13 +19,30 @@
#include <gtsam/inference/Conditional.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/discrete/DecisionTree.h>
namespace gtsam {
class CLGaussianConditional : public HybridFactor, public Conditional<HybridFactor, CLGaussianConditional> {
public:
using This = CLGaussianConditional;
using shared_ptr = boost::shared_ptr<CLGaussianConditional>;
using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, CLGaussianConditional>;
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
public:
CLGaussianConditional(const KeyVector &continuousFrontals,
const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const Conditionals &factors);
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
void print(
const std::string &s = "CLGaussianConditional\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
};
}

View File

@ -21,6 +21,11 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
return false;
}
void HybridDiscreteFactor::print(const std::string &s, const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
inner->print("inner: ", formatter);
};
}

View File

@ -37,5 +37,7 @@ class HybridDiscreteFactor : public HybridFactor {
public:
virtual bool equals(const HybridFactor& lf, double tol) const override;
void print(const std::string &s = "HybridFactor\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override;
};
}

View File

@ -16,3 +16,38 @@
*/
#include <gtsam/hybrid/HybridFactor.h>
namespace gtsam {
KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) {
KeyVector allKeys;
std::copy(continuousKeys.begin(), continuousKeys.end(), std::back_inserter(allKeys));
std::transform(discreteKeys.begin(),
discreteKeys.end(),
std::back_inserter(allKeys),
[](const DiscreteKey &k) { return k.first; });
return allKeys;
}
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2) {
KeyVector allKeys;
std::copy(keys1.begin(), keys1.end(), std::back_inserter(allKeys));
std::copy(keys2.begin(), keys2.end(), std::back_inserter(allKeys));
return allKeys;
}
HybridFactor::HybridFactor() = default;
HybridFactor::HybridFactor(const KeyVector &keys) : Base(keys), isContinuous_(true) {}
HybridFactor::HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys)
: Base(
CollectKeys(continuousKeys, discreteKeys)), isHybrid_(true), discreteKeys_(discreteKeys) {}
HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)),
isDiscrete_(true),
discreteKeys_(discreteKeys) {}
HybridFactor::~HybridFactor() = default;
}

View File

@ -25,6 +25,9 @@
#include <string>
namespace gtsam {
KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
/**
* Base class for hybrid probabilistic factors
*/
@ -49,33 +52,21 @@ public:
/// @{
/** Default constructor creates empty factor */
HybridFactor() {}
HybridFactor();
/** Construct from container of keys. This constructor is used internally from derived factor
* constructors, either from a container of keys or from a boost::assign::list_of. */
// template<typename CONTAINER>
// HybridFactor(const CONTAINER &keys) : Base(keys) {}
HybridFactor(const KeyVector &keys) : Base(keys), isContinuous_(true) {}
explicit HybridFactor(const KeyVector &keys);
static KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) {
KeyVector allKeys;
std::copy(continuousKeys.begin(), continuousKeys.end(), std::back_inserter(allKeys));
std::transform(discreteKeys.begin(),
discreteKeys.end(),
std::back_inserter(allKeys),
[](const DiscreteKey &k) { return k.first; });
return allKeys;
}
HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys);
HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) : Base(
CollectKeys(continuousKeys, discreteKeys)), isHybrid_(true) {}
HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), isDiscrete_(true) {}
explicit HybridFactor(const DiscreteKeys &discreteKeys);
/// Virtual destructor
virtual ~HybridFactor() {
}
virtual ~HybridFactor();
/// @}
/// @name Testable

View File

@ -34,6 +34,7 @@ EliminateHybrid(const HybridFactorGraph &factors,
// PREPROCESS: Identify the nature of the current elimination
KeySet allKeys;
// TODO: we do a mock by just doing the correct key thing
std::cout << "Begin Eliminate: ";
frontalKeys.print();
@ -43,6 +44,7 @@ EliminateHybrid(const HybridFactorGraph &factors,
factor->print();
allKeys.insert(factor->begin(), factor->end());
}
for (auto &k : frontalKeys) {
allKeys.erase(k);
}
@ -51,13 +53,14 @@ EliminateHybrid(const HybridFactorGraph &factors,
gttic(product);
HybridConditional sum(allKeys.size(), Ordering(allKeys));
// HybridDiscreteFactor product(DiscreteConditional());
// for (auto&& factor : factors) product = (*factor) * product;
// HybridDiscreteFactor product(DiscreteConditional());
// for (auto&& factor : factors) product = (*factor) * product;
gttoc(product);
// sum out frontals, this is the factor on the separator
gttic(sum);
// HybridFactor::shared_ptr sum = product.sum(frontalKeys);
// HybridFactor::shared_ptr sum = product.sum(frontalKeys);
gttoc(sum);
// Ordering keys for the conditional so that frontalKeys are really in front
@ -69,11 +72,11 @@ EliminateHybrid(const HybridFactorGraph &factors,
// now divide product/sum to get conditional
gttic(divide);
// auto conditional =
// boost::make_shared<HybridConditional>(product, *sum, orderedKeys);
// auto conditional =
// boost::make_shared<HybridConditional>(product, *sum, orderedKeys);
gttoc(divide);
// return std::make_pair(conditional, sum);
// return std::make_pair(conditional, sum);
return std::make_pair(boost::make_shared<HybridConditional>(frontalKeys.size(),
orderedKeys),
boost::make_shared<HybridConditional>(std::move(sum)));

View File

@ -18,6 +18,10 @@ HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf) : Base(jf.keys()
bool HybridGaussianFactor::equals(const HybridFactor& lf, double tol) const {
return false;
}
void HybridGaussianFactor::print(const std::string &s, const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter);
inner->print("inner: ", formatter);
};
}

View File

@ -37,5 +37,7 @@ class HybridGaussianFactor : public HybridFactor {
public:
virtual bool equals(const HybridFactor& lf, double tol) const override;
void print(const std::string &s = "HybridFactor\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override;
};
}

View File

@ -20,6 +20,7 @@
#include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/CLGaussianConditional.h>
#include <gtsam/hybrid/CGMixtureFactor.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
@ -47,6 +48,13 @@ TEST_UNSAFE(HybridFactorGraph, creation) {
HybridFactorGraph hfg;
hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1)));
CLGaussianConditional clgc(
{X(0)}, {X(1)},
DiscreteKeys(DiscreteKey{C(0), 2}),
CLGaussianConditional::Conditionals()
);
GTSAM_PRINT(clgc);
}
TEST_UNSAFE(HybridFactorGraph, eliminate) {
@ -84,12 +92,19 @@ TEST(HybridFactorGraph, eliminateFullMultifrontal) {
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
DecisionTree<Key, GaussianFactor::shared_ptr> dt;
DecisionTree<Key, GaussianFactor::shared_ptr> dt(C(1),
boost::make_shared<JacobianFactor>(X(1),
I_3x3,
Z_3x1),
boost::make_shared<JacobianFactor>(X(1),
I_3x3,
Vector3::Ones()));
hfg.add(CGMixtureFactor({X(1)}, { x }, dt));
hfg.add(CGMixtureFactor({X(1)}, {x}, dt));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(x, {2, 8})));
hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4")));
auto result = hfg.eliminateMultifrontal(Ordering::ColamdConstrainedLast(hfg, {C(1)}));
auto result = hfg.eliminateMultifrontal(Ordering::ColamdConstrainedLast(hfg, {C(1), C(2)}));
GTSAM_PRINT(*result);
GTSAM_PRINT(*result->marginalFactor(C(1)));