update HybridGaussianFactor to leverage constant hiding for the Tree of Pairs

release/4.3a0
Varun Agrawal 2024-09-14 14:54:36 -04:00
parent a7f5173b88
commit 9360165ef6
2 changed files with 91 additions and 23 deletions

View File

@ -28,11 +28,54 @@
namespace gtsam { namespace gtsam {
/**
* @brief Helper function to augment the [A|b] matrices in the factor components
* with the normalizer values.
* This is done by storing the normalizer value in
* the `b` vector as an additional row.
*
* @param factors DecisionTree of GaussianFactors and arbitrary scalars.
* Gaussian factor in factors.
* @return HybridGaussianFactor::Factors
*/
HybridGaussianFactor::Factors augment(
const HybridGaussianFactor::FactorValuePairs &factors) {
// Find the minimum value so we can "proselytize" to positive values.
// Done because we can't have sqrt of negative numbers.
auto unzipped_pair = unzip(factors);
const HybridGaussianFactor::Factors gaussianFactors = unzipped_pair.first;
const AlgebraicDecisionTree<Key> valueTree = unzipped_pair.second;
double min_value = valueTree.min();
AlgebraicDecisionTree<Key> values =
valueTree.apply([&min_value](double n) { return n - min_value; });
// Finally, update the [A|b] matrices.
auto update = [&values](const Assignment<Key> &assignment,
const HybridGaussianFactor::sharedFactor &gf) {
auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
if (!jf) return gf;
// If the log_normalizer is 0, do nothing
if (values(assignment) == 0.0) return gf;
GaussianFactorGraph gfg;
gfg.push_back(jf);
Vector c(1);
c << std::sqrt(values(assignment));
auto constantFactor = std::make_shared<JacobianFactor>(c);
gfg.push_back(constantFactor);
return std::dynamic_pointer_cast<GaussianFactor>(
std::make_shared<JacobianFactor>(gfg));
};
return gaussianFactors.apply(update);
}
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const KeyVector &continuousKeys, HybridGaussianFactor::HybridGaussianFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const Factors &factors) const FactorValuePairs &factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {} : Base(continuousKeys, discreteKeys), factors_(augment(factors)) {}
/* *******************************************************************************/ /* *******************************************************************************/
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const { bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
@ -45,9 +88,9 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
// Check the base and the factors: // Check the base and the factors:
return Base::equals(*e, tol) && return Base::equals(*e, tol) &&
factors_.equals(e->factors_, [tol](const GaussianFactorValuePair &f1, factors_.equals(e->factors_,
const GaussianFactorValuePair &f2) { [tol](const sharedFactor &f1, const sharedFactor &f2) {
return f1.first->equals(*f2.first, tol) && (f1.second == f2.second); return f1->equals(*f2, tol);
}); });
} }
@ -63,13 +106,11 @@ void HybridGaussianFactor::print(const std::string &s,
} else { } else {
factors_.print( factors_.print(
"", [&](Key k) { return formatter(k); }, "", [&](Key k) { return formatter(k); },
[&](const GaussianFactorValuePair &gfv) -> std::string { [&](const sharedFactor &gf) -> std::string {
auto [gf, val] = gfv;
RedirectCout rd; RedirectCout rd;
std::cout << ":\n"; std::cout << ":\n";
if (gf) { if (gf) {
gf->print("", formatter); gf->print("", formatter);
std::cout << "value: " << val << std::endl;
return rd.str(); return rd.str();
} else { } else {
return "nullptr"; return "nullptr";
@ -80,7 +121,7 @@ void HybridGaussianFactor::print(const std::string &s,
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianFactorValuePair HybridGaussianFactor::operator()( HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()(
const DiscreteValues &assignment) const { const DiscreteValues &assignment) const {
return factors_(assignment); return factors_(assignment);
} }
@ -101,9 +142,7 @@ GaussianFactorGraphTree HybridGaussianFactor::add(
/* *******************************************************************************/ /* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree() GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
const { const {
auto wrap = [](const GaussianFactorValuePair &gfv) { auto wrap = [](const sharedFactor &gf) { return GaussianFactorGraph{gf}; };
return GaussianFactorGraph{gfv.first};
};
return {factors_, wrap}; return {factors_, wrap};
} }
@ -111,9 +150,8 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree( AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value. // functor to convert from sharedFactor to double error value.
auto errorFunc = [&continuousValues](const GaussianFactorValuePair &gfv) { auto errorFunc = [&continuousValues](const sharedFactor &gf) {
auto [gf, v] = gfv; return gf->error(continuousValues);
return gf->error(continuousValues) + (0.5 * v * v);
}; };
DecisionTree<Key, double> error_tree(factors_, errorFunc); DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree; return error_tree;
@ -121,8 +159,24 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
/* *******************************************************************************/ /* *******************************************************************************/
double HybridGaussianFactor::error(const HybridValues &values) const { double HybridGaussianFactor::error(const HybridValues &values) const {
auto &&[gf, val] = factors_(values.discrete()); const sharedFactor gf = factors_(values.discrete());
return gf->error(values.continuous()) + val; return gf->error(values.continuous());
}
/* *******************************************************************************/
double ComputeLogNormalizer(
const noiseModel::Gaussian::shared_ptr &noise_model) {
// Since noise models are Gaussian, we can get the logDeterminant using
// the same trick as in GaussianConditional
double logDetR = noise_model->R()
.diagonal()
.unaryExpr([](double x) { return log(x); })
.sum();
double logDeterminantSigma = -2.0 * logDetR;
size_t n = noise_model->dim();
constexpr double log2pi = 1.8378770664093454835606594728112;
return n * log2pi + logDeterminantSigma;
} }
} // namespace gtsam } // namespace gtsam

View File

@ -55,8 +55,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
using sharedFactor = std::shared_ptr<GaussianFactor>; using sharedFactor = std::shared_ptr<GaussianFactor>;
/// typedef for Decision Tree of Gaussian factors and log-constant. /// typedef for Decision Tree of Gaussian factors and arbitrary value.
using Factors = DecisionTree<Key, GaussianFactorValuePair>; using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>;
/// typedef for Decision Tree of Gaussian factors.
using Factors = DecisionTree<Key, sharedFactor>;
private: private:
/// Decision tree of Gaussian factors indexed by discrete keys. /// Decision tree of Gaussian factors indexed by discrete keys.
@ -87,7 +89,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/ */
HybridGaussianFactor(const KeyVector &continuousKeys, HybridGaussianFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const Factors &factors); const FactorValuePairs &factors);
/** /**
* @brief Construct a new HybridGaussianFactor object using a vector of * @brief Construct a new HybridGaussianFactor object using a vector of
@ -102,7 +104,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
const DiscreteKeys &discreteKeys, const DiscreteKeys &discreteKeys,
const std::vector<GaussianFactorValuePair> &factors) const std::vector<GaussianFactorValuePair> &factors)
: HybridGaussianFactor(continuousKeys, discreteKeys, : HybridGaussianFactor(continuousKeys, discreteKeys,
Factors(discreteKeys, factors)) {} FactorValuePairs(discreteKeys, factors)) {}
/// @} /// @}
/// @name Testable /// @name Testable
@ -118,7 +120,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// @{ /// @{
/// Get the factor and scalar at a given discrete assignment. /// Get the factor and scalar at a given discrete assignment.
GaussianFactorValuePair operator()(const DiscreteValues &assignment) const; sharedFactor operator()(const DiscreteValues &assignment) const;
/** /**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
@ -173,4 +175,16 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
template <> template <>
struct traits<HybridGaussianFactor> : public Testable<HybridGaussianFactor> {}; struct traits<HybridGaussianFactor> : public Testable<HybridGaussianFactor> {};
/**
* @brief Helper function to compute the sqrt(|2πΣ|) normalizer values
* for a Gaussian noise model.
* We compute this in the log-space for numerical accuracy.
*
* @param noise_model The Gaussian noise model
* whose normalizer we wish to compute.
* @return double
*/
GTSAM_EXPORT double ComputeLogNormalizer(
const noiseModel::Gaussian::shared_ptr &noise_model);
} // namespace gtsam } // namespace gtsam