diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index fa4248921..325c32f95 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -330,19 +330,28 @@ AlgebraicDecisionTree GaussianMixture::logProbability( return DecisionTree(conditionals_, probFunc); } +/* ************************************************************************* */ +double GaussianMixture::conditionalError( + const GaussianConditional::shared_ptr &conditional, + const VectorValues &continuousValues) const { + // Check if valid pointer + if (conditional) { + return conditional->error(continuousValues) + // + logConstant_ - conditional->logNormalizationConstant(); + } else { + // If not valid, pointer, it means this conditional was pruned, + // so we return maximum error. + // This way the negative exponential will give + // a probability value close to 0.0. + return std::numeric_limits::max(); + } +} + /* *******************************************************************************/ AlgebraicDecisionTree GaussianMixture::errorTree( const VectorValues &continuousValues) const { auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) { - // Check if valid pointer - if (conditional) { - return conditional->error(continuousValues) + // - logConstant_ - conditional->logNormalizationConstant(); - } else { - // If not valid, pointer, it means this conditional was pruned, - // so we return maximum error. - return std::numeric_limits::max(); - } + return conditionalError(conditional, continuousValues); }; DecisionTree error_tree(conditionals_, errorFunc); return error_tree; @@ -350,33 +359,9 @@ AlgebraicDecisionTree GaussianMixture::errorTree( /* *******************************************************************************/ double GaussianMixture::error(const HybridValues &values) const { - // Check if discrete keys in discrete assignment are - // present in the GaussianMixture - KeyVector dKeys = this->discreteKeys_.indices(); - bool valid_assignment = false; - for (auto &&kv : values.discrete()) { - if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) { - valid_assignment = true; - break; - } - } - - // The discrete assignment is not valid so we throw an error. - if (!valid_assignment) { - throw std::runtime_error( - "Invalid discrete values in values. Not all discrete keys specified."); - } - // Directly index to get the conditional, no need to build the whole tree. auto conditional = conditionals_(values.discrete()); - if (conditional) { - return conditional->error(values.continuous()) + // - logConstant_ - conditional->logNormalizationConstant(); - } else { - // If not valid, pointer, it means this conditional was pruned, - // so we return maximum error. - return std::numeric_limits::max(); - } + return conditionalError(conditional, values.continuous()); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 41692bb1c..714c00272 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -256,6 +256,10 @@ class GTSAM_EXPORT GaussianMixture /// Check whether `given` has values for all frontal keys. bool allFrontalsGiven(const VectorValues &given) const; + /// Helper method to compute the error of a conditional. + double conditionalError(const GaussianConditional::shared_ptr &conditional, + const VectorValues &continuousValues) const; + #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ friend class boost::serialization::access;