Avoid calculating negLogK twice
							parent
							
								
									8d4233587c
								
							
						
					
					
						commit
						1365a0904a
					
				| 
						 | 
				
			
			@ -57,10 +57,20 @@ template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
 | 
			
		|||
 | 
			
		||||
using std::dynamic_pointer_cast;
 | 
			
		||||
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
 | 
			
		||||
using Result =
 | 
			
		||||
    std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
 | 
			
		||||
using ResultValuePair = std::pair<Result, double>;
 | 
			
		||||
using ResultTree = DecisionTree<Key, ResultValuePair>;
 | 
			
		||||
 | 
			
		||||
/// Result from elimination.
 | 
			
		||||
struct Result {
 | 
			
		||||
  GaussianConditional::shared_ptr conditional;
 | 
			
		||||
  double negLogK;
 | 
			
		||||
  GaussianFactor::shared_ptr factor;
 | 
			
		||||
  double scalar;
 | 
			
		||||
 | 
			
		||||
  bool operator==(const Result &other) const {
 | 
			
		||||
    return conditional == other.conditional && negLogK == other.negLogK &&
 | 
			
		||||
           factor == other.factor && scalar == other.scalar;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
using ResultTree = DecisionTree<Key, Result>;
 | 
			
		||||
 | 
			
		||||
static const VectorValues kEmpty;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -294,17 +304,14 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
 | 
			
		|||
static std::shared_ptr<Factor> createDiscreteFactor(
 | 
			
		||||
    const ResultTree &eliminationResults,
 | 
			
		||||
    const DiscreteKeys &discreteSeparator) {
 | 
			
		||||
  auto calculateError = [&](const auto &pair) -> double {
 | 
			
		||||
    const auto &[conditional, factor] = pair.first;
 | 
			
		||||
    const double scalar = pair.second;
 | 
			
		||||
    if (conditional && factor) {
 | 
			
		||||
  auto calculateError = [&](const Result &result) -> double {
 | 
			
		||||
    if (result.conditional && result.factor) {
 | 
			
		||||
      // `error` has the following contributions:
 | 
			
		||||
      // - the scalar is the sum of all mode-dependent constants
 | 
			
		||||
      // - factor->error(kempty) is the error remaining after elimination
 | 
			
		||||
      // - negLogK is what is given to the conditional to normalize
 | 
			
		||||
      const double negLogK = conditional->negLogConstant();
 | 
			
		||||
      return scalar + factor->error(kEmpty) - negLogK;
 | 
			
		||||
    } else if (!conditional && !factor) {
 | 
			
		||||
      return result.scalar + result.factor->error(kEmpty) - result.negLogK;
 | 
			
		||||
    } else if (!result.conditional && !result.factor) {
 | 
			
		||||
      // If the factor has been pruned, return infinite error
 | 
			
		||||
      return std::numeric_limits<double>::infinity();
 | 
			
		||||
    } else {
 | 
			
		||||
| 
						 | 
				
			
			@ -323,13 +330,10 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
 | 
			
		|||
    const ResultTree &eliminationResults,
 | 
			
		||||
    const DiscreteKeys &discreteSeparator) {
 | 
			
		||||
  // Correct for the normalization constant used up by the conditional
 | 
			
		||||
  auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair {
 | 
			
		||||
    const auto &[conditional, factor] = pair.first;
 | 
			
		||||
    const double scalar = pair.second;
 | 
			
		||||
    if (conditional && factor) {
 | 
			
		||||
      const double negLogK = conditional->negLogConstant();
 | 
			
		||||
      return {factor, scalar - negLogK};
 | 
			
		||||
    } else if (!conditional && !factor) {
 | 
			
		||||
  auto correct = [&](const Result &result) -> GaussianFactorValuePair {
 | 
			
		||||
    if (result.conditional && result.factor) {
 | 
			
		||||
      return {result.factor, result.scalar - result.negLogK};
 | 
			
		||||
    } else if (!result.conditional && !result.factor) {
 | 
			
		||||
      return {nullptr, std::numeric_limits<double>::infinity()};
 | 
			
		||||
    } else {
 | 
			
		||||
      throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
 | 
			
		||||
| 
						 | 
				
			
			@ -370,23 +374,23 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
 | 
			
		|||
 | 
			
		||||
  // This is the elimination method on the leaf nodes
 | 
			
		||||
  bool someContinuousLeft = false;
 | 
			
		||||
  auto eliminate = [&](const std::pair<GaussianFactorGraph, double> &pair)
 | 
			
		||||
      -> std::pair<Result, double> {
 | 
			
		||||
  auto eliminate =
 | 
			
		||||
      [&](const std::pair<GaussianFactorGraph, double> &pair) -> Result {
 | 
			
		||||
    const auto &[graph, scalar] = pair;
 | 
			
		||||
 | 
			
		||||
    if (graph.empty()) {
 | 
			
		||||
      return {{nullptr, nullptr}, 0.0};
 | 
			
		||||
      return {nullptr, 0.0, nullptr, 0.0};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Expensive elimination of product factor.
 | 
			
		||||
    auto result =
 | 
			
		||||
    auto [conditional, factor] =
 | 
			
		||||
        EliminatePreferCholesky(graph, keys);  /// <<<<<< MOST COMPUTE IS HERE
 | 
			
		||||
 | 
			
		||||
    // Record whether there any continuous variables left
 | 
			
		||||
    someContinuousLeft |= !result.second->empty();
 | 
			
		||||
    someContinuousLeft |= !factor->empty();
 | 
			
		||||
 | 
			
		||||
    // We pass on the scalar unmodified.
 | 
			
		||||
    return {result, scalar};
 | 
			
		||||
    return {conditional, conditional->negLogConstant(), factor, scalar};
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // Perform elimination!
 | 
			
		||||
| 
						 | 
				
			
			@ -400,12 +404,13 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
 | 
			
		|||
          ? createHybridGaussianFactor(eliminationResults, discreteSeparator)
 | 
			
		||||
          : createDiscreteFactor(eliminationResults, discreteSeparator);
 | 
			
		||||
 | 
			
		||||
  // Create the HybridGaussianConditional from the conditionals
 | 
			
		||||
  HybridGaussianConditional::Conditionals conditionals(
 | 
			
		||||
      eliminationResults,
 | 
			
		||||
      [](const ResultValuePair &pair) { return pair.first.first; });
 | 
			
		||||
  auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
 | 
			
		||||
      discreteSeparator, conditionals);
 | 
			
		||||
  // Create the HybridGaussianConditional without re-calculating constants:
 | 
			
		||||
  HybridGaussianConditional::FactorValuePairs pairs(
 | 
			
		||||
      eliminationResults, [](const Result &result) -> GaussianFactorValuePair {
 | 
			
		||||
        return {result.conditional, result.negLogK};
 | 
			
		||||
      });
 | 
			
		||||
  auto hybridGaussian =
 | 
			
		||||
      std::make_shared<HybridGaussianConditional>(discreteSeparator, pairs);
 | 
			
		||||
 | 
			
		||||
  return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue