diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 4aa5e5759..19cc8e230 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -197,9 +197,39 @@ namespace gtsam { /* ************************************************************************ */ std::vector DecisionTreeFactor::probabilities() const { std::vector probs; - for (auto&& [key, value] : enumerate()) { - probs.push_back(value); + + // Get all the key cardinalities + std::map cardins; + for (auto [key, cardinality] : discreteKeys()) { + cardins[key] = cardinality; } + // Set of all keys + std::set allKeys(keys().begin(), keys().end()); + + // Go through the tree + std::vector ys; + this->apply([&](const Assignment a, double p) { + // Get all the keys in the current assignment + std::set assignment_keys; + for (auto&& [k, _] : a) { + assignment_keys.insert(k); + } + + // Find the keys missing in the assignment + std::vector diff; + std::set_difference(allKeys.begin(), allKeys.end(), + assignment_keys.begin(), assignment_keys.end(), + std::back_inserter(diff)); + + // Compute the total number of assignments in the (pruned) subtree + size_t nrAssignments = 1; + for (auto&& k : diff) { + nrAssignments *= cardins.at(k); + } + probs.insert(probs.end(), nrAssignments, p); + return p; + }); + return probs; } @@ -313,11 +343,7 @@ namespace gtsam { const size_t N = maxNrAssignments; // Get the probabilities in the decision tree so we can threshold. - std::vector probabilities; - // NOTE(Varun) this is potentially slow due to the cartesian product - for (auto&& [assignment, prob] : this->enumerate()) { - probabilities.push_back(prob); - } + std::vector probabilities = this->probabilities(); // The number of probabilities can be lower than max_leaves if (probabilities.size() <= N) {