Switch to using HybridValues

release/4.3a0
Frank Dellaert 2022-12-30 12:10:16 -05:00
parent b972be0b8f
commit 9cf3e5c26a
6 changed files with 38 additions and 50 deletions

View File

@ -22,6 +22,7 @@
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianFactorGraph.h>
@ -159,9 +160,9 @@ boost::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood(
}
/* ************************************************************************* */
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
s.insert(discreteKeys.begin(), discreteKeys.end());
return s;
}
@ -186,7 +187,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
const DiscreteValues values(choices);
// Case where the gaussian mixture has the same
// discrete keys as the decision tree.
@ -256,11 +257,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
}
/* *******************************************************************************/
double GaussianMixture::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double GaussianMixture::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(discreteValues);
return conditional->error(continuousValues);
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous());
}
} // namespace gtsam

View File

@ -30,6 +30,7 @@
namespace gtsam {
class GaussianMixtureFactor;
class HybridValues;
/**
* @brief A conditional of gaussian mixtures indexed by discrete variables, as
@ -87,7 +88,7 @@ class GTSAM_EXPORT GaussianMixture
/// @name Constructors
/// @{
/// Defaut constructor, mainly for serialization.
/// Default constructor, mainly for serialization.
GaussianMixture() = default;
/**
@ -135,6 +136,7 @@ class GTSAM_EXPORT GaussianMixture
/// @name Standard API
/// @{
/// @brief Return the conditional Gaussian for the given discrete assignment.
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const;
@ -165,12 +167,10 @@ class GTSAM_EXPORT GaussianMixture
* @brief Compute the error of this Gaussian Mixture given the continuous
* values and a discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues The discrete assignment for a specific mode sequence.
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
double error(const HybridValues &values) const;
/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
@ -193,7 +193,7 @@ class GTSAM_EXPORT GaussianMixture
};
/// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys);
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
// traits
template <>

View File

@ -1,5 +1,5 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* GTSAM Copyright 2010-2022, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
@ -12,6 +12,7 @@
* @author Fan Jiang
* @author Varun Agrawal
* @author Shangjie Xue
* @author Frank Dellaert
* @date January 2022
*/
@ -321,10 +322,9 @@ HybridValues HybridBayesNet::sample() const {
}
/* ************************************************************************* */
double HybridBayesNet::error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
GaussianBayesNet gbn = choose(discreteValues);
return gbn.error(continuousValues);
double HybridBayesNet::error(const HybridValues &values) const {
GaussianBayesNet gbn = choose(values.discrete());
return gbn.error(values.continuous());
}
/* ************************************************************************* */

View File

@ -206,12 +206,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief 0.5 * sum of squared Mahalanobis distances
* for a specific discrete assignment.
*
* @param continuousValues Continuous values at which to compute the error.
* @param discreteValues Discrete assignment for a specific mode sequence.
* @param values Continuous values and discrete assignment.
* @return double
*/
double error(const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const;
double error(const HybridValues &values) const;
/**
* @brief Compute conditional error for each discrete assignment,

View File

@ -55,13 +55,14 @@
namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */
static GaussianMixtureFactor::Sum &addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
using Y = GaussianFactorGraph;
// If the decision tree is not intiialized, then intialize it.
// If the decision tree is not initialized, then initialize it.
if (sum.empty()) {
GaussianFactorGraph result;
result.push_back(factor);
@ -89,8 +90,9 @@ GaussianMixtureFactor::Sum sumFrontals(
for (auto &f : factors) {
if (f->isHybrid()) {
if (auto cgmf = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
sum = cgmf->add(sum);
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
sum = gm->add(sum);
}
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
sum = gm->asMixture()->add(sum);
@ -184,7 +186,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
const KeySet &continuousSeparator,
const std::set<DiscreteKey> &discreteSeparatorSet) {
// NOTE: since we use the special JunctionTree,
// only possiblity is continuous conditioned on discrete.
// only possibility is continuous conditioned on discrete.
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end());
@ -251,8 +253,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Separate out decision tree into conditionals and remaining factors.
auto pair = unzip(eliminationResults);
const GaussianMixtureFactor::Factors &separatorFactors = pair.second;
const auto &separatorFactors = pair.second;
// Create the GaussianMixture from the conditionals
auto conditional = boost::make_shared<GaussianMixture>(
@ -460,6 +461,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
// Iterate over each factor.
for (size_t idx = 0; idx < size(); idx++) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
AlgebraicDecisionTree<Key> factor_error;
if (factors_.at(idx)->isHybrid()) {
@ -499,27 +501,26 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
}
/* ************************************************************************ */
double HybridGaussianFactorGraph::error(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
double error = 0.0;
for (size_t idx = 0; idx < size(); idx++) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
auto factor = factors_.at(idx);
if (factor->isHybrid()) {
if (auto c = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += c->asMixture()->error(continuousValues, discreteValues);
error += c->asMixture()->error(values);
}
if (auto f = boost::dynamic_pointer_cast<GaussianMixtureFactor>(factor)) {
error += f->error(continuousValues, discreteValues);
error += f->error(values);
}
} else if (factor->isContinuous()) {
if (auto f = boost::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
error += f->inner()->error(continuousValues);
error += f->inner()->error(values.continuous());
}
if (auto cg = boost::dynamic_pointer_cast<HybridConditional>(factor)) {
error += cg->asGaussian()->error(continuousValues);
error += cg->asGaussian()->error(values.continuous());
}
}
}
@ -527,10 +528,8 @@ double HybridGaussianFactorGraph::error(
}
/* ************************************************************************ */
double HybridGaussianFactorGraph::probPrime(
const VectorValues &continuousValues,
const DiscreteValues &discreteValues) const {
double error = this->error(continuousValues, discreteValues);
double HybridGaussianFactorGraph::probPrime(const HybridValues &values) const {
double error = this->error(values);
// NOTE: The 0.5 term is handled by each factor
return std::exp(-error);
}

View File

@ -186,14 +186,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @brief Compute error given a continuous vector values
* and a discrete assignment.
*
* @param continuousValues The continuous VectorValues
* for computing the error.
* @param discreteValues The specific discrete assignment
* whose error we wish to compute.
* @return double
*/
double error(const VectorValues& continuousValues,
const DiscreteValues& discreteValues) const;
double error(const HybridValues& values) const;
/**
* @brief Compute unnormalized probability \f$ P(X | M, Z) \f$
@ -210,13 +205,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* @brief Compute the unnormalized posterior probability for a continuous
* vector values given a specific assignment.
*
* @param continuousValues The vector values for which to compute the
* posterior probability.
* @param discreteValues The specific assignment to use for the computation.
* @return double
*/
double probPrime(const VectorValues& continuousValues,
const DiscreteValues& discreteValues) const;
double probPrime(const HybridValues& values) const;
/**
* @brief Return a Colamd constrained ordering where the discrete keys are