conditionalError method

release/4.3a0
Varun Agrawal 2024-09-04 15:18:27 -04:00
parent 997d0b411b
commit cc04003716
2 changed files with 23 additions and 34 deletions

View File

@ -330,10 +330,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
return DecisionTree<Key, double>(conditionals_, probFunc); return DecisionTree<Key, double>(conditionals_, probFunc);
} }
/* *******************************************************************************/ /* ************************************************************************* */
AlgebraicDecisionTree<Key> GaussianMixture::errorTree( double GaussianMixture::conditionalError(
const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
// Check if valid pointer // Check if valid pointer
if (conditional) { if (conditional) {
return conditional->error(continuousValues) + // return conditional->error(continuousValues) + //
@ -341,8 +341,17 @@ AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
} else { } else {
// If not valid, pointer, it means this conditional was pruned, // If not valid, pointer, it means this conditional was pruned,
// so we return maximum error. // so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max(); return std::numeric_limits<double>::max();
} }
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditionalError(conditional, continuousValues);
}; };
DecisionTree<Key, double> error_tree(conditionals_, errorFunc); DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree; return error_tree;
@ -350,33 +359,9 @@ AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
/* *******************************************************************************/ /* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const { double GaussianMixture::error(const HybridValues &values) const {
// Check if discrete keys in discrete assignment are
// present in the GaussianMixture
KeyVector dKeys = this->discreteKeys_.indices();
bool valid_assignment = false;
for (auto &&kv : values.discrete()) {
if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) {
valid_assignment = true;
break;
}
}
// The discrete assignment is not valid so we throw an error.
if (!valid_assignment) {
throw std::runtime_error(
"Invalid discrete values in values. Not all discrete keys specified.");
}
// Directly index to get the conditional, no need to build the whole tree. // Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete()); auto conditional = conditionals_(values.discrete());
if (conditional) { return conditionalError(conditional, values.continuous());
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
return std::numeric_limits<double>::max();
}
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -256,6 +256,10 @@ class GTSAM_EXPORT GaussianMixture
/// Check whether `given` has values for all frontal keys. /// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const; bool allFrontalsGiven(const VectorValues &given) const;
/// Helper method to compute the error of a conditional.
double conditionalError(const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const;
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */ /** Serialization function */
friend class boost::serialization::access; friend class boost::serialization::access;