Avoid calculating negLogK twice
							parent
							
								
									8d4233587c
								
							
						
					
					
						commit
						1365a0904a
					
				| 
						 | 
					@ -57,10 +57,20 @@ template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using std::dynamic_pointer_cast;
 | 
					using std::dynamic_pointer_cast;
 | 
				
			||||||
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
 | 
					using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
 | 
				
			||||||
using Result =
 | 
					
 | 
				
			||||||
    std::pair<std::shared_ptr<GaussianConditional>, GaussianFactor::shared_ptr>;
 | 
					/// Result from elimination.
 | 
				
			||||||
using ResultValuePair = std::pair<Result, double>;
 | 
					struct Result {
 | 
				
			||||||
using ResultTree = DecisionTree<Key, ResultValuePair>;
 | 
					  GaussianConditional::shared_ptr conditional;
 | 
				
			||||||
 | 
					  double negLogK;
 | 
				
			||||||
 | 
					  GaussianFactor::shared_ptr factor;
 | 
				
			||||||
 | 
					  double scalar;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  bool operator==(const Result &other) const {
 | 
				
			||||||
 | 
					    return conditional == other.conditional && negLogK == other.negLogK &&
 | 
				
			||||||
 | 
					           factor == other.factor && scalar == other.scalar;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					using ResultTree = DecisionTree<Key, Result>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static const VectorValues kEmpty;
 | 
					static const VectorValues kEmpty;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -294,17 +304,14 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
 | 
				
			||||||
static std::shared_ptr<Factor> createDiscreteFactor(
 | 
					static std::shared_ptr<Factor> createDiscreteFactor(
 | 
				
			||||||
    const ResultTree &eliminationResults,
 | 
					    const ResultTree &eliminationResults,
 | 
				
			||||||
    const DiscreteKeys &discreteSeparator) {
 | 
					    const DiscreteKeys &discreteSeparator) {
 | 
				
			||||||
  auto calculateError = [&](const auto &pair) -> double {
 | 
					  auto calculateError = [&](const Result &result) -> double {
 | 
				
			||||||
    const auto &[conditional, factor] = pair.first;
 | 
					    if (result.conditional && result.factor) {
 | 
				
			||||||
    const double scalar = pair.second;
 | 
					 | 
				
			||||||
    if (conditional && factor) {
 | 
					 | 
				
			||||||
      // `error` has the following contributions:
 | 
					      // `error` has the following contributions:
 | 
				
			||||||
      // - the scalar is the sum of all mode-dependent constants
 | 
					      // - the scalar is the sum of all mode-dependent constants
 | 
				
			||||||
      // - factor->error(kempty) is the error remaining after elimination
 | 
					      // - factor->error(kempty) is the error remaining after elimination
 | 
				
			||||||
      // - negLogK is what is given to the conditional to normalize
 | 
					      // - negLogK is what is given to the conditional to normalize
 | 
				
			||||||
      const double negLogK = conditional->negLogConstant();
 | 
					      return result.scalar + result.factor->error(kEmpty) - result.negLogK;
 | 
				
			||||||
      return scalar + factor->error(kEmpty) - negLogK;
 | 
					    } else if (!result.conditional && !result.factor) {
 | 
				
			||||||
    } else if (!conditional && !factor) {
 | 
					 | 
				
			||||||
      // If the factor has been pruned, return infinite error
 | 
					      // If the factor has been pruned, return infinite error
 | 
				
			||||||
      return std::numeric_limits<double>::infinity();
 | 
					      return std::numeric_limits<double>::infinity();
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
| 
						 | 
					@ -323,13 +330,10 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
 | 
				
			||||||
    const ResultTree &eliminationResults,
 | 
					    const ResultTree &eliminationResults,
 | 
				
			||||||
    const DiscreteKeys &discreteSeparator) {
 | 
					    const DiscreteKeys &discreteSeparator) {
 | 
				
			||||||
  // Correct for the normalization constant used up by the conditional
 | 
					  // Correct for the normalization constant used up by the conditional
 | 
				
			||||||
  auto correct = [&](const ResultValuePair &pair) -> GaussianFactorValuePair {
 | 
					  auto correct = [&](const Result &result) -> GaussianFactorValuePair {
 | 
				
			||||||
    const auto &[conditional, factor] = pair.first;
 | 
					    if (result.conditional && result.factor) {
 | 
				
			||||||
    const double scalar = pair.second;
 | 
					      return {result.factor, result.scalar - result.negLogK};
 | 
				
			||||||
    if (conditional && factor) {
 | 
					    } else if (!result.conditional && !result.factor) {
 | 
				
			||||||
      const double negLogK = conditional->negLogConstant();
 | 
					 | 
				
			||||||
      return {factor, scalar - negLogK};
 | 
					 | 
				
			||||||
    } else if (!conditional && !factor) {
 | 
					 | 
				
			||||||
      return {nullptr, std::numeric_limits<double>::infinity()};
 | 
					      return {nullptr, std::numeric_limits<double>::infinity()};
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
 | 
					      throw std::runtime_error("createHybridGaussianFactors has mixed NULLs");
 | 
				
			||||||
| 
						 | 
					@ -370,23 +374,23 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // This is the elimination method on the leaf nodes
 | 
					  // This is the elimination method on the leaf nodes
 | 
				
			||||||
  bool someContinuousLeft = false;
 | 
					  bool someContinuousLeft = false;
 | 
				
			||||||
  auto eliminate = [&](const std::pair<GaussianFactorGraph, double> &pair)
 | 
					  auto eliminate =
 | 
				
			||||||
      -> std::pair<Result, double> {
 | 
					      [&](const std::pair<GaussianFactorGraph, double> &pair) -> Result {
 | 
				
			||||||
    const auto &[graph, scalar] = pair;
 | 
					    const auto &[graph, scalar] = pair;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (graph.empty()) {
 | 
					    if (graph.empty()) {
 | 
				
			||||||
      return {{nullptr, nullptr}, 0.0};
 | 
					      return {nullptr, 0.0, nullptr, 0.0};
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Expensive elimination of product factor.
 | 
					    // Expensive elimination of product factor.
 | 
				
			||||||
    auto result =
 | 
					    auto [conditional, factor] =
 | 
				
			||||||
        EliminatePreferCholesky(graph, keys);  /// <<<<<< MOST COMPUTE IS HERE
 | 
					        EliminatePreferCholesky(graph, keys);  /// <<<<<< MOST COMPUTE IS HERE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Record whether there any continuous variables left
 | 
					    // Record whether there any continuous variables left
 | 
				
			||||||
    someContinuousLeft |= !result.second->empty();
 | 
					    someContinuousLeft |= !factor->empty();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // We pass on the scalar unmodified.
 | 
					    // We pass on the scalar unmodified.
 | 
				
			||||||
    return {result, scalar};
 | 
					    return {conditional, conditional->negLogConstant(), factor, scalar};
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Perform elimination!
 | 
					  // Perform elimination!
 | 
				
			||||||
| 
						 | 
					@ -400,12 +404,13 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
 | 
				
			||||||
          ? createHybridGaussianFactor(eliminationResults, discreteSeparator)
 | 
					          ? createHybridGaussianFactor(eliminationResults, discreteSeparator)
 | 
				
			||||||
          : createDiscreteFactor(eliminationResults, discreteSeparator);
 | 
					          : createDiscreteFactor(eliminationResults, discreteSeparator);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Create the HybridGaussianConditional from the conditionals
 | 
					  // Create the HybridGaussianConditional without re-calculating constants:
 | 
				
			||||||
  HybridGaussianConditional::Conditionals conditionals(
 | 
					  HybridGaussianConditional::FactorValuePairs pairs(
 | 
				
			||||||
      eliminationResults,
 | 
					      eliminationResults, [](const Result &result) -> GaussianFactorValuePair {
 | 
				
			||||||
      [](const ResultValuePair &pair) { return pair.first.first; });
 | 
					        return {result.conditional, result.negLogK};
 | 
				
			||||||
  auto hybridGaussian = std::make_shared<HybridGaussianConditional>(
 | 
					      });
 | 
				
			||||||
      discreteSeparator, conditionals);
 | 
					  auto hybridGaussian =
 | 
				
			||||||
 | 
					      std::make_shared<HybridGaussianConditional>(discreteSeparator, pairs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
 | 
					  return {std::make_shared<HybridConditional>(hybridGaussian), newFactor};
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue