Store the values
							parent
							
								
									acccef8024
								
							
						
					
					
						commit
						3797996e89
					
				| 
						 | 
				
			
			@ -27,14 +27,16 @@
 | 
			
		|||
#include <gtsam/linear/GaussianFactor.h>
 | 
			
		||||
#include <gtsam/linear/GaussianFactorGraph.h>
 | 
			
		||||
 | 
			
		||||
#include "gtsam/base/types.h"
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
HybridGaussianFactor::Factors HybridGaussianFactor::augment(
 | 
			
		||||
HybridGaussianFactor::FactorValuePairs HybridGaussianFactor::augment(
 | 
			
		||||
    const FactorValuePairs &factors) {
 | 
			
		||||
  // Find the minimum value so we can "proselytize" to positive values.
 | 
			
		||||
  // Done because we can't have sqrt of negative numbers.
 | 
			
		||||
  Factors gaussianFactors;
 | 
			
		||||
  DecisionTree<Key, GaussianFactor::shared_ptr> gaussianFactors;
 | 
			
		||||
  AlgebraicDecisionTree<Key> valueTree;
 | 
			
		||||
  std::tie(gaussianFactors, valueTree) = unzip(factors);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -42,16 +44,16 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment(
 | 
			
		|||
  double min_value = valueTree.min();
 | 
			
		||||
 | 
			
		||||
  // Finally, update the [A|b] matrices.
 | 
			
		||||
  auto update = [&min_value](const GaussianFactorValuePair &gfv) {
 | 
			
		||||
  auto update = [&min_value](const auto &gfv) -> GaussianFactorValuePair {
 | 
			
		||||
    auto [gf, value] = gfv;
 | 
			
		||||
 | 
			
		||||
    auto jf = std::dynamic_pointer_cast<JacobianFactor>(gf);
 | 
			
		||||
    if (!jf) return gf;
 | 
			
		||||
    if (!jf) return {gf, 0.0};  // should this be zero or infinite?
 | 
			
		||||
 | 
			
		||||
    double normalized_value = value - min_value;
 | 
			
		||||
 | 
			
		||||
    // If the value is 0, do nothing
 | 
			
		||||
    if (normalized_value == 0.0) return gf;
 | 
			
		||||
    if (normalized_value == 0.0) return {gf, 0.0};
 | 
			
		||||
 | 
			
		||||
    GaussianFactorGraph gfg;
 | 
			
		||||
    gfg.push_back(jf);
 | 
			
		||||
| 
						 | 
				
			
			@ -62,18 +64,16 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment(
 | 
			
		|||
    auto constantFactor = std::make_shared<JacobianFactor>(c);
 | 
			
		||||
 | 
			
		||||
    gfg.push_back(constantFactor);
 | 
			
		||||
    return std::dynamic_pointer_cast<GaussianFactor>(
 | 
			
		||||
        std::make_shared<JacobianFactor>(gfg));
 | 
			
		||||
    return {std::make_shared<JacobianFactor>(gfg), normalized_value};
 | 
			
		||||
  };
 | 
			
		||||
  return Factors(factors, update);
 | 
			
		||||
  return FactorValuePairs(factors, update);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
struct HybridGaussianFactor::ConstructorHelper {
 | 
			
		||||
  KeyVector continuousKeys;   // Continuous keys extracted from factors
 | 
			
		||||
  DiscreteKeys discreteKeys;  // Discrete keys provided to the constructors
 | 
			
		||||
  FactorValuePairs pairs;     // Used only if factorsTree is empty
 | 
			
		||||
  Factors factorsTree;
 | 
			
		||||
  FactorValuePairs pairs;     // The decision tree with factors and scalars
 | 
			
		||||
 | 
			
		||||
  ConstructorHelper(const DiscreteKey &discreteKey,
 | 
			
		||||
                    const std::vector<GaussianFactor::shared_ptr> &factors)
 | 
			
		||||
| 
						 | 
				
			
			@ -85,9 +85,10 @@ struct HybridGaussianFactor::ConstructorHelper {
 | 
			
		|||
        break;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Build the DecisionTree from the factor vector
 | 
			
		||||
    factorsTree = Factors(discreteKeys, factors);
 | 
			
		||||
    // Build the FactorValuePairs DecisionTree
 | 
			
		||||
    pairs = FactorValuePairs(
 | 
			
		||||
        DecisionTree<Key, GaussianFactor::shared_ptr>(discreteKeys, factors),
 | 
			
		||||
        [](const auto &f) { return std::pair{f, 0.0}; });
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ConstructorHelper(const DiscreteKey &discreteKey,
 | 
			
		||||
| 
						 | 
				
			
			@ -109,6 +110,7 @@ struct HybridGaussianFactor::ConstructorHelper {
 | 
			
		|||
                    const FactorValuePairs &factorPairs)
 | 
			
		||||
      : discreteKeys(discreteKeys) {
 | 
			
		||||
    // Extract continuous keys from the first non-null factor
 | 
			
		||||
    // TODO: just stop after first non-null factor
 | 
			
		||||
    factorPairs.visit([&](const GaussianFactorValuePair &pair) {
 | 
			
		||||
      if (pair.first && continuousKeys.empty()) {
 | 
			
		||||
        continuousKeys = pair.first->keys();
 | 
			
		||||
| 
						 | 
				
			
			@ -123,14 +125,13 @@ struct HybridGaussianFactor::ConstructorHelper {
 | 
			
		|||
/* *******************************************************************************/
 | 
			
		||||
HybridGaussianFactor::HybridGaussianFactor(const ConstructorHelper &helper)
 | 
			
		||||
    : Base(helper.continuousKeys, helper.discreteKeys),
 | 
			
		||||
      factors_(helper.factorsTree.empty() ? augment(helper.pairs)
 | 
			
		||||
                                          : helper.factorsTree) {}
 | 
			
		||||
      factors_(augment(helper.pairs)) {}
 | 
			
		||||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
HybridGaussianFactor::HybridGaussianFactor(
 | 
			
		||||
    const DiscreteKey &discreteKey,
 | 
			
		||||
    const std::vector<GaussianFactor::shared_ptr> &factors)
 | 
			
		||||
    : HybridGaussianFactor(ConstructorHelper(discreteKey, factors)) {}
 | 
			
		||||
    const std::vector<GaussianFactor::shared_ptr> &factorPairs)
 | 
			
		||||
    : HybridGaussianFactor(ConstructorHelper(discreteKey, factorPairs)) {}
 | 
			
		||||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
HybridGaussianFactor::HybridGaussianFactor(
 | 
			
		||||
| 
						 | 
				
			
			@ -140,8 +141,8 @@ HybridGaussianFactor::HybridGaussianFactor(
 | 
			
		|||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
 | 
			
		||||
                                           const FactorValuePairs &factors)
 | 
			
		||||
    : HybridGaussianFactor(ConstructorHelper(discreteKeys, factors)) {}
 | 
			
		||||
                                           const FactorValuePairs &factorPairs)
 | 
			
		||||
    : HybridGaussianFactor(ConstructorHelper(discreteKeys, factorPairs)) {}
 | 
			
		||||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
 | 
			
		||||
| 
						 | 
				
			
			@ -153,10 +154,12 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
 | 
			
		|||
  if (factors_.empty() ^ e->factors_.empty()) return false;
 | 
			
		||||
 | 
			
		||||
  // Check the base and the factors:
 | 
			
		||||
  return Base::equals(*e, tol) &&
 | 
			
		||||
         factors_.equals(e->factors_, [tol](const auto &f1, const auto &f2) {
 | 
			
		||||
           return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
 | 
			
		||||
         });
 | 
			
		||||
  auto compareFunc = [tol](const auto &pair1, const auto &pair2) {
 | 
			
		||||
  	auto f1 = pair1.first, f2 = pair2.first;
 | 
			
		||||
    bool match = (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
 | 
			
		||||
    return match && gtsam::equal(pair1.second, pair2.second, tol);
 | 
			
		||||
  };
 | 
			
		||||
  return Base::equals(*e, tol) && factors_.equals(e->factors_, compareFunc);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
| 
						 | 
				
			
			@ -171,15 +174,16 @@ void HybridGaussianFactor::print(const std::string &s,
 | 
			
		|||
  } else {
 | 
			
		||||
    factors_.print(
 | 
			
		||||
        "", [&](Key k) { return formatter(k); },
 | 
			
		||||
        [&](const sharedFactor &gf) -> std::string {
 | 
			
		||||
        [&](const auto &pair) -> std::string {
 | 
			
		||||
          RedirectCout rd;
 | 
			
		||||
          std::cout << ":\n";
 | 
			
		||||
          if (gf) {
 | 
			
		||||
            gf->print("", formatter);
 | 
			
		||||
          if (pair.first) {
 | 
			
		||||
            pair.first->print("", formatter);
 | 
			
		||||
            return rd.str();
 | 
			
		||||
          } else {
 | 
			
		||||
            return "nullptr";
 | 
			
		||||
          }
 | 
			
		||||
          std::cout << "scalar: " << pair.second << "\n";
 | 
			
		||||
        });
 | 
			
		||||
  }
 | 
			
		||||
  std::cout << "}" << std::endl;
 | 
			
		||||
| 
						 | 
				
			
			@ -188,7 +192,7 @@ void HybridGaussianFactor::print(const std::string &s,
 | 
			
		|||
/* *******************************************************************************/
 | 
			
		||||
HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()(
 | 
			
		||||
    const DiscreteValues &assignment) const {
 | 
			
		||||
  return factors_(assignment);
 | 
			
		||||
  return factors_(assignment).first;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* *******************************************************************************/
 | 
			
		||||
| 
						 | 
				
			
			@ -207,7 +211,7 @@ GaussianFactorGraphTree HybridGaussianFactor::add(
 | 
			
		|||
/* *******************************************************************************/
 | 
			
		||||
GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
 | 
			
		||||
    const {
 | 
			
		||||
  auto wrap = [](const sharedFactor &gf) { return GaussianFactorGraph{gf}; };
 | 
			
		||||
  auto wrap = [](const auto &pair) { return GaussianFactorGraph{pair.first}; };
 | 
			
		||||
  return {factors_, wrap};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -229,8 +233,8 @@ static double PotentiallyPrunedComponentError(
 | 
			
		|||
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
 | 
			
		||||
    const VectorValues &continuousValues) const {
 | 
			
		||||
  // functor to convert from sharedFactor to double error value.
 | 
			
		||||
  auto errorFunc = [&continuousValues](const sharedFactor &gf) {
 | 
			
		||||
    return PotentiallyPrunedComponentError(gf, continuousValues);
 | 
			
		||||
  auto errorFunc = [this, &continuousValues](const auto &pair) {
 | 
			
		||||
    return PotentiallyPrunedComponentError(pair.first, continuousValues);
 | 
			
		||||
  };
 | 
			
		||||
  DecisionTree<Key, double> error_tree(factors_, errorFunc);
 | 
			
		||||
  return error_tree;
 | 
			
		||||
| 
						 | 
				
			
			@ -239,8 +243,8 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
 | 
			
		|||
/* *******************************************************************************/
 | 
			
		||||
double HybridGaussianFactor::error(const HybridValues &values) const {
 | 
			
		||||
  // Directly index to get the component, no need to build the whole tree.
 | 
			
		||||
  const sharedFactor gf = factors_(values.discrete());
 | 
			
		||||
  return PotentiallyPrunedComponentError(gf, values.continuous());
 | 
			
		||||
  const auto pair = factors_(values.discrete());
 | 
			
		||||
  return PotentiallyPrunedComponentError(pair.first, values.continuous());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -66,12 +66,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
 | 
			
		|||
 | 
			
		||||
  /// typedef for Decision Tree of Gaussian factors and arbitrary value.
 | 
			
		||||
  using FactorValuePairs = DecisionTree<Key, GaussianFactorValuePair>;
 | 
			
		||||
  /// typedef for Decision Tree of Gaussian factors.
 | 
			
		||||
  using Factors = DecisionTree<Key, sharedFactor>;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  /// Decision tree of Gaussian factors indexed by discrete keys.
 | 
			
		||||
  Factors factors_;
 | 
			
		||||
  FactorValuePairs factors_;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  /// @name Constructors
 | 
			
		||||
| 
						 | 
				
			
			@ -110,10 +108,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
 | 
			
		|||
   * The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m.
 | 
			
		||||
   *
 | 
			
		||||
   * @param discreteKeys Discrete variables and their cardinalities.
 | 
			
		||||
   * @param factors The decision tree of Gaussian factor/scalar pairs.
 | 
			
		||||
   * @param factorPairs The decision tree of Gaussian factor/scalar pairs.
 | 
			
		||||
   */
 | 
			
		||||
  HybridGaussianFactor(const DiscreteKeys &discreteKeys,
 | 
			
		||||
                       const FactorValuePairs &factors);
 | 
			
		||||
                       const FactorValuePairs &factorPairs);
 | 
			
		||||
 | 
			
		||||
  /// @}
 | 
			
		||||
  /// @name Testable
 | 
			
		||||
| 
						 | 
				
			
			@ -158,7 +156,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
 | 
			
		|||
  double error(const HybridValues &values) const override;
 | 
			
		||||
 | 
			
		||||
  /// Getter for GaussianFactor decision tree
 | 
			
		||||
  const Factors &factors() const { return factors_; }
 | 
			
		||||
  const FactorValuePairs &factors() const { return factors_; }
 | 
			
		||||
 | 
			
		||||
  /// Add HybridNonlinearFactor to a Sum, syntactic sugar.
 | 
			
		||||
  friend GaussianFactorGraphTree &operator+=(
 | 
			
		||||
| 
						 | 
				
			
			@ -184,10 +182,9 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
 | 
			
		|||
   * value in the `b` vector as an additional row.
 | 
			
		||||
   *
 | 
			
		||||
   * @param factors DecisionTree of GaussianFactors and arbitrary scalars.
 | 
			
		||||
   * Gaussian factor in factors.
 | 
			
		||||
   * @return HybridGaussianFactor::Factors
 | 
			
		||||
   * @return FactorValuePairs
 | 
			
		||||
   */
 | 
			
		||||
  static Factors augment(const FactorValuePairs &factors);
 | 
			
		||||
  static FactorValuePairs augment(const FactorValuePairs &factors);
 | 
			
		||||
 | 
			
		||||
  /// Helper struct to assist private constructor below.
 | 
			
		||||
  struct ConstructorHelper;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -238,8 +238,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
 | 
			
		|||
    } else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
 | 
			
		||||
      // Case where we have a HybridGaussianFactor with no continuous keys.
 | 
			
		||||
      // In this case, compute discrete probabilities.
 | 
			
		||||
      auto logProbability =
 | 
			
		||||
          [&](const GaussianFactor::shared_ptr &factor) -> double {
 | 
			
		||||
      auto logProbability = [&](const auto &pair) -> double {
 | 
			
		||||
        auto [factor, _] = pair;
 | 
			
		||||
        if (!factor) return 0.0;
 | 
			
		||||
        return factor->error(VectorValues());
 | 
			
		||||
      };
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -196,8 +196,8 @@ std::shared_ptr<HybridGaussianFactor> HybridNonlinearFactor::linearize(
 | 
			
		|||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  DecisionTree<Key, std::pair<GaussianFactor::shared_ptr, double>>
 | 
			
		||||
      linearized_factors(factors_, linearizeDT);
 | 
			
		||||
  HybridGaussianFactor::FactorValuePairs linearized_factors(factors_,
 | 
			
		||||
                                                            linearizeDT);
 | 
			
		||||
 | 
			
		||||
  return std::make_shared<HybridGaussianFactor>(discreteKeys_,
 | 
			
		||||
                                                linearized_factors);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -52,11 +52,11 @@ BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf");
 | 
			
		|||
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
 | 
			
		||||
 | 
			
		||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor, "gtsam_HybridGaussianFactor");
 | 
			
		||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors,
 | 
			
		||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs,
 | 
			
		||||
                        "gtsam_HybridGaussianFactor_Factors");
 | 
			
		||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Leaf,
 | 
			
		||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs::Leaf,
 | 
			
		||||
                        "gtsam_HybridGaussianFactor_Factors_Leaf");
 | 
			
		||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Choice,
 | 
			
		||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::FactorValuePairs::Choice,
 | 
			
		||||
                        "gtsam_HybridGaussianFactor_Factors_Choice");
 | 
			
		||||
 | 
			
		||||
BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue