Merge pull request #1362 from borglab/hybrid/test_with_evaluate

release/4.3a0
Varun Agrawal 2023-01-04 01:37:24 -05:00 committed by GitHub
commit 5cdff9e223
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 764 additions and 297 deletions

View File

@ -51,28 +51,28 @@ GaussianMixture::GaussianMixture(
Conditionals(discreteParents, conditionalsList)) {} Conditionals(discreteParents, conditionalsList)) {}
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add( GaussianFactorGraphTree GaussianMixture::add(
const GaussianMixture::Sum &sum) const { const GaussianFactorGraphTree &sum) const {
using Y = GaussianMixtureFactor::GraphAndConstant; using Y = GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1.graph; auto result = graph1.graph;
result.push_back(graph2.graph); result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant); return Y(result, graph1.constant + graph2.constant);
}; };
const Sum tree = asGaussianFactorGraphTree(); const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add); return sum.empty() ? tree : sum.apply(tree, add);
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const { GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianConditional::shared_ptr &conditional) { auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(conditional); result.push_back(conditional);
if (conditional) { if (conditional) {
return GaussianMixtureFactor::GraphAndConstant( return GraphAndConstant(
result, conditional->logNormalizationConstant()); result, conditional->logNormalizationConstant());
} else { } else {
return GaussianMixtureFactor::GraphAndConstant(result, 0.0); return GraphAndConstant(result, 0.0);
} }
}; };
return {conditionals_, lambda}; return {conditionals_, lambda};
@ -103,7 +103,19 @@ GaussianConditional::shared_ptr GaussianMixture::operator()(
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf); const This *e = dynamic_cast<const This *>(&lf);
return e != nullptr && BaseFactor::equals(*e, tol); if (e == nullptr) return false;
// This will return false if either conditionals_ is empty or e->conditionals_
// is empty, but not if both are empty or both are not empty:
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return f1->equals(*(f2), tol);
});
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -59,9 +59,6 @@ class GTSAM_EXPORT GaussianMixture
using BaseFactor = HybridFactor; using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, GaussianMixture>; using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
/// Alias for DecisionTree of GaussianFactorGraphs
using Sum = DecisionTree<Key, GaussianMixtureFactor::GraphAndConstant>;
/// typedef for Decision Tree of Gaussian Conditionals /// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>; using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
@ -71,7 +68,7 @@ class GTSAM_EXPORT GaussianMixture
/** /**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. * @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
*/ */
Sum asGaussianFactorGraphTree() const; GaussianFactorGraphTree asGaussianFactorGraphTree() const;
/** /**
* @brief Helper function to get the pruner functor. * @brief Helper function to get the pruner functor.
@ -172,6 +169,16 @@ class GTSAM_EXPORT GaussianMixture
*/ */
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
// /// Calculate probability density for given values `x`.
// double evaluate(const HybridValues &values) const;
// /// Evaluate probability density, sugar.
// double operator()(const HybridValues &values) const { return
// evaluate(values); }
// /// Calculate log-density for given values `x`.
// double logDensity(const HybridValues &values) const;
/** /**
* @brief Prune the decision tree of Gaussian factors as per the discrete * @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`. * `decisionTree`.
@ -186,9 +193,9 @@ class GTSAM_EXPORT GaussianMixture
* maintaining the decision tree structure. * maintaining the decision tree structure.
* *
* @param sum Decision Tree of Gaussian Factor Graphs * @param sum Decision Tree of Gaussian Factor Graphs
* @return Sum * @return GaussianFactorGraphTree
*/ */
Sum add(const Sum &sum) const; GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/// @} /// @}
}; };

View File

@ -81,32 +81,36 @@ void GaussianMixtureFactor::print(const std::string &s,
} }
/* *******************************************************************************/ /* *******************************************************************************/
const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const { GaussianFactor::shared_ptr GaussianMixtureFactor::factor(
return Mixture(factors_, [](const FactorAndConstant &factor_z) { const DiscreteValues &assignment) const {
return factor_z.factor; return factors_(assignment).factor;
});
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::add( double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const {
const GaussianMixtureFactor::Sum &sum) const { return factors_(assignment).constant;
using Y = GaussianMixtureFactor::GraphAndConstant; }
/* *******************************************************************************/
GaussianFactorGraphTree GaussianMixtureFactor::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1.graph; auto result = graph1.graph;
result.push_back(graph2.graph); result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant); return Y(result, graph1.constant + graph2.constant);
}; };
const Sum tree = asGaussianFactorGraphTree(); const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add); return sum.empty() ? tree : sum.apply(tree, add);
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
const { const {
auto wrap = [](const FactorAndConstant &factor_z) { auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor_z.factor); result.push_back(factor_z.factor);
return GaussianMixtureFactor::GraphAndConstant(result, factor_z.constant); return GraphAndConstant(result, factor_z.constant);
}; };
return {factors_, wrap}; return {factors_, wrap};
} }

View File

@ -62,6 +62,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
// Note: constant is log of normalization constant for probabilities. // Note: constant is log of normalization constant for probabilities.
// Errors is the negative log-likelihood, // Errors is the negative log-likelihood,
// hence we subtract the constant here. // hence we subtract the constant here.
if (!factor) return 0.0; // If nullptr, return 0.0 error
return factor->error(values) - constant; return factor->error(values) - constant;
} }
@ -71,22 +72,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
} }
}; };
/// Gaussian factor graph and log of normalizing constant.
struct GraphAndConstant {
GaussianFactorGraph graph;
double constant;
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
: graph(graph), constant(constant) {}
// Check pointer equality.
bool operator==(const GraphAndConstant &other) const {
return graph == other.graph && constant == other.constant;
}
};
using Sum = DecisionTree<Key, GraphAndConstant>;
/// typedef for Decision Tree of Gaussian factors and log-constant. /// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>; using Factors = DecisionTree<Key, FactorAndConstant>;
using Mixture = DecisionTree<Key, sharedFactor>; using Mixture = DecisionTree<Key, sharedFactor>;
@ -99,9 +84,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @brief Helper function to return factors and functional to create a * @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs. * DecisionTree of Gaussian Factor Graphs.
* *
* @return Sum (DecisionTree<Key, GaussianFactorGraph>) * @return GaussianFactorGraphTree
*/ */
Sum asGaussianFactorGraphTree() const; GaussianFactorGraphTree asGaussianFactorGraphTree() const;
public: public:
/// @name Constructors /// @name Constructors
@ -151,12 +136,16 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
void print( void print(
const std::string &s = "GaussianMixtureFactor\n", const std::string &s = "GaussianMixtureFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
/// @name Standard API /// @name Standard API
/// @{ /// @{
/// Getter for the underlying Gaussian Factor Decision Tree. /// Get factor at a given discrete assignment.
const Mixture factors() const; sharedFactor factor(const DiscreteValues &assignment) const;
/// Get constant at a given discrete assignment.
double constant(const DiscreteValues &assignment) const;
/** /**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
@ -166,7 +155,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* variables. * variables.
* @return Sum * @return Sum
*/ */
Sum add(const Sum &sum) const; GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/** /**
* @brief Compute error of the GaussianMixtureFactor as a tree. * @brief Compute error of the GaussianMixtureFactor as a tree.
@ -184,7 +173,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
/// Add MixtureFactor to a Sum, syntactic sugar. /// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { friend GaussianFactorGraphTree &operator+=(
GaussianFactorGraphTree &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum); sum = factor.add(sum);
return sum; return sum;
} }

View File

@ -26,6 +26,17 @@ static std::mt19937_64 kRandomNumberGenerator(42);
namespace gtsam { namespace gtsam {
/* ************************************************************************* */
void HybridBayesNet::print(const std::string &s,
const KeyFormatter &formatter) const {
Base::print(s, formatter);
}
/* ************************************************************************* */
bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol);
}
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree; AlgebraicDecisionTree<Key> decisionTree;
@ -271,12 +282,15 @@ double HybridBayesNet::evaluate(const HybridValues &values) const {
// Iterate over each conditional. // Iterate over each conditional.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
// TODO: should be delegated to derived classes.
if (auto gm = conditional->asMixture()) { if (auto gm = conditional->asMixture()) {
const auto component = (*gm)(discreteValues); const auto component = (*gm)(discreteValues);
logDensity += component->logDensity(continuousValues); logDensity += component->logDensity(continuousValues);
} else if (auto gc = conditional->asGaussian()) { } else if (auto gc = conditional->asGaussian()) {
// If continuous only, evaluate the probability and multiply. // If continuous only, evaluate the probability and multiply.
logDensity += gc->logDensity(continuousValues); logDensity += gc->logDensity(continuousValues);
} else if (auto dc = conditional->asDiscrete()) { } else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, so return its probability. // Conditional is discrete-only, so return its probability.
probability *= dc->operator()(discreteValues); probability *= dc->operator()(discreteValues);

View File

@ -50,18 +50,14 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// @name Testable /// @name Testable
/// @{ /// @{
/** Check equality */ /// GTSAM-style printing
bool equals(const This &bn, double tol = 1e-9) const {
return Base::equals(bn, tol);
}
/// print graph
void print( void print(
const std::string &s = "", const std::string &s = "",
const KeyFormatter &formatter = DefaultKeyFormatter) const override { const KeyFormatter &formatter = DefaultKeyFormatter) const override;
Base::print(s, formatter);
}
/// GTSAM-style equals
bool equals(const This& fg, double tol = 1e-9) const;
/// @} /// @}
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{

View File

@ -17,6 +17,7 @@
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/inference/Key.h> #include <gtsam/inference/Key.h>
@ -102,7 +103,38 @@ void HybridConditional::print(const std::string &s,
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridConditional::equals(const HybridFactor &other, double tol) const { bool HybridConditional::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other); const This *e = dynamic_cast<const This *>(&other);
return e != nullptr && BaseFactor::equals(*e, tol); if (e == nullptr) return false;
if (auto gm = asMixture()) {
auto other = e->asMixture();
return other != nullptr && gm->equals(*other, tol);
}
if (auto gc = asGaussian()) {
auto other = e->asGaussian();
return other != nullptr && gc->equals(*other, tol);
}
if (auto dc = asDiscrete()) {
auto other = e->asDiscrete();
return other != nullptr && dc->equals(*other, tol);
}
return inner_->equals(*(e->inner_), tol);
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
}
/* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const {
if (auto gm = asMixture()) {
return gm->error(values);
}
if (auto gc = asGaussian()) {
return gc->error(values.continuous());
}
if (auto dc = asDiscrete()) {
return -log((*dc)(values.discrete()));
}
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
} }
} // namespace gtsam } // namespace gtsam

View File

@ -176,15 +176,7 @@ class GTSAM_EXPORT HybridConditional
boost::shared_ptr<Factor> inner() const { return inner_; } boost::shared_ptr<Factor> inner() const { return inner_; }
/// Return the error of the underlying conditional. /// Return the error of the underlying conditional.
/// Currently only implemented for Gaussian mixture. double error(const HybridValues& values) const override;
double error(const HybridValues& values) const override {
if (auto gm = asMixture()) {
return gm->error(values);
} else {
throw std::runtime_error(
"HybridConditional::error: only implemented for Gaussian mixture");
}
}
/// @} /// @}
@ -195,6 +187,7 @@ class GTSAM_EXPORT HybridConditional
void serialize(Archive& ar, const unsigned int /*version*/) { void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar& BOOST_SERIALIZATION_NVP(inner_);
} }
}; // HybridConditional }; // HybridConditional

