efficient probabilities method
parent
2b85cfedd4
commit
3d24d0128f
|
@ -197,9 +197,39 @@ namespace gtsam {
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
std::vector<double> DecisionTreeFactor::probabilities() const {
|
std::vector<double> DecisionTreeFactor::probabilities() const {
|
||||||
std::vector<double> probs;
|
std::vector<double> probs;
|
||||||
for (auto&& [key, value] : enumerate()) {
|
|
||||||
probs.push_back(value);
|
// Get all the key cardinalities
|
||||||
|
std::map<Key, size_t> cardins;
|
||||||
|
for (auto [key, cardinality] : discreteKeys()) {
|
||||||
|
cardins[key] = cardinality;
|
||||||
}
|
}
|
||||||
|
// Set of all keys
|
||||||
|
std::set<Key> allKeys(keys().begin(), keys().end());
|
||||||
|
|
||||||
|
// Go through the tree
|
||||||
|
std::vector<double> ys;
|
||||||
|
this->apply([&](const Assignment<Key> a, double p) {
|
||||||
|
// Get all the keys in the current assignment
|
||||||
|
std::set<Key> assignment_keys;
|
||||||
|
for (auto&& [k, _] : a) {
|
||||||
|
assignment_keys.insert(k);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the keys missing in the assignment
|
||||||
|
std::vector<Key> diff;
|
||||||
|
std::set_difference(allKeys.begin(), allKeys.end(),
|
||||||
|
assignment_keys.begin(), assignment_keys.end(),
|
||||||
|
std::back_inserter(diff));
|
||||||
|
|
||||||
|
// Compute the total number of assignments in the (pruned) subtree
|
||||||
|
size_t nrAssignments = 1;
|
||||||
|
for (auto&& k : diff) {
|
||||||
|
nrAssignments *= cardins.at(k);
|
||||||
|
}
|
||||||
|
probs.insert(probs.end(), nrAssignments, p);
|
||||||
|
return p;
|
||||||
|
});
|
||||||
|
|
||||||
return probs;
|
return probs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -313,11 +343,7 @@ namespace gtsam {
|
||||||
const size_t N = maxNrAssignments;
|
const size_t N = maxNrAssignments;
|
||||||
|
|
||||||
// Get the probabilities in the decision tree so we can threshold.
|
// Get the probabilities in the decision tree so we can threshold.
|
||||||
std::vector<double> probabilities;
|
std::vector<double> probabilities = this->probabilities();
|
||||||
// NOTE(Varun) this is potentially slow due to the cartesian product
|
|
||||||
for (auto&& [assignment, prob] : this->enumerate()) {
|
|
||||||
probabilities.push_back(prob);
|
|
||||||
}
|
|
||||||
|
|
||||||
// The number of probabilities can be lower than max_leaves
|
// The number of probabilities can be lower than max_leaves
|
||||||
if (probabilities.size() <= N) {
|
if (probabilities.size() <= N) {
|
||||||
|
|
Loading…
Reference in New Issue