Added mixture factor functionality
parent
2bae2865d7
commit
ee4f9d19f0
|
@ -4,6 +4,9 @@
|
|||
|
||||
#include <gtsam/hybrid/CGMixtureFactor.h>
|
||||
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
#include <gtsam/base/utilities.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
CGMixtureFactor::CGMixtureFactor(const KeyVector &continuousKeys,
|
||||
|
@ -14,4 +17,18 @@ bool CGMixtureFactor::equals(const HybridFactor &lf, double tol) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
void CGMixtureFactor::print(const std::string &s, const KeyFormatter &formatter) const {
|
||||
HybridFactor::print(s, formatter);
|
||||
factors_.print(
|
||||
"mixture = ",
|
||||
[&](Key k) {
|
||||
return formatter(k);
|
||||
}, [&](const GaussianFactor::shared_ptr &gf) -> std::string {
|
||||
RedirectCout rd;
|
||||
if (!gf->empty()) gf->print("", formatter);
|
||||
else return {"nullptr"};
|
||||
return rd.str();
|
||||
});
|
||||
}
|
||||
|
||||
}
|
|
@ -42,6 +42,9 @@ public:
|
|||
const Factors &factors);
|
||||
|
||||
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
|
||||
|
||||
void print(const std::string &s = "HybridFactor\n",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -18,3 +18,34 @@
|
|||
|
||||
#include <gtsam/hybrid/CLGaussianConditional.h>
|
||||
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
CLGaussianConditional::CLGaussianConditional(const KeyVector &continuousFrontals,
|
||||
const KeyVector &continuousParents,
|
||||
const DiscreteKeys &discreteParents,
|
||||
const CLGaussianConditional::Conditionals &factors)
|
||||
: BaseFactor(
|
||||
CollectKeys(continuousFrontals, continuousParents), discreteParents),
|
||||
BaseConditional(continuousFrontals.size()) {
|
||||
|
||||
}
|
||||
|
||||
bool CLGaussianConditional::equals(const HybridFactor &lf, double tol) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
void CLGaussianConditional::print(const std::string &s, const KeyFormatter &formatter) const {
|
||||
std::cout << s << ": ";
|
||||
if (isContinuous_) std::cout << "Cont. ";
|
||||
if (isDiscrete_) std::cout << "Disc. ";
|
||||
if (isHybrid_) std::cout << "Hybr. ";
|
||||
BaseConditional::print("", formatter);
|
||||
std::cout << "Discrete Keys = ";
|
||||
for (auto &dk : discreteKeys_) {
|
||||
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
}
|
|
@ -19,13 +19,30 @@
|
|||
#include <gtsam/inference/Conditional.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
|
||||
namespace gtsam {
|
||||
class CLGaussianConditional : public HybridFactor, public Conditional<HybridFactor, CLGaussianConditional> {
|
||||
public:
|
||||
using This = CLGaussianConditional;
|
||||
using shared_ptr = boost::shared_ptr<CLGaussianConditional>;
|
||||
using BaseFactor = HybridFactor;
|
||||
using BaseConditional = Conditional<HybridFactor, CLGaussianConditional>;
|
||||
|
||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||
|
||||
public:
|
||||
|
||||
CLGaussianConditional(const KeyVector &continuousFrontals,
|
||||
const KeyVector &continuousParents,
|
||||
const DiscreteKeys &discreteParents,
|
||||
const Conditionals &factors);
|
||||
|
||||
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
|
||||
|
||||
void print(
|
||||
const std::string &s = "CLGaussianConditional\n",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
};
|
||||
}
|
|
@ -21,6 +21,11 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
|
|||
|
||||
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
void HybridDiscreteFactor::print(const std::string &s, const KeyFormatter &formatter) const {
|
||||
HybridFactor::print(s, formatter);
|
||||
inner->print("inner: ", formatter);
|
||||
};
|
||||
|
||||
}
|
|
@ -37,5 +37,7 @@ class HybridDiscreteFactor : public HybridFactor {
|
|||
|
||||
public:
|
||||
virtual bool equals(const HybridFactor& lf, double tol) const override;
|
||||
|
||||
void print(const std::string &s = "HybridFactor\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -16,3 +16,38 @@
|
|||
*/
|
||||
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) {
|
||||
KeyVector allKeys;
|
||||
std::copy(continuousKeys.begin(), continuousKeys.end(), std::back_inserter(allKeys));
|
||||
std::transform(discreteKeys.begin(),
|
||||
discreteKeys.end(),
|
||||
std::back_inserter(allKeys),
|
||||
[](const DiscreteKey &k) { return k.first; });
|
||||
return allKeys;
|
||||
}
|
||||
|
||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2) {
|
||||
KeyVector allKeys;
|
||||
std::copy(keys1.begin(), keys1.end(), std::back_inserter(allKeys));
|
||||
std::copy(keys2.begin(), keys2.end(), std::back_inserter(allKeys));
|
||||
return allKeys;
|
||||
}
|
||||
|
||||
HybridFactor::HybridFactor() = default;
|
||||
|
||||
HybridFactor::HybridFactor(const KeyVector &keys) : Base(keys), isContinuous_(true) {}
|
||||
|
||||
HybridFactor::HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys)
|
||||
: Base(
|
||||
CollectKeys(continuousKeys, discreteKeys)), isHybrid_(true), discreteKeys_(discreteKeys) {}
|
||||
|
||||
HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)),
|
||||
isDiscrete_(true),
|
||||
discreteKeys_(discreteKeys) {}
|
||||
|
||||
HybridFactor::~HybridFactor() = default;
|
||||
|
||||
}
|
|
@ -25,6 +25,9 @@
|
|||
#include <string>
|
||||
namespace gtsam {
|
||||
|
||||
KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys);
|
||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
||||
|
||||
/**
|
||||
* Base class for hybrid probabilistic factors
|
||||
*/
|
||||
|
@ -49,33 +52,21 @@ public:
|
|||
/// @{
|
||||
|
||||
/** Default constructor creates empty factor */
|
||||
HybridFactor() {}
|
||||
HybridFactor();
|
||||
|
||||
/** Construct from container of keys. This constructor is used internally from derived factor
|
||||
* constructors, either from a container of keys or from a boost::assign::list_of. */
|
||||
// template<typename CONTAINER>
|
||||
// HybridFactor(const CONTAINER &keys) : Base(keys) {}
|
||||
|
||||
HybridFactor(const KeyVector &keys) : Base(keys), isContinuous_(true) {}
|
||||
explicit HybridFactor(const KeyVector &keys);
|
||||
|
||||
static KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) {
|
||||
KeyVector allKeys;
|
||||
std::copy(continuousKeys.begin(), continuousKeys.end(), std::back_inserter(allKeys));
|
||||
std::transform(discreteKeys.begin(),
|
||||
discreteKeys.end(),
|
||||
std::back_inserter(allKeys),
|
||||
[](const DiscreteKey &k) { return k.first; });
|
||||
return allKeys;
|
||||
}
|
||||
HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys);
|
||||
|
||||
HybridFactor(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys) : Base(
|
||||
CollectKeys(continuousKeys, discreteKeys)), isHybrid_(true) {}
|
||||
|
||||
HybridFactor(const DiscreteKeys &discreteKeys) : Base(CollectKeys({}, discreteKeys)), isDiscrete_(true) {}
|
||||
explicit HybridFactor(const DiscreteKeys &discreteKeys);
|
||||
|
||||
/// Virtual destructor
|
||||
virtual ~HybridFactor() {
|
||||
}
|
||||
virtual ~HybridFactor();
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
|
|
|
@ -34,6 +34,7 @@ EliminateHybrid(const HybridFactorGraph &factors,
|
|||
|
||||
// PREPROCESS: Identify the nature of the current elimination
|
||||
KeySet allKeys;
|
||||
|
||||
// TODO: we do a mock by just doing the correct key thing
|
||||
std::cout << "Begin Eliminate: ";
|
||||
frontalKeys.print();
|
||||
|
@ -43,6 +44,7 @@ EliminateHybrid(const HybridFactorGraph &factors,
|
|||
factor->print();
|
||||
allKeys.insert(factor->begin(), factor->end());
|
||||
}
|
||||
|
||||
for (auto &k : frontalKeys) {
|
||||
allKeys.erase(k);
|
||||
}
|
||||
|
@ -51,13 +53,14 @@ EliminateHybrid(const HybridFactorGraph &factors,
|
|||
gttic(product);
|
||||
|
||||
HybridConditional sum(allKeys.size(), Ordering(allKeys));
|
||||
// HybridDiscreteFactor product(DiscreteConditional());
|
||||
// for (auto&& factor : factors) product = (*factor) * product;
|
||||
|
||||
// HybridDiscreteFactor product(DiscreteConditional());
|
||||
// for (auto&& factor : factors) product = (*factor) * product;
|
||||
gttoc(product);
|
||||
|
||||
// sum out frontals, this is the factor on the separator
|
||||
gttic(sum);
|
||||
// HybridFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
// HybridFactor::shared_ptr sum = product.sum(frontalKeys);
|
||||
gttoc(sum);
|
||||
|
||||
// Ordering keys for the conditional so that frontalKeys are really in front
|
||||
|
@ -69,11 +72,11 @@ EliminateHybrid(const HybridFactorGraph &factors,
|
|||
|
||||
// now divide product/sum to get conditional
|
||||
gttic(divide);
|
||||
// auto conditional =
|
||||
// boost::make_shared<HybridConditional>(product, *sum, orderedKeys);
|
||||
// auto conditional =
|
||||
// boost::make_shared<HybridConditional>(product, *sum, orderedKeys);
|
||||
gttoc(divide);
|
||||
|
||||
// return std::make_pair(conditional, sum);
|
||||
// return std::make_pair(conditional, sum);
|
||||
return std::make_pair(boost::make_shared<HybridConditional>(frontalKeys.size(),
|
||||
orderedKeys),
|
||||
boost::make_shared<HybridConditional>(std::move(sum)));
|
||||
|
|
|
@ -18,6 +18,10 @@ HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf) : Base(jf.keys()
|
|||
|
||||
bool HybridGaussianFactor::equals(const HybridFactor& lf, double tol) const {
|
||||
return false;
|
||||
}
|
||||
void HybridGaussianFactor::print(const std::string &s, const KeyFormatter &formatter) const {
|
||||
HybridFactor::print(s, formatter);
|
||||
inner->print("inner: ", formatter);
|
||||
};
|
||||
|
||||
}
|
|
@ -37,5 +37,7 @@ class HybridGaussianFactor : public HybridFactor {
|
|||
|
||||
public:
|
||||
virtual bool equals(const HybridFactor& lf, double tol) const override;
|
||||
|
||||
void print(const std::string &s = "HybridFactor\n", const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <gtsam/hybrid/HybridFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/hybrid/HybridDiscreteFactor.h>
|
||||
#include <gtsam/hybrid/CLGaussianConditional.h>
|
||||
#include <gtsam/hybrid/CGMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridBayesTree.h>
|
||||
|
@ -47,6 +48,13 @@ TEST_UNSAFE(HybridFactorGraph, creation) {
|
|||
HybridFactorGraph hfg;
|
||||
|
||||
hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1)));
|
||||
|
||||
CLGaussianConditional clgc(
|
||||
{X(0)}, {X(1)},
|
||||
DiscreteKeys(DiscreteKey{C(0), 2}),
|
||||
CLGaussianConditional::Conditionals()
|
||||
);
|
||||
GTSAM_PRINT(clgc);
|
||||
}
|
||||
|
||||
TEST_UNSAFE(HybridFactorGraph, eliminate) {
|
||||
|
@ -84,12 +92,19 @@ TEST(HybridFactorGraph, eliminateFullMultifrontal) {
|
|||
hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1));
|
||||
hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1));
|
||||
|
||||
DecisionTree<Key, GaussianFactor::shared_ptr> dt;
|
||||
DecisionTree<Key, GaussianFactor::shared_ptr> dt(C(1),
|
||||
boost::make_shared<JacobianFactor>(X(1),
|
||||
I_3x3,
|
||||
Z_3x1),
|
||||
boost::make_shared<JacobianFactor>(X(1),
|
||||
I_3x3,
|
||||
Vector3::Ones()));
|
||||
|
||||
hfg.add(CGMixtureFactor({X(1)}, { x }, dt));
|
||||
hfg.add(CGMixtureFactor({X(1)}, {x}, dt));
|
||||
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(x, {2, 8})));
|
||||
hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4")));
|
||||
|
||||
auto result = hfg.eliminateMultifrontal(Ordering::ColamdConstrainedLast(hfg, {C(1)}));
|
||||
auto result = hfg.eliminateMultifrontal(Ordering::ColamdConstrainedLast(hfg, {C(1), C(2)}));
|
||||
|
||||
GTSAM_PRINT(*result);
|
||||
GTSAM_PRINT(*result->marginalFactor(C(1)));
|
||||
|
|
Loading…
Reference in New Issue