View File

@ -21,6 +21,8 @@
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/Factor.h> #include <gtsam/inference/Factor.h>
#include <gtsam/nonlinear/Values.h> #include <gtsam/nonlinear/Values.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/discrete/DecisionTree.h>
#include <cstddef> #include <cstddef>
#include <string> #include <string>
@ -28,6 +30,36 @@ namespace gtsam {
class HybridValues; class HybridValues;
/// Gaussian factor graph and log of normalizing constant.
struct GraphAndConstant {
GaussianFactorGraph graph;
double constant;
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
: graph(graph), constant(constant) {}
// Check pointer equality.
bool operator==(const GraphAndConstant &other) const {
return graph == other.graph && constant == other.constant;
}
// Implement GTSAM-style print:
void print(const std::string &s = "Graph: ",
const KeyFormatter &formatter = DefaultKeyFormatter) const {
graph.print(s, formatter);
std::cout << "Constant: " << constant << std::endl;
}
// Implement GTSAM-style equals:
bool equals(const GraphAndConstant &other, double tol = 1e-9) const {
return graph.equals(other.graph, tol) &&
fabs(constant - other.constant) < tol;
}
};
/// Alias for DecisionTree of GaussianFactorGraphs
using GaussianFactorGraphTree = DecisionTree<Key, GraphAndConstant>;
KeyVector CollectKeys(const KeyVector &continuousKeys, KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys); const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
@ -160,4 +192,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
template <> template <>
struct traits<HybridFactor> : public Testable<HybridFactor> {}; struct traits<HybridFactor> : public Testable<HybridFactor> {};
template <>
struct traits<GraphAndConstant> : public Testable<GraphAndConstant> {};
} // namespace gtsam } // namespace gtsam

View File

@ -48,9 +48,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/** /**
* Constructor from shared_ptr of GaussianFactor. * Constructor from shared_ptr of GaussianFactor.
* Example: * Example:
* boost::shared_ptr<GaussianFactor> ptr = * auto ptr = boost::make_shared<JacobianFactor>(...);
* boost::make_shared<JacobianFactor>(...); * HybridGaussianFactor factor(ptr);
*
*/ */
explicit HybridGaussianFactor(const boost::shared_ptr<GaussianFactor> &ptr); explicit HybridGaussianFactor(const boost::shared_ptr<GaussianFactor> &ptr);

View File

