conditionalError method
parent
997d0b411b
commit
cc04003716
|
@ -330,19 +330,28 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
|||
return DecisionTree<Key, double>(conditionals_, probFunc);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
double GaussianMixture::conditionalError(
|
||||
const GaussianConditional::shared_ptr &conditional,
|
||||
const VectorValues &continuousValues) const {
|
||||
// Check if valid pointer
|
||||
if (conditional) {
|
||||
return conditional->error(continuousValues) + //
|
||||
logConstant_ - conditional->logNormalizationConstant();
|
||||
} else {
|
||||
// If not valid, pointer, it means this conditional was pruned,
|
||||
// so we return maximum error.
|
||||
// This way the negative exponential will give
|
||||
// a probability value close to 0.0.
|
||||
return std::numeric_limits<double>::max();
|
||||
}
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
|
||||
const VectorValues &continuousValues) const {
|
||||
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
// Check if valid pointer
|
||||
if (conditional) {
|
||||
return conditional->error(continuousValues) + //
|
||||
logConstant_ - conditional->logNormalizationConstant();
|
||||
} else {
|
||||
// If not valid, pointer, it means this conditional was pruned,
|
||||
// so we return maximum error.
|
||||
return std::numeric_limits<double>::max();
|
||||
}
|
||||
return conditionalError(conditional, continuousValues);
|
||||
};
|
||||
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
|
||||
return error_tree;
|
||||
|
@ -350,33 +359,9 @@ AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
|
|||
|
||||
/* *******************************************************************************/
|
||||
double GaussianMixture::error(const HybridValues &values) const {
|
||||
// Check if discrete keys in discrete assignment are
|
||||
// present in the GaussianMixture
|
||||
KeyVector dKeys = this->discreteKeys_.indices();
|
||||
bool valid_assignment = false;
|
||||
for (auto &&kv : values.discrete()) {
|
||||
if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) {
|
||||
valid_assignment = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// The discrete assignment is not valid so we throw an error.
|
||||
if (!valid_assignment) {
|
||||
throw std::runtime_error(
|
||||
"Invalid discrete values in values. Not all discrete keys specified.");
|
||||
}
|
||||
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
auto conditional = conditionals_(values.discrete());
|
||||
if (conditional) {
|
||||
return conditional->error(values.continuous()) + //
|
||||
logConstant_ - conditional->logNormalizationConstant();
|
||||
} else {
|
||||
// If not valid, pointer, it means this conditional was pruned,
|
||||
// so we return maximum error.
|
||||
return std::numeric_limits<double>::max();
|
||||
}
|
||||
return conditionalError(conditional, values.continuous());
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -256,6 +256,10 @@ class GTSAM_EXPORT GaussianMixture
|
|||
/// Check whether `given` has values for all frontal keys.
|
||||
bool allFrontalsGiven(const VectorValues &given) const;
|
||||
|
||||
/// Helper method to compute the error of a conditional.
|
||||
double conditionalError(const GaussianConditional::shared_ptr &conditional,
|
||||
const VectorValues &continuousValues) const;
|
||||
|
||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
|
|
Loading…
Reference in New Issue