update HybridGaussianFactor to leverage constant hiding for the Tree of Pairs

release/4.3a0
Varun Agrawal 2024-09-14 14:54:36 -04:00
parent a7f5173b88
commit 9360165ef6
2 changed files with 91 additions and 23 deletions

View File

@ -28,11 +28,54 @@
namespace gtsam {
/**
* @brief Helper function to augment the [A|b] matrices in the factor components
* with the normalizer values.
* This is done by storing the normalizer value in
* the `b` vector as an additional row.
*
* @param factors DecisionTree of GaussianFactors and arbitrary scalars.
* Gaussian factor in factors.
* @return HybridGaussianFactor::Factors
*/
HybridGaussianFactor::Factors augment(
const HybridGaussianFactor::FactorValuePairs &factors) {
// Find the minimum value so we can "proselytize" to positive values.
// Done because we can't have sqrt of negative numbers.
auto unzipped_pair = unzip(factors);
const HybridGaussianFactor::Factors gaussianFactors = unzipped_pair.first;
const AlgebraicDecisionTree<Key> valueTree = unzipped_pair.second;
double min_value = valueTree.min();
AlgebraicDecisionTree<Key> values =
valueTree.apply([&min_value](double n) { return n - min_value; });
// Finally, update the [A|b] matrices.
auto update = [&values](const Assignment<Key> &assignment,
const HybridGaussianFactor::sharedFactor &gf) {
auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
if (!jf) return gf;
// If the log_normalizer is 0, do nothing
if (values(assignment) == 0.0) return gf;
GaussianFactorGraph gfg;
gfg.push_back(jf);
Vector c(1);
c << std::sqrt(values(assignment));
auto constantFactor = std::make_shared<JacobianFactor>(c);
gfg.push_back(constantFactor);
return std::dynamic_pointer_cast<GaussianFactor>(
std::make_shared<JacobianFactor>(gfg));
};
return gaussianFactors.apply(update);
}
/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {}
const FactorValuePairs &factors)
: Base(continuousKeys, discreteKeys), factors_(augment(factors)) {}
/* *******************************************************************************/
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
@ -45,10 +88,10 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
// Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_, [tol](const GaussianFactorValuePair &f1,
const GaussianFactorValuePair &f2) {
return f1.first->equals(*f2.first, tol) && (f1.second == f2.second);
});
factors_.equals(e->factors_,
[tol](const sharedFactor &f1, const sharedFactor &f2) {
return f1->equals(*f2, tol);
});
}
/* *******************************************************************************/
@ -63,13 +106,11 @@ void HybridGaussianFactor::print(const std::string &s,
} else {
factors_.print(
"", [&](Key k) { return formatter(k); },
[&](const GaussianFactorValuePair &gfv) -> std::string {
auto [gf, val] = gfv;
[&](const sharedFactor &gf) -> std::string {
RedirectCout rd;
std::cout << ":\n";
if (gf) {
gf->print("", formatter);
std::cout << "value: " << val << std::endl;
return rd.str();
} else {
return "nullptr";
@ -80,7 +121,7 @@ void HybridGaussianFactor::print(const std::string &s,
}
/* *******************************************************************************/
GaussianFactorValuePair HybridGaussianFactor::operator()(
HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()(
const DiscreteValues &assignment) const {
return factors_(assignment);
}
@ -101,9 +142,7 @@ GaussianFactorGraphTree HybridGaussianFactor::add(
/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const GaussianFactorValuePair &gfv) {
return GaussianFactorGraph{gfv.first};
};
auto wrap = [](const sharedFactor &gf) { return GaussianFactorGraph{gf}; };
return {factors_, wrap};
}
@ -111,9 +150,8 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const GaussianFactorValuePair &gfv) {
auto [gf, v] = gfv;
return gf->error(continuousValues) + (0.5 * v * v);
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return gf->error(continuousValues);
};
DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree;
@ -121,8 +159,24 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
/* *******************************************************************************/
double HybridGaussianFactor::error(const HybridValues &values) const {
auto &&[gf, val] = factors_(values.discrete());
return gf->error(values.continuous()) + val;
const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous());
}
/* *******************************************************************************/
double ComputeLogNormalizer(
const noiseModel::Gaussian::shared_ptr &noise_model) {
// Since noise models are Gaussian, we can get the logDeterminant using
// the same trick as in GaussianConditional
double logDetR = noise_model->R()
.diagonal()
.unaryExpr([](double x) { return log(x); })
.sum();
double logDeterminantSigma = -2.0 * logDetR;
size_t n = noise_model->dim();
constexpr double log2pi = 1.8378770664093454835606594728112;
return n * log2pi + logDeterminantSigma;
}
} // namespace gtsam

View File

@ -55,8 +55,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
using sharedFactor = std::shared_ptr<GaussianFactor>;
/// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, GaussianFactorValuePair>;
/// typedef for Decision Tree of Gaussian factors and arbitrary value.
using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>;
/// typedef for Decision Tree of Gaussian factors.
using Factors = DecisionTree<Key, sharedFactor>;
private:
/// Decision tree of Gaussian factors indexed by discrete keys.
@ -87,7 +89,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/
HybridGaussianFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors);
const FactorValuePairs &factors);
/**
* @brief Construct a new HybridGaussianFactor object using a vector of
@ -102,7 +104,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactorValuePair> &factors)
: HybridGaussianFactor(continuousKeys, discreteKeys,
Factors(discreteKeys, factors)) {}
FactorValuePairs(discreteKeys, factors)) {}
/// @}
/// @name Testable
@ -118,7 +120,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// @{
/// Get the factor and scalar at a given discrete assignment.
GaussianFactorValuePair operator()(const DiscreteValues &assignment) const;
sharedFactor operator()(const DiscreteValues &assignment) const;
/**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
@ -173,4 +175,16 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
template <>
struct traits<HybridGaussianFactor> : public Testable<HybridGaussianFactor> {};
/**
* @brief Helper function to compute the sqrt(|2πΣ|) normalizer values
* for a Gaussian noise model.
* We compute this in the log-space for numerical accuracy.
*
* @param noise_model The Gaussian noise model
* whose normalizer we wish to compute.
* @return double
*/
GTSAM_EXPORT double ComputeLogNormalizer(
const noiseModel::Gaussian::shared_ptr &noise_model);
} // namespace gtsam