@ -59,44 +59,44 @@ namespace gtsam {
template class EliminateableFactorGraph<HybridGaussianFactorGraph>; template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */ /* ************************************************************************ */
static GaussianMixtureFactor::Sum &addGaussian( static GaussianFactorGraphTree addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { const GaussianFactorGraphTree &sum,
const GaussianFactor::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it. // If the decision tree is not initialized, then initialize it.
if (sum.empty()) { if (sum.empty()) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(factor);
sum = GaussianMixtureFactor::Sum( return GaussianFactorGraphTree(GraphAndConstant(result, 0.0));
GaussianMixtureFactor::GraphAndConstant(result, 0.0));
} else { } else {
auto add = [&factor]( auto add = [&factor](const GraphAndConstant &graph_z) {
const GaussianMixtureFactor::GraphAndConstant &graph_z) {
auto result = graph_z.graph; auto result = graph_z.graph;
result.push_back(factor); result.push_back(factor);
return GaussianMixtureFactor::GraphAndConstant(result, graph_z.constant); return GraphAndConstant(result, graph_z.constant);
}; };
sum = sum.apply(add); return sum.apply(add);
} }
return sum;
} }
/* ************************************************************************ */ /* ************************************************************************ */
GaussianMixtureFactor::Sum sumFrontals( // TODO(dellaert): We need to document why deferredFactors need to be
const HybridGaussianFactorGraph &factors) { // added last, which I would undo if possible. Implementation-wise, it's
// sum out frontals, this is the factor on the separator // probably more efficient to first collect the discrete keys, and then loop
gttic(sum); // over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
gttic(assembleGraphTree);
GaussianMixtureFactor::Sum sum; GaussianFactorGraphTree result;
std::vector<GaussianFactor::shared_ptr> deferredFactors; std::vector<GaussianFactor::shared_ptr> deferredFactors;
for (auto &f : factors) { for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (f->isHybrid()) { if (f->isHybrid()) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) { if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
sum = gm->add(sum); result = gm->add(result);
} }
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) { if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
sum = gm->asMixture()->add(sum); result = gm->asMixture()->add(result);
} }
} else if (f->isContinuous()) { } else if (f->isContinuous()) {
@ -127,16 +127,16 @@ GaussianMixtureFactor::Sum sumFrontals(
} }
for (auto &f : deferredFactors) { for (auto &f : deferredFactors) {
sum = addGaussian(sum, f); result = addGaussian(result, f);
} }
gttoc(sum); gttoc(assembleGraphTree);
return sum; return result;
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
continuousElimination(const HybridGaussianFactorGraph &factors, continuousElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &frontalKeys) {
GaussianFactorGraph gfg; GaussianFactorGraph gfg;
@ -157,7 +157,7 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
discreteElimination(const HybridGaussianFactorGraph &factors, discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) { const Ordering &frontalKeys) {
DiscreteFactorGraph dfg; DiscreteFactorGraph dfg;
@ -174,53 +174,52 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
} }
} }
auto result = EliminateForMPE(dfg, frontalKeys); // NOTE: This does sum-product. For max-product, use EliminateForMPE.
auto result = EliminateDiscrete(dfg, frontalKeys);
return {boost::make_shared<HybridConditional>(result.first), return {boost::make_shared<HybridConditional>(result.first),
boost::make_shared<HybridDiscreteFactor>(result.second)}; boost::make_shared<HybridDiscreteFactor>(result.second)};
} }
/* ************************************************************************ */ /* ************************************************************************ */
std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> // If any GaussianFactorGraph in the decision tree contains a nullptr, convert
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
// otherwise create a GFG with a single (null) factor.
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
auto emptyGaussian = [](const GraphAndConstant &graph_z) {
bool hasNull =
std::any_of(graph_z.graph.begin(), graph_z.graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GraphAndConstant{GaussianFactorGraph(), 0.0} : graph_z;
};
return GaussianFactorGraphTree(sum, emptyGaussian);
}
/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
hybridElimination(const HybridGaussianFactorGraph &factors, hybridElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys, const Ordering &frontalKeys,
const KeySet &continuousSeparator, const KeyVector &continuousSeparator,
const std::set<DiscreteKey> &discreteSeparatorSet) { const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree, // NOTE: since we use the special JunctionTree,
// only possibility is continuous conditioned on discrete. // only possibility is continuous conditioned on discrete.
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end()); discreteSeparatorSet.end());
// Collect all the frontal factors to create Gaussian factor graphs // Collect all the factors to create a set of Gaussian factor graphs in a
// indexed on the discrete keys. // decision tree indexed by all discrete keys involved.
GaussianMixtureFactor::Sum sum = sumFrontals(factors); GaussianFactorGraphTree sum = factors.assembleGraphTree();
// If a tree leaf contains nullptr, // Convert factor graphs with a nullptr to an empty factor graph.
// convert that leaf to an empty GaussianFactorGraph. // This is done after assembly since it is non-trivial to keep track of which
// Needed since the DecisionTree will otherwise create // FG has a nullptr as we're looping over the factors.
// a GFG with a single (null) factor. sum = removeEmpty(sum);
auto emptyGaussian =
[](const GaussianMixtureFactor::GraphAndConstant &graph_z) {
bool hasNull = std::any_of(
graph_z.graph.begin(), graph_z.graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GaussianMixtureFactor::GraphAndConstant(
GaussianFactorGraph(), 0.0)
: graph_z;
};
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>, using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::FactorAndConstant>; GaussianMixtureFactor::FactorAndConstant>;
KeyVector keysOfEliminated; // Not the ordering
KeyVector keysOfSeparator;
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
auto eliminateFunc = auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair {
[&](const GaussianMixtureFactor::GraphAndConstant &graph_z)
-> EliminationPair {
if (graph_z.graph.empty()) { if (graph_z.graph.empty()) {
return {nullptr, {nullptr, 0.0}}; return {nullptr, {nullptr, 0.0}};
} }
@ -229,25 +228,30 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
gttic_(hybrid_eliminate); gttic_(hybrid_eliminate);
#endif #endif
std::pair<boost::shared_ptr<GaussianConditional>, boost::shared_ptr<GaussianConditional> conditional;
boost::shared_ptr<GaussianFactor>> boost::shared_ptr<GaussianFactor> newFactor;
conditional_factor = boost::tie(conditional, newFactor) =
EliminatePreferCholesky(graph_z.graph, frontalKeys); EliminatePreferCholesky(graph_z.graph, frontalKeys);
// Initialize the keysOfEliminated to be the keys of the // Get the log of the log normalization constant inverse and
// eliminated GaussianConditional // add it to the previous constant.
keysOfEliminated = conditional_factor.first->keys(); const double logZ =
keysOfSeparator = conditional_factor.second->keys(); graph_z.constant - conditional->logNormalizationConstant();
GaussianConditional::shared_ptr conditional = conditional_factor.first;
// Get the log of the log normalization constant inverse. // Get the log of the log normalization constant inverse.
double logZ = -conditional->logNormalizationConstant() + graph_z.constant; // double logZ = -conditional->logNormalizationConstant();
// // IF this is the last continuous variable to eliminated, we need to
// // calculate the error here: the value of all factors at the mean, see
// // ml_map_rao.pdf.
// if (continuousSeparator.empty()) {
// const auto posterior_mean = conditional->solve(VectorValues());
// logZ += graph_z.graph.error(posterior_mean);
// }
#ifdef HYBRID_TIMING #ifdef HYBRID_TIMING
gttoc_(hybrid_eliminate); gttoc_(hybrid_eliminate);
#endif #endif
return {conditional, {conditional_factor.second, logZ}}; return {conditional, {newFactor, logZ}};
}; };
// Perform elimination! // Perform elimination!
@ -259,54 +263,50 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
#endif #endif
// Separate out decision tree into conditionals and remaining factors. // Separate out decision tree into conditionals and remaining factors.
auto pair = unzip(eliminationResults); GaussianMixture::Conditionals conditionals;
const auto &separatorFactors = pair.second; GaussianMixtureFactor::Factors newFactors;
std::tie(conditionals, newFactors) = unzip(eliminationResults);
// Create the GaussianMixture from the conditionals // Create the GaussianMixture from the conditionals
auto conditional = boost::make_shared<GaussianMixture>( auto gaussianMixture = boost::make_shared<GaussianMixture>(
frontalKeys, keysOfSeparator, discreteSeparator, pair.first); frontalKeys, continuousSeparator, discreteSeparator, conditionals);
// If there are no more continuous parents, then we should create here a // If there are no more continuous parents, then we should create a
// DiscreteFactor, with the error for each discrete choice. // DiscreteFactor here, with the error for each discrete choice.
if (keysOfSeparator.empty()) { if (continuousSeparator.empty()) {
auto factorProb = auto factorProb =
[&](const GaussianMixtureFactor::FactorAndConstant &factor_z) { [&](const GaussianMixtureFactor::FactorAndConstant &factor_z) {
GaussianFactor::shared_ptr factor = factor_z.factor; // This is the probability q(μ) at the MLE point.
if (!factor) { // factor_z.factor is a factor without keys,
return 0.0; // If nullptr, return 0.0 probability // just containing the residual.
} else { return exp(-factor_z.error(VectorValues()));
// This is the probability q(μ) at the MLE point.
double error = factor_z.error(VectorValues());
return std::exp(-error);
}
}; };
DecisionTree<Key, double> fdt(separatorFactors, factorProb);
// Normalize the values of decision tree to be valid probabilities const DecisionTree<Key, double> fdt(newFactors, factorProb);
double sum = 0.0; // // Normalize the values of decision tree to be valid probabilities
auto visitor = [&](double y) { sum += y; }; // double sum = 0.0;
fdt.visit(visitor); // auto visitor = [&](double y) { sum += y; };
// Check if sum is 0, and update accordingly. // fdt.visit(visitor);
if (sum == 0) { // // Check if sum is 0, and update accordingly.
sum = 1.0; // if (sum == 0) {
} // sum = 1.0;
// }
// fdt = DecisionTree<Key, double>(fdt, // fdt = DecisionTree<Key, double>(fdt,
// [sum](const double &x) { return x / sum; // [sum](const double &x) { return x / sum;
// }); // });
const auto discreteFactor =
auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt); boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
return {boost::make_shared<HybridConditional>(conditional), return {boost::make_shared<HybridConditional>(gaussianMixture),
boost::make_shared<HybridDiscreteFactor>(discreteFactor)}; boost::make_shared<HybridDiscreteFactor>(discreteFactor)};
} else { } else {
// Create a resulting GaussianMixtureFactor on the separator. // Create a resulting GaussianMixtureFactor on the separator.
auto factor = boost::make_shared<GaussianMixtureFactor>( return {boost::make_shared<HybridConditional>(gaussianMixture),
KeyVector(continuousSeparator.begin(), continuousSeparator.end()), boost::make_shared<GaussianMixtureFactor>(
discreteSeparator, separatorFactors); continuousSeparator, discreteSeparator, newFactors)};
return {boost::make_shared<HybridConditional>(conditional), factor};
} }
} }
/* ************************************************************************ /* ************************************************************************
* Function to eliminate variables **under the following assumptions**: * Function to eliminate variables **under the following assumptions**:
* 1. When the ordering is fully continuous, and the graph only contains * 1. When the ordering is fully continuous, and the graph only contains
@ -403,12 +403,12 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// Fill in discrete discrete separator keys and continuous separator keys. // Fill in discrete discrete separator keys and continuous separator keys.
std::set<DiscreteKey> discreteSeparatorSet; std::set<DiscreteKey> discreteSeparatorSet;
KeySet continuousSeparator; KeyVector continuousSeparator;
for (auto &k : separatorKeys) { for (auto &k : separatorKeys) {
if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) {
discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k)); discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k));
} else { } else {
continuousSeparator.insert(k); continuousSeparator.push_back(k);
} }
} }

View File

@ -18,6 +18,7 @@
#pragma once #pragma once
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h> #include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
@ -118,14 +119,12 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
: Base(graph) {} : Base(graph) {}
/// @} /// @}
/// @name Adding factors.
/// @{
using Base::empty;
using Base::reserve;
using Base::size;
using Base::operator[];
using Base::add; using Base::add;
using Base::push_back; using Base::push_back;
using Base::resize; using Base::reserve;
/// Add a Jacobian factor to the factor graph. /// Add a Jacobian factor to the factor graph.
void add(JacobianFactor&& factor); void add(JacobianFactor&& factor);
@ -172,6 +171,25 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
} }
} }
/// @}
/// @name Testable
/// @{
// TODO(dellaert): customize print and equals.
// void print(const std::string& s = "HybridGaussianFactorGraph",
// const KeyFormatter& keyFormatter = DefaultKeyFormatter) const
// override;
// bool equals(const This& fg, double tol = 1e-9) const override;
/// @}
/// @name Standard Interface
/// @{
using Base::empty;
using Base::size;
using Base::operator[];
using Base::resize;
/** /**
* @brief Compute error for each discrete assignment, * @brief Compute error for each discrete assignment,
* and return as a tree. * and return as a tree.
@ -217,6 +235,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @return const Ordering * @return const Ordering
*/ */
const Ordering getHybridOrdering() const; const Ordering getHybridOrdering() const;
/**
* @brief Create a decision tree of factor graphs out of this hybrid factor
* graph.
*
* For example, if there are two mixture factors, one with a discrete key A
* and one with a discrete key B, then the decision tree will have two levels,
* one for A and one for B. The leaves of the tree will be the Gaussian
* factors that have only continuous keys.
*/
GaussianFactorGraphTree assembleGraphTree() const;
/// @}
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -168,6 +168,15 @@ class GTSAM_EXPORT HybridValues {
return *this; return *this;
} }
/// Extract continuous values with given keys.
VectorValues continuousSubset(const KeyVector& keys) const {
VectorValues measurements;
for (const auto& key : keys) {
measurements.insert(key, continuous_.at(key));
}
return measurements;
}
/// @} /// @}
/// @name Wrapper support /// @name Wrapper support
/// @{ /// @{

