add new visitLeaf method that provides the leaf as the function argument
parent
e81e04acf5
commit
039ecfc3c3
|
|
@ -715,6 +715,42 @@ namespace gtsam {
|
||||||
visit(root_);
|
visit(root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
/**
|
||||||
|
* Functor performing depth-first visit with Leaf argument.
|
||||||
|
*
|
||||||
|
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
|
||||||
|
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
|
||||||
|
* can have <8 leaves. For example, if a tree has all assignment values as 1,
|
||||||
|
* then pruning will cause the tree to have only 1 leaf yet 8 assignments.
|
||||||
|
*/
|
||||||
|
template <typename L, typename Y>
|
||||||
|
struct VisitLeaf {
|
||||||
|
using F = std::function<void(const typename DecisionTree<L, Y>::Leaf&)>;
|
||||||
|
explicit VisitLeaf(F f) : f(f) {} ///< Construct from folding function.
|
||||||
|
F f; ///< folding function object.
|
||||||
|
|
||||||
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
|
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
|
||||||
|
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||||
|
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
|
||||||
|
return f(*leaf);
|
||||||
|
|
||||||
|
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||||
|
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||||
|
if (!choice)
|
||||||
|
throw std::invalid_argument("DecisionTree::VisitLeaf: Invalid NodePtr");
|
||||||
|
for (auto&& branch : choice->branches()) (*this)(branch); // recurse!
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename L, typename Y>
|
||||||
|
template <typename Func>
|
||||||
|
void DecisionTree<L, Y>::visitLeaf(Func f) const {
|
||||||
|
VisitLeaf<L, Y> visit(f);
|
||||||
|
visit(root_);
|
||||||
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
/**
|
/**
|
||||||
* Functor performing depth-first visit with Assignment<L> argument.
|
* Functor performing depth-first visit with Assignment<L> argument.
|
||||||
|
|
|
||||||
|
|
@ -248,6 +248,23 @@ namespace gtsam {
|
||||||
template <typename Func>
|
template <typename Func>
|
||||||
void visit(Func f) const;
|
void visit(Func f) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Visit all leaves in depth-first fashion.
|
||||||
|
*
|
||||||
|
* @param f (side-effect) Function taking the leaf node pointer.
|
||||||
|
*
|
||||||
|
* @note Due to pruning, the number of leaves may not be the same as the
|
||||||
|
* number of assignments. E.g. if we have a tree on 2 binary variables with
|
||||||
|
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* int sum = 0;
|
||||||
|
* auto visitor = [&](int y) { sum += y; };
|
||||||
|
* tree.visitWith(visitor);
|
||||||
|
*/
|
||||||
|
template <typename Func>
|
||||||
|
void visitLeaf(Func f) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Visit all leaves in depth-first fashion.
|
* @brief Visit all leaves in depth-first fashion.
|
||||||
*
|
*
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue