From 039ecfc3c3bf67232ffe4da524e02ba65018fe4c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Thu, 31 Mar 2022 10:03:23 -0400 Subject: [PATCH] add new visitLeaf method that provides the leaf as the function argument --- gtsam/discrete/DecisionTree-inl.h | 36 +++++++++++++++++++++++++++++++ gtsam/discrete/DecisionTree.h | 17 +++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index ed345461c..b3ad66721 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -715,6 +715,42 @@ namespace gtsam { visit(root_); } + /****************************************************************************/ + /** + * Functor performing depth-first visit with Leaf argument. + * + * NOTE: We differentiate between leaves and assignments. Concretely, a 3 + * binary variable tree will have 2^3=8 assignments, but based on pruning, it + * can have <8 leaves. For example, if a tree has all assignment values as 1, + * then pruning will cause the tree to have only 1 leaf yet 8 assignments. + */ + template + struct VisitLeaf { + using F = std::function::Leaf&)>; + explicit VisitLeaf(F f) : f(f) {} ///< Construct from folding function. + F f; ///< folding function object. + + /// Do a depth-first visit on the tree rooted at node. + void operator()(const typename DecisionTree::NodePtr& node) const { + using Leaf = typename DecisionTree::Leaf; + if (auto leaf = boost::dynamic_pointer_cast(node)) + return f(*leaf); + + using Choice = typename DecisionTree::Choice; + auto choice = boost::dynamic_pointer_cast(node); + if (!choice) + throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr"); + for (auto&& branch : choice->branches()) (*this)(branch); // recurse! + } + }; + + template + template + void DecisionTree::visitLeaf(Func f) const { + VisitLeaf visit(f); + visit(root_); + } + /****************************************************************************/ /** * Functor performing depth-first visit with Assignment argument. diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 9520d43bc..1f45d320b 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -248,6 +248,23 @@ namespace gtsam { template void visit(Func f) const; + /** + * @brief Visit all leaves in depth-first fashion. + * + * @param f (side-effect) Function taking the leaf node pointer. + * + * @note Due to pruning, the number of leaves may not be the same as the + * number of assignments. E.g. if we have a tree on 2 binary variables with + * all values being 1, then there are 2^2=4 assignments, but only 1 leaf. + * + * Example: + * int sum = 0; + * auto visitor = [&](int y) { sum += y; }; + * tree.visitWith(visitor); + */ + template + void visitLeaf(Func f) const; + /** * @brief Visit all leaves in depth-first fashion. *