View File

@ -40,6 +40,15 @@ virtual class HybridFactor {
bool empty() const; bool empty() const;
size_t size() const; size_t size() const;
gtsam::KeyVector keys() const; gtsam::KeyVector keys() const;
// Standard interface:
double error(const gtsam::HybridValues &values) const;
bool isDiscrete() const;
bool isContinuous() const;
bool isHybrid() const;
size_t nrContinuous() const;
gtsam::DiscreteKeys discreteKeys() const;
gtsam::KeyVector continuousKeys() const;
}; };
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
@ -50,7 +59,13 @@ virtual class HybridConditional {
bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const; bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const;
size_t nrFrontals() const; size_t nrFrontals() const;
size_t nrParents() const; size_t nrParents() const;
// Standard interface:
gtsam::GaussianMixture* asMixture() const;
gtsam::GaussianConditional* asGaussian() const;
gtsam::DiscreteConditional* asDiscrete() const;
gtsam::Factor* inner(); gtsam::Factor* inner();
double error(const gtsam::HybridValues& values) const;
}; };
#include <gtsam/hybrid/HybridDiscreteFactor.h> #include <gtsam/hybrid/HybridDiscreteFactor.h>
@ -61,6 +76,7 @@ virtual class HybridDiscreteFactor {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const; bool equals(const gtsam::HybridDiscreteFactor& other, double tol = 1e-9) const;
gtsam::Factor* inner(); gtsam::Factor* inner();
double error(const gtsam::HybridValues &values) const;
}; };
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>

View File

@ -0,0 +1,96 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010-2023, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/*
* @file TinyHybridExample.h
* @date December, 2022
* @author Frank Dellaert
*/
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/inference/Symbol.h>
#pragma once
namespace gtsam {
namespace tiny {
using symbol_shorthand::M;
using symbol_shorthand::X;
using symbol_shorthand::Z;
// Create mode key: 0 is low-noise, 1 is high-noise.
const DiscreteKey mode{M(0), 2};
/**
* Create a tiny two variable hybrid model which represents
* the generative probability P(z,x,mode) = P(z|x,mode)P(x)P(mode).
*/
inline HybridBayesNet createHybridBayesNet(int num_measurements = 1) {
HybridBayesNet bayesNet;
// Create Gaussian mixture z_i = x0 + noise for each measurement.
for (int i = 0; i < num_measurements; i++) {
const auto conditional0 = boost::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 0.5));
const auto conditional1 = boost::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(i), I_1x1, X(0), Z_1x1, 3));
GaussianMixture gm({Z(i)}, {X(0)}, {mode}, {conditional0, conditional1});
bayesNet.emplaceMixture(gm); // copy :-(
}
// Create prior on X(0).
const auto prior_on_x0 =
GaussianConditional::FromMeanAndStddev(X(0), Vector1(5.0), 0.5);
bayesNet.emplaceGaussian(prior_on_x0); // copy :-(
// Add prior on mode.
bayesNet.emplaceDiscrete(mode, "4/6");
return bayesNet;
}
/**
* Convert a hybrid Bayes net to a hybrid Gaussian factor graph.
*/
inline HybridGaussianFactorGraph convertBayesNet(
const HybridBayesNet& bayesNet, const VectorValues& measurements) {
HybridGaussianFactorGraph fg;
int num_measurements = bayesNet.size() - 2;
for (int i = 0; i < num_measurements; i++) {
auto conditional = bayesNet.atMixture(i);
auto factor = conditional->likelihood({{Z(i), measurements.at(Z(i))}});
fg.push_back(factor);
}
fg.push_back(bayesNet.atGaussian(num_measurements));
fg.push_back(bayesNet.atDiscrete(num_measurements + 1));
return fg;
}
/**
* Create a tiny two variable hybrid factor graph which represents a discrete
* mode and a continuous variable x0, given a number of measurements of the
* continuous variable x0. If no measurements are given, they are sampled from
* the generative Bayes net model HybridBayesNet::Example(num_measurements)
*/
inline HybridGaussianFactorGraph createHybridGaussianFactorGraph(
int num_measurements = 1,
boost::optional<VectorValues> measurements = boost::none) {
auto bayesNet = createHybridBayesNet(num_measurements);
if (measurements) {
return convertBayesNet(bayesNet, *measurements);
} else {
return convertBayesNet(bayesNet, bayesNet.sample().continuous());
}
}
} // namespace tiny
} // namespace gtsam

View File

