Merge pull request #1151 from borglab/feature/decision-tree-factor-prune
						commit
						0850e89b0e
					
				| 
						 | 
					@ -286,5 +286,39 @@ namespace gtsam {
 | 
				
			||||||
        AlgebraicDecisionTree<Key>(keys, table),
 | 
					        AlgebraicDecisionTree<Key>(keys, table),
 | 
				
			||||||
        cardinalities_(keys.cardinalities()) {}
 | 
					        cardinalities_(keys.cardinalities()) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /* ************************************************************************ */
 | 
				
			||||||
 | 
					  DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const {
 | 
				
			||||||
 | 
					    const size_t N = maxNrLeaves;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Get the probabilities in the decision tree so we can threshold.
 | 
				
			||||||
 | 
					    std::vector<double> probabilities;
 | 
				
			||||||
 | 
					    this->visit([&](const double& prob) { probabilities.emplace_back(prob); });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // The number of probabilities can be lower than max_leaves
 | 
				
			||||||
 | 
					    if (probabilities.size() <= N) {
 | 
				
			||||||
 | 
					      return *this;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::sort(probabilities.begin(), probabilities.end(),
 | 
				
			||||||
 | 
					              std::greater<double>{});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    double threshold = probabilities[N - 1];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Now threshold the decision tree
 | 
				
			||||||
 | 
					    size_t total = 0;
 | 
				
			||||||
 | 
					    auto thresholdFunc = [threshold, &total, N](const double& value) {
 | 
				
			||||||
 | 
					      if (value < threshold || total >= N) {
 | 
				
			||||||
 | 
					        return 0.0;
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        total += 1;
 | 
				
			||||||
 | 
					        return value;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					    DecisionTree<Key, double> thresholded(*this, thresholdFunc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Create pruned decision tree factor and return.
 | 
				
			||||||
 | 
					    return DecisionTreeFactor(this->discreteKeys(), thresholded);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /* ************************************************************************ */
 | 
					  /* ************************************************************************ */
 | 
				
			||||||
}  // namespace gtsam
 | 
					}  // namespace gtsam
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -170,6 +170,18 @@ namespace gtsam {
 | 
				
			||||||
    /// Return all the discrete keys associated with this factor.
 | 
					    /// Return all the discrete keys associated with this factor.
 | 
				
			||||||
    DiscreteKeys discreteKeys() const;
 | 
					    DiscreteKeys discreteKeys() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    /**
 | 
				
			||||||
 | 
					     * @brief Prune the decision tree of discrete variables.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * Pruning will set the leaves to be "pruned" to 0 indicating a 0
 | 
				
			||||||
 | 
					     * probability.
 | 
				
			||||||
 | 
					     * A leaf is pruned if it is not in the top `maxNrLeaves` values.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * @param maxNrLeaves The maximum number of leaves to keep.
 | 
				
			||||||
 | 
					     * @return DecisionTreeFactor
 | 
				
			||||||
 | 
					     */
 | 
				
			||||||
 | 
					    DecisionTreeFactor prune(size_t maxNrLeaves) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /// @}
 | 
					    /// @}
 | 
				
			||||||
    /// @name Wrapper support
 | 
					    /// @name Wrapper support
 | 
				
			||||||
    /// @{
 | 
					    /// @{
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -106,6 +106,27 @@ TEST(DecisionTreeFactor, enumerate) {
 | 
				
			||||||
  EXPECT(actual == expected);
 | 
					  EXPECT(actual == expected);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 | 
					// Check pruning of the decision tree works as expected.
 | 
				
			||||||
 | 
					TEST(DecisionTreeFactor, Prune) {
 | 
				
			||||||
 | 
					  DiscreteKey A(1, 2), B(2, 2), C(3, 2);
 | 
				
			||||||
 | 
					  DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Only keep the leaves with the top 5 values.
 | 
				
			||||||
 | 
					  size_t maxNrLeaves = 5;
 | 
				
			||||||
 | 
					  auto pruned5 = f.prune(maxNrLeaves);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Pruned leaves should be 0
 | 
				
			||||||
 | 
					  DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
 | 
				
			||||||
 | 
					  EXPECT(assert_equal(expected, pruned5));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Check for more extreme pruning where we only keep the top 2 leaves
 | 
				
			||||||
 | 
					  maxNrLeaves = 2;
 | 
				
			||||||
 | 
					  auto pruned2 = f.prune(maxNrLeaves);
 | 
				
			||||||
 | 
					  DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
 | 
				
			||||||
 | 
					  EXPECT(assert_equal(expected2, pruned2));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* ************************************************************************* */
 | 
					/* ************************************************************************* */
 | 
				
			||||||
TEST(DecisionTreeFactor, DotWithNames) {
 | 
					TEST(DecisionTreeFactor, DotWithNames) {
 | 
				
			||||||
  DiscreteKey A(12, 3), B(5, 2);
 | 
					  DiscreteKey A(12, 3), B(5, 2);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue