conditionalError method

release/4.3a0
Varun Agrawal 2024-09-04 15:18:27 -04:00
parent 997d0b411b
commit cc04003716
2 changed files with 23 additions and 34 deletions

View File

@ -330,19 +330,28 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
return DecisionTree<Key, double>(conditionals_, probFunc);
}
/* ************************************************************************* */
double GaussianMixture::conditionalError(
const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const {
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
return std::numeric_limits<double>::max();
}
return conditionalError(conditional, continuousValues);
};
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree;
@ -350,33 +359,9 @@ AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
/* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const {
// Check if discrete keys in discrete assignment are
// present in the GaussianMixture
KeyVector dKeys = this->discreteKeys_.indices();
bool valid_assignment = false;
for (auto &&kv : values.discrete()) {
if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) {
valid_assignment = true;
break;
}
}
// The discrete assignment is not valid so we throw an error.
if (!valid_assignment) {
throw std::runtime_error(
"Invalid discrete values in values. Not all discrete keys specified.");
}
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
if (conditional) {
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
return std::numeric_limits<double>::max();
}
return conditionalError(conditional, values.continuous());
}
/* *******************************************************************************/

View File

@ -256,6 +256,10 @@ class GTSAM_EXPORT GaussianMixture
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;
/// Helper method to compute the error of a conditional.
double conditionalError(const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const;
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;