@ -80,7 +80,7 @@ TEST(GaussianMixtureFactor, Sum) {
// Create sum of two mixture factors: it will be a decision tree now on both // Create sum of two mixture factors: it will be a decision tree now on both
// discrete variables m1 and m2: // discrete variables m1 and m2:
GaussianMixtureFactor::Sum sum; GaussianFactorGraphTree sum;
sum += mixtureFactorA; sum += mixtureFactorA;
sum += mixtureFactorB; sum += mixtureFactorB;

View File

@ -24,6 +24,7 @@
#include <gtsam/nonlinear/NonlinearFactorGraph.h> #include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include "Switching.h" #include "Switching.h"
#include "TinyHybridExample.h"
// Include for test suite // Include for test suite
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
@ -63,7 +64,7 @@ TEST(HybridBayesNet, Add) {
/* ****************************************************************************/ /* ****************************************************************************/
// Test evaluate for a pure discrete Bayes net P(Asia). // Test evaluate for a pure discrete Bayes net P(Asia).
TEST(HybridBayesNet, evaluatePureDiscrete) { TEST(HybridBayesNet, EvaluatePureDiscrete) {
HybridBayesNet bayesNet; HybridBayesNet bayesNet;
bayesNet.emplaceDiscrete(Asia, "99/1"); bayesNet.emplaceDiscrete(Asia, "99/1");
HybridValues values; HybridValues values;
@ -71,6 +72,13 @@ TEST(HybridBayesNet, evaluatePureDiscrete) {
EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9); EXPECT_DOUBLES_EQUAL(0.99, bayesNet.evaluate(values), 1e-9);
} }
/* ****************************************************************************/
// Test creation of a tiny hybrid Bayes net.
TEST(HybridBayesNet, Tiny) {
auto bayesNet = tiny::createHybridBayesNet();
EXPECT_LONGS_EQUAL(3, bayesNet.size());
}
/* ****************************************************************************/ /* ****************************************************************************/
// Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia). // Test evaluate for a hybrid Bayes net P(X0|X1) P(X1|Asia) P(Asia).
TEST(HybridBayesNet, evaluateHybrid) { TEST(HybridBayesNet, evaluateHybrid) {
@ -205,7 +213,7 @@ TEST(HybridBayesNet, Optimize) {
} }
/* ****************************************************************************/ /* ****************************************************************************/
// Test bayes net error // Test Bayes net error
TEST(HybridBayesNet, Error) { TEST(HybridBayesNet, Error) {
Switching s(3); Switching s(3);
@ -236,7 +244,7 @@ TEST(HybridBayesNet, Error) {
EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9)); EXPECT(assert_equal(expected_pruned_error, pruned_error_tree, 1e-9));
// Verify error computation and check for specific error value // Verify error computation and check for specific error value
DiscreteValues discrete_values {{M(0), 1}, {M(1), 1}}; DiscreteValues discrete_values{{M(0), 1}, {M(1), 1}};
double total_error = 0; double total_error = 0;
for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) { for (size_t idx = 0; idx < hybridBayesNet->size(); idx++) {
@ -329,9 +337,11 @@ TEST(HybridBayesNet, Serialization) {
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering)); HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
EXPECT(equalsObj<HybridBayesNet>(hbn)); // TODO(Varun) Serialization of inner factor doesn't work. Requires
EXPECT(equalsXML<HybridBayesNet>(hbn)); // serialization support for all hybrid factors.
EXPECT(equalsBinary<HybridBayesNet>(hbn)); // EXPECT(equalsObj<HybridBayesNet>(hbn));
// EXPECT(equalsXML<HybridBayesNet>(hbn));
// EXPECT(equalsBinary<HybridBayesNet>(hbn));
} }
/* ****************************************************************************/ /* ****************************************************************************/

View File

@ -155,7 +155,7 @@ TEST(HybridBayesTree, Optimize) {
dfg.push_back( dfg.push_back(
boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner())); boost::dynamic_pointer_cast<DecisionTreeFactor>(factor->inner()));
} }
// Add the probabilities for each branch // Add the probabilities for each branch
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}}; DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}, {M(2), 2}};
vector<double> probs = {0.012519475, 0.041280228, 0.075018647, 0.081663656, vector<double> probs = {0.012519475, 0.041280228, 0.075018647, 0.081663656,
@ -211,10 +211,10 @@ TEST(HybridBayesTree, Choose) {
ordering += M(0); ordering += M(0);
ordering += M(1); ordering += M(1);
ordering += M(2); ordering += M(2);
//TODO(Varun) get segfault if ordering not provided // TODO(Varun) get segfault if ordering not provided
auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering); auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering);
auto expected_gbt = bayesTree->choose(assignment); auto expected_gbt = bayesTree->choose(assignment);
EXPECT(assert_equal(expected_gbt, gbt)); EXPECT(assert_equal(expected_gbt, gbt));
@ -229,9 +229,11 @@ TEST(HybridBayesTree, Serialization) {
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering)); *(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
using namespace gtsam::serializationTestHelpers; using namespace gtsam::serializationTestHelpers;
EXPECT(equalsObj<HybridBayesTree>(hbt)); // TODO(Varun) Serialization of inner factor doesn't work. Requires
EXPECT(equalsXML<HybridBayesTree>(hbt)); // serialization support for all hybrid factors.
EXPECT(equalsBinary<HybridBayesTree>(hbt)); // EXPECT(equalsObj<HybridBayesTree>(hbt));
// EXPECT(equalsXML<HybridBayesTree>(hbt));
// EXPECT(equalsBinary<HybridBayesTree>(hbt));
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -283,11 +283,10 @@ AlgebraicDecisionTree<Key> getProbPrimeTree(
return probPrimeTree; return probPrimeTree;
} }
/****************************************************************************/ /*********************************************************************************
/**
* Test for correctness of different branches of the P'(Continuous | Discrete). * Test for correctness of different branches of the P'(Continuous | Discrete).
* The values should match those of P'(Continuous) for each discrete mode. * The values should match those of P'(Continuous) for each discrete mode.
*/ ********************************************************************************/
TEST(HybridEstimation, Probability) { TEST(HybridEstimation, Probability) {
constexpr size_t K = 4; constexpr size_t K = 4;
std::vector<double> measurements = {0, 1, 2, 2}; std::vector<double> measurements = {0, 1, 2, 2};
@ -444,20 +443,30 @@ static HybridGaussianFactorGraph::shared_ptr createHybridGaussianFactorGraph() {
* Do hybrid elimination and do regression test on discrete conditional. * Do hybrid elimination and do regression test on discrete conditional.
********************************************************************************/ ********************************************************************************/
TEST(HybridEstimation, eliminateSequentialRegression) { TEST(HybridEstimation, eliminateSequentialRegression) {
// 1. Create the factor graph from the nonlinear factor graph. // Create the factor graph from the nonlinear factor graph.
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph(); HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph();
// 2. Eliminate into BN // Create expected discrete conditional on m0.
const Ordering ordering = fg->getHybridOrdering(); DiscreteKey m(M(0), 2);
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering); DiscreteConditional expected(m % "0.51341712/1"); // regression
// GTSAM_PRINT(*bn);
// TODO(dellaert): dc should be discrete conditional on m0, but it is an // Eliminate into BN using one ordering
// unnormalized factor? Ordering ordering1;
// DiscreteKey m(M(0), 2); ordering1 += X(0), X(1), M(0);
// DiscreteConditional expected(m % "0.51341712/1"); HybridBayesNet::shared_ptr bn1 = fg->eliminateSequential(ordering1);
// auto dc = bn->back()->asDiscrete();
// EXPECT(assert_equal(expected, *dc, 1e-9)); // Check that the discrete conditional matches the expected.
auto dc1 = bn1->back()->asDiscrete();
EXPECT(assert_equal(expected, *dc1, 1e-9));
// Eliminate into BN using a different ordering
Ordering ordering2;
ordering2 += X(0), X(1), M(0);
HybridBayesNet::shared_ptr bn2 = fg->eliminateSequential(ordering2);
// Check that the discrete conditional matches the expected.
auto dc2 = bn2->back()->asDiscrete();
EXPECT(assert_equal(expected, *dc2, 1e-9));
} }
/********************************************************************************* /*********************************************************************************
@ -472,46 +481,35 @@ TEST(HybridEstimation, eliminateSequentialRegression) {
********************************************************************************/ ********************************************************************************/
TEST(HybridEstimation, CorrectnessViaSampling) { TEST(HybridEstimation, CorrectnessViaSampling) {
// 1. Create the factor graph from the nonlinear factor graph. // 1. Create the factor graph from the nonlinear factor graph.
HybridGaussianFactorGraph::shared_ptr fg = createHybridGaussianFactorGraph(); const auto fg = createHybridGaussianFactorGraph();
// 2. Eliminate into BN // 2. Eliminate into BN
const Ordering ordering = fg->getHybridOrdering(); const Ordering ordering = fg->getHybridOrdering();
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering); const HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
// Set up sampling // Set up sampling
std::mt19937_64 rng(11); std::mt19937_64 rng(11);
// 3. Do sampling // Compute the log-ratio between the Bayes net and the factor graph.
int num_samples = 10; auto compute_ratio = [&](const HybridValues& sample) -> double {
return bn->evaluate(sample) / fg->probPrime(sample);
// Functor to compute the ratio between the
// Bayes net and the factor graph.
auto compute_ratio =
[](const HybridBayesNet::shared_ptr& bayesNet,
const HybridGaussianFactorGraph::shared_ptr& factorGraph,
const HybridValues& sample) -> double {
const DiscreteValues assignment = sample.discrete();
// Compute in log form for numerical stability
double log_ratio = bayesNet->error({sample.continuous(), assignment}) -
factorGraph->error({sample.continuous(), assignment});
double ratio = exp(-log_ratio);
return ratio;
}; };
// The error evaluated by the factor graph and the Bayes net should differ by // The error evaluated by the factor graph and the Bayes net should differ by
// the normalizing term computed via the Bayes net determinant. // the normalizing term computed via the Bayes net determinant.
const HybridValues sample = bn->sample(&rng); const HybridValues sample = bn->sample(&rng);
double ratio = compute_ratio(bn, fg, sample); double expected_ratio = compute_ratio(sample);
// regression // regression
EXPECT_DOUBLES_EQUAL(1.9477340410546764, ratio, 1e-9); EXPECT_DOUBLES_EQUAL(0.728588, expected_ratio, 1e-6);
// 4. Check that all samples == constant // 3. Do sampling
constexpr int num_samples = 10;
for (size_t i = 0; i < num_samples; i++) { for (size_t i = 0; i < num_samples; i++) {
// Sample from the bayes net // Sample from the bayes net
const HybridValues sample = bn->sample(&rng); const HybridValues sample = bn->sample(&rng);
// TODO(Varun) The ratio changes based on the mode // 4. Check that the ratio is constant.
// EXPECT_DOUBLES_EQUAL(ratio, compute_ratio(bn, fg, sample), 1e-9); EXPECT_DOUBLES_EQUAL(expected_ratio, compute_ratio(sample), 1e-6);
} }
} }

View File

@ -47,6 +47,7 @@
#include <vector> #include <vector>
#include "Switching.h" #include "Switching.h"
#include "TinyHybridExample.h"
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -133,8 +134,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
auto dc = result->at(2)->asDiscrete(); auto dc = result->at(2)->asDiscrete();
DiscreteValues dv; DiscreteValues dv;
dv[M(1)] = 0; dv[M(1)] = 0;
// regression // Regression test
EXPECT_DOUBLES_EQUAL(8.5730017810851127, dc->operator()(dv), 1e-3); EXPECT_DOUBLES_EQUAL(0.62245933120185448, dc->operator()(dv), 1e-3);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -613,6 +614,108 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
EXPECT(assert_equal(expected_probs, probs, 1e-7)); EXPECT(assert_equal(expected_probs, probs, 1e-7));
} }
/* ****************************************************************************/
// Check that assembleGraphTree assembles Gaussian factor graphs for each
// assignment.
TEST(HybridGaussianFactorGraph, assembleGraphTree) {
using symbol_shorthand::Z;
const int num_measurements = 1;
auto fg = tiny::createHybridGaussianFactorGraph(
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
EXPECT_LONGS_EQUAL(3, fg.size());
auto sum = fg.assembleGraphTree();
// Get mixture factor:
auto mixture = boost::dynamic_pointer_cast<GaussianMixtureFactor>(fg.at(0));
using GF = GaussianFactor::shared_ptr;
// Get prior factor:
const GF prior =
boost::dynamic_pointer_cast<HybridGaussianFactor>(fg.at(1))->inner();
// Create DiscreteValues for both 0 and 1:
DiscreteValues d0{{M(0), 0}}, d1{{M(0), 1}};
// Expected decision tree with two factor graphs:
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
GaussianFactorGraphTree expectedSum{
M(0),
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}),
mixture->constant(d0)},
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d1), prior}),
mixture->constant(d1)}};
EXPECT(assert_equal(expectedSum(d0), sum(d0), 1e-5));
EXPECT(assert_equal(expectedSum(d1), sum(d1), 1e-5));
}
/* ****************************************************************************/
// Check that eliminating tiny net with 1 measurement yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny1) {
using symbol_shorthand::Z;
const int num_measurements = 1;
auto fg = tiny::createHybridGaussianFactorGraph(
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
// Create expected Bayes Net:
HybridBayesNet expectedBayesNet;
// Create Gaussian mixture on X(0).
using tiny::mode;
// regression, but mean checked to be 5.0 in both cases:
const auto conditional0 = boost::make_shared<GaussianConditional>(
X(0), Vector1(14.1421), I_1x1 * 2.82843),
conditional1 = boost::make_shared<GaussianConditional>(
X(0), Vector1(10.1379), I_1x1 * 2.02759);
GaussianMixture gm({X(0)}, {}, {mode}, {conditional0, conditional1});
expectedBayesNet.emplaceMixture(gm); // copy :-(
// Add prior on mode.
expectedBayesNet.emplaceDiscrete(mode, "74/26");
// Test elimination
Ordering ordering;
ordering.push_back(X(0));
ordering.push_back(M(0));
const auto posterior = fg.eliminateSequential(ordering);
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
}
/* ****************************************************************************/
// Check that eliminating tiny net with 2 measurements yields correct result.
TEST(HybridGaussianFactorGraph, EliminateTiny2) {
// Create factor graph with 2 measurements such that posterior mean = 5.0.
using symbol_shorthand::Z;
const int num_measurements = 2;
auto fg = tiny::createHybridGaussianFactorGraph(
num_measurements,
VectorValues{{Z(0), Vector1(4.0)}, {Z(1), Vector1(6.0)}});
// Create expected Bayes Net:
HybridBayesNet expectedBayesNet;
// Create Gaussian mixture on X(0).
using tiny::mode;
// regression, but mean checked to be 5.0 in both cases:
const auto conditional0 = boost::make_shared<GaussianConditional>(
X(0), Vector1(17.3205), I_1x1 * 3.4641),
conditional1 = boost::make_shared<GaussianConditional>(
X(0), Vector1(10.274), I_1x1 * 2.0548);
GaussianMixture gm({X(0)}, {}, {mode}, {conditional0, conditional1});
expectedBayesNet.emplaceMixture(gm); // copy :-(
// Add prior on mode.
expectedBayesNet.emplaceDiscrete(mode, "23/77");
// Test elimination
Ordering ordering;
ordering.push_back(X(0));
ordering.push_back(M(0));
const auto posterior = fg.eliminateSequential(ordering);
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -177,19 +177,19 @@ TEST(HybridGaussianElimination, IncrementalInference) {
// Test the probability values with regression tests. // Test the probability values with regression tests.
DiscreteValues assignment; DiscreteValues assignment;
EXPECT(assert_equal(0.000956191, m00_prob, 1e-5)); EXPECT(assert_equal(0.0952922, m00_prob, 1e-5));
assignment[M(0)] = 0; assignment[M(0)] = 0;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT(assert_equal(0.000956191, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.0952922, (*discreteConditional)(assignment), 1e-5));
assignment[M(0)] = 1; assignment[M(0)] = 1;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT(assert_equal(0.00283728, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.282758, (*discreteConditional)(assignment), 1e-5));
assignment[M(0)] = 0; assignment[M(0)] = 0;
assignment[M(1)] = 1; assignment[M(1)] = 1;
EXPECT(assert_equal(0.00315253, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.314175, (*discreteConditional)(assignment), 1e-5));
assignment[M(0)] = 1; assignment[M(0)] = 1;
assignment[M(1)] = 1; assignment[M(1)] = 1;
EXPECT(assert_equal(0.00308831, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.307775, (*discreteConditional)(assignment), 1e-5));
// Check if the clique conditional generated from incremental elimination // Check if the clique conditional generated from incremental elimination
// matches that of batch elimination. // matches that of batch elimination.
@ -199,7 +199,7 @@ TEST(HybridGaussianElimination, IncrementalInference) {
isam[M(1)]->conditional()->inner()); isam[M(1)]->conditional()->inner());
// Account for the probability terms from evaluating continuous FGs // Account for the probability terms from evaluating continuous FGs
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}}; DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}};
vector<double> probs = {0.00095619114, 0.0031525308, 0.0028372777, 0.0030883072}; vector<double> probs = {0.095292197, 0.31417524, 0.28275772, 0.30777485};
auto expectedConditional = auto expectedConditional =
boost::make_shared<DecisionTreeFactor>(discrete_keys, probs); boost::make_shared<DecisionTreeFactor>(discrete_keys, probs);
EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6)); EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6));

