Avoid using slow conditionals()
							parent
							
								
									6c9b25c45e
								
							
						
					
					
						commit
						1fe09f5e09
					
				|  | @ -197,11 +197,11 @@ void HybridGaussianConditional::print(const std::string &s, | |||
|   std::cout << std::endl | ||||
|             << " logNormalizationConstant: " << -negLogConstant() << std::endl | ||||
|             << std::endl; | ||||
|   conditionals().print( | ||||
|   factors().print( | ||||
|       "", [&](Key k) { return formatter(k); }, | ||||
|       [&](const GaussianConditional::shared_ptr &gf) -> std::string { | ||||
|       [&](const GaussianFactorValuePair &pair) -> std::string { | ||||
|         RedirectCout rd; | ||||
|         if (gf && !gf->empty()) { | ||||
|         if (auto gf = std::dynamic_pointer_cast<GaussianConditional>(pair.first)) { | ||||
|           gf->print("", formatter); | ||||
|           return rd.str(); | ||||
|         } else { | ||||
|  | @ -249,12 +249,18 @@ std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood( | |||
|   const DiscreteKeys discreteParentKeys = discreteKeys(); | ||||
|   const KeyVector continuousParentKeys = continuousParents(); | ||||
|   const HybridGaussianFactor::FactorValuePairs likelihoods( | ||||
|       conditionals(), | ||||
|       [&](const GaussianConditional::shared_ptr &conditional) | ||||
|           -> GaussianFactorValuePair { | ||||
|         const auto likelihood_m = conditional->likelihood(given); | ||||
|         const double Cgm_Kgcm = conditional->negLogConstant() - negLogConstant_; | ||||
|         return {likelihood_m, Cgm_Kgcm}; | ||||
|       factors(), | ||||
|       [&](const GaussianFactorValuePair &pair) -> GaussianFactorValuePair { | ||||
|         if (auto conditional = | ||||
|                 std::dynamic_pointer_cast<GaussianConditional>(pair.first)) { | ||||
|           const auto likelihood_m = conditional->likelihood(given); | ||||
|           // scalar is already correct.
 | ||||
|           assert(pair.second == | ||||
|                  conditional->negLogConstant() - negLogConstant_); | ||||
|           return {likelihood_m, pair.second}; | ||||
|         } else { | ||||
|           return {nullptr, std::numeric_limits<double>::infinity()}; | ||||
|         } | ||||
|       }); | ||||
|   return std::make_shared<HybridGaussianFactor>(discreteParentKeys, | ||||
|                                                 likelihoods); | ||||
|  | @ -283,15 +289,19 @@ HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( | |||
| 
 | ||||
|   // Check the max value for every combination of our keys.
 | ||||
|   // If the max value is 0.0, we can prune the corresponding conditional.
 | ||||
|   auto pruner = [&](const Assignment<Key> &choices, | ||||
|                     const GaussianConditional::shared_ptr &conditional) | ||||
|       -> GaussianConditional::shared_ptr { | ||||
|     return (max->evaluate(choices) == 0.0) ? nullptr : conditional; | ||||
|   auto pruner = | ||||
|       [&](const Assignment<Key> &choices, | ||||
|           const GaussianFactorValuePair &pair) -> GaussianFactorValuePair { | ||||
|     if (max->evaluate(choices) == 0.0) | ||||
|       return {nullptr, std::numeric_limits<double>::infinity()}; | ||||
|     else | ||||
|       return pair; | ||||
|   }; | ||||
| 
 | ||||
|   auto pruned_conditionals = conditionals().apply(pruner); | ||||
|   return std::make_shared<HybridGaussianConditional>(discreteKeys(), | ||||
|                                                      pruned_conditionals); | ||||
|   FactorValuePairs prunedConditionals = factors().apply(pruner); | ||||
|   return std::shared_ptr<HybridGaussianConditional>( | ||||
|       new HybridGaussianConditional(discreteKeys(), nrFrontals_, | ||||
|                                     prunedConditionals, negLogConstant_)); | ||||
| } | ||||
| 
 | ||||
| /* *******************************************************************************/ | ||||
|  |  | |||
|  | @ -191,6 +191,7 @@ class GTSAM_EXPORT HybridGaussianConditional | |||
|       const VectorValues &given) const; | ||||
| 
 | ||||
|   /// Get Conditionals DecisionTree (dynamic cast from factors)
 | ||||
|   /// @note Slow: avoid using in favor of factors(), which uses existing tree.
 | ||||
|   const Conditionals conditionals() const; | ||||
| 
 | ||||
|   /**
 | ||||
|  | @ -229,6 +230,14 @@ class GTSAM_EXPORT HybridGaussianConditional | |||
|   HybridGaussianConditional(const DiscreteKeys &discreteParents, | ||||
|                             const Helper &helper); | ||||
| 
 | ||||
|   /// Private constructor used when constants have already been calculated.
 | ||||
|   HybridGaussianConditional(const DiscreteKeys &discreteKeys, int nrFrontals, | ||||
|                             const FactorValuePairs &factors, | ||||
|                             double negLogConstant) | ||||
|       : BaseFactor(discreteKeys, factors), | ||||
|         BaseConditional(nrFrontals), | ||||
|         negLogConstant_(negLogConstant) {} | ||||
| 
 | ||||
|   /// Check whether `given` has values for all frontal keys.
 | ||||
|   bool allFrontalsGiven(const VectorValues &given) const; | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue