conditionalError method
parent
997d0b411b
commit
cc04003716
|
@ -330,10 +330,10 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
||||||
return DecisionTree<Key, double>(conditionals_, probFunc);
|
return DecisionTree<Key, double>(conditionals_, probFunc);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
|
double GaussianMixture::conditionalError(
|
||||||
|
const GaussianConditional::shared_ptr &conditional,
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
|
|
||||||
// Check if valid pointer
|
// Check if valid pointer
|
||||||
if (conditional) {
|
if (conditional) {
|
||||||
return conditional->error(continuousValues) + //
|
return conditional->error(continuousValues) + //
|
||||||
|
@ -341,8 +341,17 @@ AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
|
||||||
} else {
|
} else {
|
||||||
// If not valid, pointer, it means this conditional was pruned,
|
// If not valid, pointer, it means this conditional was pruned,
|
||||||
// so we return maximum error.
|
// so we return maximum error.
|
||||||
|
// This way the negative exponential will give
|
||||||
|
// a probability value close to 0.0.
|
||||||
return std::numeric_limits<double>::max();
|
return std::numeric_limits<double>::max();
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
|
||||||
|
const VectorValues &continuousValues) const {
|
||||||
|
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
|
||||||
|
return conditionalError(conditional, continuousValues);
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
|
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
|
||||||
return error_tree;
|
return error_tree;
|
||||||
|
@ -350,33 +359,9 @@ AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double GaussianMixture::error(const HybridValues &values) const {
|
double GaussianMixture::error(const HybridValues &values) const {
|
||||||
// Check if discrete keys in discrete assignment are
|
|
||||||
// present in the GaussianMixture
|
|
||||||
KeyVector dKeys = this->discreteKeys_.indices();
|
|
||||||
bool valid_assignment = false;
|
|
||||||
for (auto &&kv : values.discrete()) {
|
|
||||||
if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) {
|
|
||||||
valid_assignment = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// The discrete assignment is not valid so we throw an error.
|
|
||||||
if (!valid_assignment) {
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Invalid discrete values in values. Not all discrete keys specified.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Directly index to get the conditional, no need to build the whole tree.
|
// Directly index to get the conditional, no need to build the whole tree.
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto conditional = conditionals_(values.discrete());
|
||||||
if (conditional) {
|
return conditionalError(conditional, values.continuous());
|
||||||
return conditional->error(values.continuous()) + //
|
|
||||||
logConstant_ - conditional->logNormalizationConstant();
|
|
||||||
} else {
|
|
||||||
// If not valid, pointer, it means this conditional was pruned,
|
|
||||||
// so we return maximum error.
|
|
||||||
return std::numeric_limits<double>::max();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
@ -256,6 +256,10 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
/// Check whether `given` has values for all frontal keys.
|
/// Check whether `given` has values for all frontal keys.
|
||||||
bool allFrontalsGiven(const VectorValues &given) const;
|
bool allFrontalsGiven(const VectorValues &given) const;
|
||||||
|
|
||||||
|
/// Helper method to compute the error of a conditional.
|
||||||
|
double conditionalError(const GaussianConditional::shared_ptr &conditional,
|
||||||
|
const VectorValues &continuousValues) const;
|
||||||
|
|
||||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
/** Serialization function */
|
/** Serialization function */
|
||||||
friend class boost::serialization::access;
|
friend class boost::serialization::access;
|
||||||
|
|
Loading…
Reference in New Issue