View File

@ -443,7 +443,7 @@ TEST(HybridFactorGraph, Full_Elimination) {
ordering.clear(); ordering.clear();
for (size_t k = 0; k < self.K - 1; k++) ordering += M(k); for (size_t k = 0; k < self.K - 1; k++) ordering += M(k);
discreteBayesNet = discreteBayesNet =
*discrete_fg.eliminateSequential(ordering, EliminateForMPE); *discrete_fg.eliminateSequential(ordering, EliminateDiscrete);
} }
// Create ordering. // Create ordering.
@ -638,22 +638,30 @@ conditional 2: Hybrid P( x2 | m0 m1)
0 0 Leaf p(x2) 0 0 Leaf p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.1489 ] d = [ -10.1489 ]
mean: 1 elements
x2: -1.0099
No noise model No noise model
0 1 Leaf p(x2) 0 1 Leaf p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.1479 ] d = [ -10.1479 ]
mean: 1 elements
x2: -1.0098
No noise model No noise model
1 Choice(m0) 1 Choice(m0)
1 0 Leaf p(x2) 1 0 Leaf p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.0504 ] d = [ -10.0504 ]
mean: 1 elements
x2: -1.0001
No noise model No noise model
1 1 Leaf p(x2) 1 1 Leaf p(x2)
R = [ 10.0494 ] R = [ 10.0494 ]
d = [ -10.0494 ] d = [ -10.0494 ]
mean: 1 elements
x2: -1
No noise model No noise model
)"; )";

View File

@ -195,19 +195,19 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
// Test the probability values with regression tests. // Test the probability values with regression tests.
DiscreteValues assignment; DiscreteValues assignment;
EXPECT(assert_equal(0.000956191, m00_prob, 1e-5)); EXPECT(assert_equal(0.0952922, m00_prob, 1e-5));
assignment[M(0)] = 0; assignment[M(0)] = 0;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT(assert_equal(0.000956191, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.0952922, (*discreteConditional)(assignment), 1e-5));
assignment[M(0)] = 1; assignment[M(0)] = 1;
assignment[M(1)] = 0; assignment[M(1)] = 0;
EXPECT(assert_equal(0.00283728, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.282758, (*discreteConditional)(assignment), 1e-5));
assignment[M(0)] = 0; assignment[M(0)] = 0;
assignment[M(1)] = 1; assignment[M(1)] = 1;
EXPECT(assert_equal(0.00315253, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.314175, (*discreteConditional)(assignment), 1e-5));
assignment[M(0)] = 1; assignment[M(0)] = 1;
assignment[M(1)] = 1; assignment[M(1)] = 1;
EXPECT(assert_equal(0.00308831, (*discreteConditional)(assignment), 1e-5)); EXPECT(assert_equal(0.307775, (*discreteConditional)(assignment), 1e-5));
// Check if the clique conditional generated from incremental elimination // Check if the clique conditional generated from incremental elimination
// matches that of batch elimination. // matches that of batch elimination.
@ -216,7 +216,7 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
bayesTree[M(1)]->conditional()->inner()); bayesTree[M(1)]->conditional()->inner());
// Account for the probability terms from evaluating continuous FGs // Account for the probability terms from evaluating continuous FGs
DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}}; DiscreteKeys discrete_keys = {{M(0), 2}, {M(1), 2}};
vector<double> probs = {0.00095619114, 0.0031525308, 0.0028372777, 0.0030883072}; vector<double> probs = {0.095292197, 0.31417524, 0.28275772, 0.30777485};
auto expectedConditional = auto expectedConditional =
boost::make_shared<DecisionTreeFactor>(discrete_keys, probs); boost::make_shared<DecisionTreeFactor>(discrete_keys, probs);
EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6)); EXPECT(assert_equal(*expectedConditional, *actualConditional, 1e-6));

View File

@ -120,6 +120,10 @@ namespace gtsam {
<< endl; << endl;
} }
cout << formatMatrixIndented(" d = ", getb(), true) << "\n"; cout << formatMatrixIndented(" d = ", getb(), true) << "\n";
if (nrParents() == 0) {
const auto mean = solve({}); // solve for mean.
mean.print(" mean");
}
if (model_) if (model_)
model_->print(" Noise model: "); model_->print(" Noise model: ");
else else

View File

@ -507,6 +507,8 @@ TEST(GaussianConditional, Print) {
" R = [ 1 0 ]\n" " R = [ 1 0 ]\n"
" [ 0 1 ]\n" " [ 0 1 ]\n"
" d = [ 20 40 ]\n" " d = [ 20 40 ]\n"
" mean: 1 elements\n"
" x0: 20 40\n"
"isotropic dim=2 sigma=3\n"; "isotropic dim=2 sigma=3\n";
EXPECT(assert_print_equal(expected, conditional, "GaussianConditional")); EXPECT(assert_print_equal(expected, conditional, "GaussianConditional"));

View File

@ -18,9 +18,9 @@ from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
GaussianMixture, GaussianMixtureFactor, HybridBayesNet, HybridValues, GaussianMixture, GaussianMixtureFactor, HybridBayesNet,
HybridGaussianFactorGraph, JacobianFactor, Ordering, HybridGaussianFactorGraph, HybridValues, JacobianFactor,
noiseModel) Ordering, noiseModel)
class TestHybridGaussianFactorGraph(GtsamTestCase): class TestHybridGaussianFactorGraph(GtsamTestCase):
@ -82,10 +82,12 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
self.assertEqual(hv.atDiscrete(C(0)), 1) self.assertEqual(hv.atDiscrete(C(0)), 1)
@staticmethod @staticmethod
def tiny(num_measurements: int = 1) -> HybridBayesNet: def tiny(num_measurements: int = 1, prior_mean: float = 5.0,
prior_sigma: float = 0.5) -> HybridBayesNet:
""" """
Create a tiny two variable hybrid model which represents Create a tiny two variable hybrid model which represents
the generative probability P(z, x, n) = P(z | x, n)P(x)P(n). the generative probability P(Z, x0, mode) = P(Z|x0, mode)P(x0)P(mode).
num_measurements: number of measurements in Z = {z0, z1...}
""" """
# Create hybrid Bayes net. # Create hybrid Bayes net.
bayesNet = HybridBayesNet() bayesNet = HybridBayesNet()
@ -94,23 +96,24 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
mode = (M(0), 2) mode = (M(0), 2)
# Create Gaussian mixture Z(0) = X(0) + noise for each measurement. # Create Gaussian mixture Z(0) = X(0) + noise for each measurement.
I = np.eye(1) I_1x1 = np.eye(1)
keys = DiscreteKeys() keys = DiscreteKeys()
keys.push_back(mode) keys.push_back(mode)
for i in range(num_measurements): for i in range(num_measurements):
conditional0 = GaussianConditional.FromMeanAndStddev(Z(i), conditional0 = GaussianConditional.FromMeanAndStddev(Z(i),
I, I_1x1,
X(0), [0], X(0), [0],
sigma=0.5) sigma=0.5)
conditional1 = GaussianConditional.FromMeanAndStddev(Z(i), conditional1 = GaussianConditional.FromMeanAndStddev(Z(i),
I, I_1x1,
X(0), [0], X(0), [0],
sigma=3) sigma=3)
bayesNet.emplaceMixture([Z(i)], [X(0)], keys, bayesNet.emplaceMixture([Z(i)], [X(0)], keys,
[conditional0, conditional1]) [conditional0, conditional1])
# Create prior on X(0). # Create prior on X(0).
prior_on_x0 = GaussianConditional.FromMeanAndStddev(X(0), [5.0], 5.0) prior_on_x0 = GaussianConditional.FromMeanAndStddev(
X(0), [prior_mean], prior_sigma)
bayesNet.addGaussian(prior_on_x0) bayesNet.addGaussian(prior_on_x0)
# Add prior on mode. # Add prior on mode.
@ -118,8 +121,41 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
return bayesNet return bayesNet
def test_evaluate(self):
"""Test evaluate with two different prior noise models."""
# TODO(dellaert): really a HBN test
# Create a tiny Bayes net P(x0) P(m0) P(z0|x0)
bayesNet1 = self.tiny(prior_sigma=0.5, num_measurements=1)
bayesNet2 = self.tiny(prior_sigma=5.0, num_measurements=1)
# bn1: # 1/sqrt(2*pi*0.5^2)
# bn2: # 1/sqrt(2*pi*5.0^2)
expected_ratio = np.sqrt(2*np.pi*5.0**2)/np.sqrt(2*np.pi*0.5**2)
mean0 = HybridValues()
mean0.insert(X(0), [5.0])
mean0.insert(Z(0), [5.0])
mean0.insert(M(0), 0)
self.assertAlmostEqual(bayesNet1.evaluate(mean0) /
bayesNet2.evaluate(mean0), expected_ratio,
delta=1e-9)
mean1 = HybridValues()
mean1.insert(X(0), [5.0])
mean1.insert(Z(0), [5.0])
mean1.insert(M(0), 1)
self.assertAlmostEqual(bayesNet1.evaluate(mean1) /
bayesNet2.evaluate(mean1), expected_ratio,
delta=1e-9)
@staticmethod @staticmethod
def factor_graph_from_bayes_net(bayesNet: HybridBayesNet, sample: HybridValues): def measurements(sample: HybridValues, indices) -> gtsam.VectorValues:
"""Create measurements from a sample, grabbing Z(i) for indices."""
measurements = gtsam.VectorValues()
for i in indices:
measurements.insert(Z(i), sample.at(Z(i)))
return measurements
@classmethod
def factor_graph_from_bayes_net(cls, bayesNet: HybridBayesNet,
sample: HybridValues):
"""Create a factor graph from the Bayes net with sampled measurements. """Create a factor graph from the Bayes net with sampled measurements.
The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...` The factor graph is `P(x)P(n) ϕ(x, n; z0) ϕ(x, n; z1) ...`
and thus represents the same joint probability as the Bayes net. and thus represents the same joint probability as the Bayes net.
@ -128,31 +164,27 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
num_measurements = bayesNet.size() - 2 num_measurements = bayesNet.size() - 2
for i in range(num_measurements): for i in range(num_measurements):
conditional = bayesNet.atMixture(i) conditional = bayesNet.atMixture(i)
measurement = gtsam.VectorValues() factor = conditional.likelihood(cls.measurements(sample, [i]))
measurement.insert(Z(i), sample.at(Z(i)))
factor = conditional.likelihood(measurement)
fg.push_back(factor) fg.push_back(factor)
fg.push_back(bayesNet.atGaussian(num_measurements)) fg.push_back(bayesNet.atGaussian(num_measurements))
fg.push_back(bayesNet.atDiscrete(num_measurements+1)) fg.push_back(bayesNet.atDiscrete(num_measurements+1))
return fg return fg
@classmethod @classmethod
def estimate_marginals(cls, bayesNet: HybridBayesNet, sample: HybridValues, N=10000): def estimate_marginals(cls, target, proposal_density: HybridBayesNet,
"""Do importance sampling to get an estimate of the discrete marginal P(mode).""" N=10000):
# Use prior on x0, mode as proposal density. """Do importance sampling to estimate discrete marginal P(mode)."""
prior = cls.tiny(num_measurements=0) # just P(x0)P(mode) # Allocate space for marginals on mode.
# Allocate space for marginals.
marginals = np.zeros((2,)) marginals = np.zeros((2,))
# Do importance sampling. # Do importance sampling.
num_measurements = bayesNet.size() - 2
for s in range(N): for s in range(N):
proposed = prior.sample() proposed = proposal_density.sample() # sample from proposal
for i in range(num_measurements): target_proposed = target(proposed) # evaluate target
z_i = sample.at(Z(i)) # print(target_proposed, proposal_density.evaluate(proposed))
proposed.insert(Z(i), z_i) weight = target_proposed / proposal_density.evaluate(proposed)
weight = bayesNet.evaluate(proposed) / prior.evaluate(proposed) # print weight:
# print(f"weight: {weight}")
marginals[proposed.atDiscrete(M(0))] += weight marginals[proposed.atDiscrete(M(0))] += weight
# print marginals: # print marginals:
@ -161,72 +193,146 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
def test_tiny(self): def test_tiny(self):
"""Test a tiny two variable hybrid model.""" """Test a tiny two variable hybrid model."""
bayesNet = self.tiny() # P(x0)P(mode)P(z0|x0,mode)
sample = bayesNet.sample() prior_sigma = 0.5
# print(sample) bayesNet = self.tiny(prior_sigma=prior_sigma)
# Deterministic values exactly at the mean, for both x and Z:
values = HybridValues()
values.insert(X(0), [5.0])
values.insert(M(0), 0) # low-noise, standard deviation 0.5
z0: float = 5.0
values.insert(Z(0), [z0])
def unnormalized_posterior(x):
"""Posterior is proportional to joint, centered at 5.0 as well."""
x.insert(Z(0), [z0])
return bayesNet.evaluate(x)
# Create proposal density on (x0, mode), making sure it has same mean:
posterior_information = 1/(prior_sigma**2) + 1/(0.5**2)
posterior_sigma = posterior_information**(-0.5)
proposal_density = self.tiny(
num_measurements=0, prior_mean=5.0, prior_sigma=posterior_sigma)
# Estimate marginals using importance sampling. # Estimate marginals using importance sampling.
marginals = self.estimate_marginals(bayesNet, sample) marginals = self.estimate_marginals(
# print(f"True mode: {sample.atDiscrete(M(0))}") target=unnormalized_posterior, proposal_density=proposal_density)
# print(f"True mode: {values.atDiscrete(M(0))}")
# print(f"P(mode=0; Z) = {marginals[0]}")
# print(f"P(mode=1; Z) = {marginals[1]}")
# Check that the estimate is close to the true value.
self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
self.assertAlmostEqual(marginals[1], 0.26, delta=0.01)
fg = self.factor_graph_from_bayes_net(bayesNet, values)
self.assertEqual(fg.size(), 3)
# Test elimination.
ordering = gtsam.Ordering()
ordering.push_back(X(0))
ordering.push_back(M(0))
posterior = fg.eliminateSequential(ordering)
def true_posterior(x):
"""Posterior from elimination."""
x.insert(Z(0), [z0])
return posterior.evaluate(x)
# Estimate marginals using importance sampling.
marginals = self.estimate_marginals(
target=true_posterior, proposal_density=proposal_density)
# print(f"True mode: {values.atDiscrete(M(0))}")
# print(f"P(mode=0; z0) = {marginals[0]}") # print(f"P(mode=0; z0) = {marginals[0]}")
# print(f"P(mode=1; z0) = {marginals[1]}") # print(f"P(mode=1; z0) = {marginals[1]}")
# Check that the estimate is close to the true value. # Check that the estimate is close to the true value.
self.assertAlmostEqual(marginals[0], 0.4, delta=0.1) self.assertAlmostEqual(marginals[0], 0.74, delta=0.01)
self.assertAlmostEqual(marginals[1], 0.6, delta=0.1) self.assertAlmostEqual(marginals[1], 0.26, delta=0.01)
fg = self.factor_graph_from_bayes_net(bayesNet, sample)
self.assertEqual(fg.size(), 3)
@staticmethod @staticmethod
def calculate_ratio(bayesNet: HybridBayesNet, def calculate_ratio(bayesNet: HybridBayesNet,
fg: HybridGaussianFactorGraph, fg: HybridGaussianFactorGraph,
sample: HybridValues): sample: HybridValues):
"""Calculate ratio between Bayes net probability and the factor graph.""" """Calculate ratio between Bayes net and factor graph."""
return bayesNet.evaluate(sample) / fg.probPrime(sample) if fg.probPrime(sample) > 0 else 0 return bayesNet.evaluate(sample) / fg.probPrime(sample) if \
fg.probPrime(sample) > 0 else 0
def test_ratio(self): def test_ratio(self):
""" """
Given a tiny two variable hybrid model, with 2 measurements, Given a tiny two variable hybrid model, with 2 measurements, test the
test the ratio of the bayes net model representing P(z, x, n)=P(z|x, n)P(x)P(n) ratio of the bayes net model representing P(z,x,n)=P(z|x, n)P(x)P(n)
and the factor graph P(x, n | z)=P(x | n, z)P(n|z), and the factor graph P(x, n | z)=P(x | n, z)P(n|z),
both of which represent the same posterior. both of which represent the same posterior.
""" """
# Create the Bayes net representing the generative model P(z, x, n)=P(z|x, n)P(x)P(n) # Create generative model P(z, x, n)=P(z|x, n)P(x)P(n)
bayesNet = self.tiny(num_measurements=2) prior_sigma = 0.5
# Sample from the Bayes net. bayesNet = self.tiny(prior_sigma=prior_sigma, num_measurements=2)
sample: HybridValues = bayesNet.sample()
# print(sample) # Deterministic values exactly at the mean, for both x and Z:
values = HybridValues()
values.insert(X(0), [5.0])
values.insert(M(0), 0) # high-noise, standard deviation 3
measurements = gtsam.VectorValues()
measurements.insert(Z(0), [4.0])
measurements.insert(Z(1), [6.0])
values.insert(measurements)
def unnormalized_posterior(x):
"""Posterior is proportional to joint, centered at 5.0 as well."""
x.insert(measurements)
return bayesNet.evaluate(x)
# Create proposal density on (x0, mode), making sure it has same mean:
posterior_information = 1/(prior_sigma**2) + 2.0/(3.0**2)
posterior_sigma = posterior_information**(-0.5)
proposal_density = self.tiny(
num_measurements=0, prior_mean=5.0, prior_sigma=posterior_sigma)
# Estimate marginals using importance sampling. # Estimate marginals using importance sampling.
marginals = self.estimate_marginals(bayesNet, sample) marginals = self.estimate_marginals(
# print(f"True mode: {sample.atDiscrete(M(0))}") target=unnormalized_posterior, proposal_density=proposal_density)
# print(f"P(mode=0; z0, z1) = {marginals[0]}") # print(f"True mode: {values.atDiscrete(M(0))}")
# print(f"P(mode=1; z0, z1) = {marginals[1]}") # print(f"P(mode=0; Z) = {marginals[0]}")
# print(f"P(mode=1; Z) = {marginals[1]}")
# Check marginals based on sampled mode. # Check that the estimate is close to the true value.
if sample.atDiscrete(M(0)) == 0: self.assertAlmostEqual(marginals[0], 0.23, delta=0.01)
self.assertGreater(marginals[0], marginals[1]) self.assertAlmostEqual(marginals[1], 0.77, delta=0.01)
else:
self.assertGreater(marginals[1], marginals[0])
fg = self.factor_graph_from_bayes_net(bayesNet, sample) # Convert to factor graph using measurements.
fg = self.factor_graph_from_bayes_net(bayesNet, values)
self.assertEqual(fg.size(), 4) self.assertEqual(fg.size(), 4)
# Calculate ratio between Bayes net probability and the factor graph: # Calculate ratio between Bayes net probability and the factor graph:
expected_ratio = self.calculate_ratio(bayesNet, fg, sample) expected_ratio = self.calculate_ratio(bayesNet, fg, values)
# print(f"expected_ratio: {expected_ratio}\n") # print(f"expected_ratio: {expected_ratio}\n")
# Create measurements from the sample.
measurements = gtsam.VectorValues()
for i in range(2):
measurements.insert(Z(i), sample.at(Z(i)))
# Check with a number of other samples. # Check with a number of other samples.
for i in range(10): for i in range(10):
other = bayesNet.sample() samples = bayesNet.sample()
other.update(measurements) samples.update(measurements)
ratio = self.calculate_ratio(bayesNet, fg, other) ratio = self.calculate_ratio(bayesNet, fg, samples)
# print(f"Ratio: {ratio}\n")
if (ratio > 0):
self.assertAlmostEqual(ratio, expected_ratio)
# Test elimination.
ordering = gtsam.Ordering()
ordering.push_back(X(0))
ordering.push_back(M(0))
posterior = fg.eliminateSequential(ordering)
# Calculate ratio between Bayes net probability and the factor graph:
expected_ratio = self.calculate_ratio(posterior, fg, values)
# print(f"expected_ratio: {expected_ratio}\n")
# Check with a number of other samples.
for i in range(10):
samples = posterior.sample()
samples.insert(measurements)
ratio = self.calculate_ratio(posterior, fg, samples)
# print(f"Ratio: {ratio}\n") # print(f"Ratio: {ratio}\n")
if (ratio > 0): if (ratio > 0):
self.assertAlmostEqual(ratio, expected_ratio) self.assertAlmostEqual(ratio, expected_ratio)