refactor DecisionTree to make a distinction between leaves and assignments

release/4.3a0
Varun Agrawal 2022-03-31 06:10:40 -04:00
parent cc7f4992b7
commit d5d5ecc3b3
2 changed files with 80 additions and 63 deletions

View File

@ -59,7 +59,7 @@ namespace gtsam {
/** constant stored in this leaf */ /** constant stored in this leaf */
Y constant_; Y constant_;
/** The number of assignments contained within this leaf /** The number of assignments contained within this leaf.
* Particularly useful when leaves have been pruned. * Particularly useful when leaves have been pruned.
*/ */
size_t nrAssignments_; size_t nrAssignments_;
@ -68,7 +68,7 @@ namespace gtsam {
Leaf(const Y& constant, size_t nrAssignments = 1) Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {} : constant_(constant), nrAssignments_(nrAssignments) {}
/** return the constant */ /// Return the constant
const Y& constant() const { const Y& constant() const {
return constant_; return constant_;
} }
@ -81,19 +81,19 @@ namespace gtsam {
return constant_ == q.constant_; return constant_ == q.constant_;
} }
/// polymorphic equality: is q is a leaf, could be /// polymorphic equality: is q a leaf and is it the same as this leaf?
bool sameLeaf(const Node& q) const override { bool sameLeaf(const Node& q) const override {
return (q.isLeaf() && q.sameLeaf(*this)); return (q.isLeaf() && q.sameLeaf(*this));
} }
/** equality up to tolerance */ /// equality up to tolerance
bool equals(const Node& q, const CompareFunc& compare) const override { bool equals(const Node& q, const CompareFunc& compare) const override {
const Leaf* other = dynamic_cast<const Leaf*>(&q); const Leaf* other = dynamic_cast<const Leaf*>(&q);
if (!other) return false; if (!other) return false;
return compare(this->constant_, other->constant_); return compare(this->constant_, other->constant_);
} }
/** print */ /// print
void print(const std::string& s, const LabelFormatter& labelFormatter, void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override { const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
@ -122,8 +122,8 @@ namespace gtsam {
/// Apply unary operator with assignment /// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op, NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override { const Assignment<L>& assignment) const override {
NodePtr f(new Leaf(op(choices, constant_), nrAssignments_)); NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
return f; return f;
} }
@ -168,7 +168,10 @@ namespace gtsam {
std::vector<NodePtr> branches_; std::vector<NodePtr> branches_;
private: private:
/** incremental allSame */ /**
* Incremental allSame.
* Records if all the branches are the same leaf.
*/
size_t allSame_; size_t allSame_;
using ChoicePtr = boost::shared_ptr<const Choice>; using ChoicePtr = boost::shared_ptr<const Choice>;
@ -181,7 +184,7 @@ namespace gtsam {
#endif #endif
} }
/** If all branches of a choice node f are the same, just return a branch */ /// If all branches of a choice node f are the same, just return a branch.
static NodePtr Unique(const ChoicePtr& f) { static NodePtr Unique(const ChoicePtr& f) {
#ifndef DT_NO_PRUNING #ifndef DT_NO_PRUNING
if (f->allSame_) { if (f->allSame_) {
@ -205,15 +208,13 @@ namespace gtsam {
bool isLeaf() const override { return false; } bool isLeaf() const override { return false; }
/** Constructor, given choice label and mandatory expected branch count */ /// Constructor, given choice label and mandatory expected branch count.
Choice(const L& label, size_t count) : Choice(const L& label, size_t count) :
label_(label), allSame_(true) { label_(label), allSame_(true) {
branches_.reserve(count); branches_.reserve(count);
} }
/** /// Construct from applying binary op to two Choice nodes.
* Construct from applying binary op to two Choice nodes
*/
Choice(const Choice& f, const Choice& g, const Binary& op) : Choice(const Choice& f, const Choice& g, const Binary& op) :
allSame_(true) { allSame_(true) {
// Choose what to do based on label // Choose what to do based on label
@ -241,6 +242,7 @@ namespace gtsam {
} }
} }
/// Return the label of this choice node.
const L& label() const { const L& label() const {
return label_; return label_;
} }
@ -262,7 +264,7 @@ namespace gtsam {
branches_.push_back(node); branches_.push_back(node);
} }
/** print (as a tree) */ /// print (as a tree).
void print(const std::string& s, const LabelFormatter& labelFormatter, void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override { const ValueFormatter& valueFormatter) const override {
std::cout << s << " Choice("; std::cout << s << " Choice(";
@ -308,7 +310,7 @@ namespace gtsam {
return (q.isLeaf() && q.sameLeaf(*this)); return (q.isLeaf() && q.sameLeaf(*this));
} }
/** equality */ /// equality
bool equals(const Node& q, const CompareFunc& compare) const override { bool equals(const Node& q, const CompareFunc& compare) const override {
const Choice* other = dynamic_cast<const Choice*>(&q); const Choice* other = dynamic_cast<const Choice*>(&q);
if (!other) return false; if (!other) return false;
@ -321,7 +323,7 @@ namespace gtsam {
return true; return true;
} }
/** evaluate */ /// evaluate
const Y& operator()(const Assignment<L>& x) const override { const Y& operator()(const Assignment<L>& x) const override {
#ifndef NDEBUG #ifndef NDEBUG
typename Assignment<L>::const_iterator it = x.find(label_); typename Assignment<L>::const_iterator it = x.find(label_);
@ -336,13 +338,13 @@ namespace gtsam {
return (*child)(x); return (*child)(x);
} }
/** /// Construct from applying unary op to a Choice node.
* Construct from applying unary op to a Choice node
*/
Choice(const L& label, const Choice& f, const Unary& op) : Choice(const L& label, const Choice& f, const Unary& op) :
label_(label), allSame_(true) { label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space branches_.reserve(f.branches_.size()); // reserve space
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op)); for (const NodePtr& branch : f.branches_) {
push_back(branch->apply(op));
}
} }
/** /**
@ -353,28 +355,28 @@ namespace gtsam {
* @param f The original choice node to apply the op on. * @param f The original choice node to apply the op on.
* @param op Function to apply on the choice node. Takes Assignment and * @param op Function to apply on the choice node. Takes Assignment and
* value as arguments. * value as arguments.
* @param choices The Assignment that will go to op. * @param assignment The Assignment that will go to op.
*/ */
Choice(const L& label, const Choice& f, const UnaryAssignment& op, Choice(const L& label, const Choice& f, const UnaryAssignment& op,
const Assignment<L>& choices) const Assignment<L>& assignment)
: label_(label), allSame_(true) { : label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space branches_.reserve(f.branches_.size()); // reserve space
Assignment<L> choices_ = choices; Assignment<L> assignment_ = assignment;
for (size_t i = 0; i < f.branches_.size(); i++) { for (size_t i = 0; i < f.branches_.size(); i++) {
choices_[label_] = i; // Set assignment for label to i assignment_[label_] = i; // Set assignment for label to i
const NodePtr branch = f.branches_[i]; const NodePtr branch = f.branches_[i];
push_back(branch->apply(op, choices_)); push_back(branch->apply(op, assignment_));
// Remove the choice so we are backtracking // Remove the assignment so we are backtracking
auto choice_it = choices_.find(label_); auto assignment_it = assignment_.find(label_);
choices_.erase(choice_it); assignment_.erase(assignment_it);
} }
} }
/** apply unary operator */ /// apply unary operator.
NodePtr apply(const Unary& op) const override { NodePtr apply(const Unary& op) const override {
auto r = boost::make_shared<Choice>(label_, *this, op); auto r = boost::make_shared<Choice>(label_, *this, op);
return Unique(r); return Unique(r);
@ -382,8 +384,8 @@ namespace gtsam {
/// Apply unary operator with assignment /// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op, NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override { const Assignment<L>& assignment) const override {
auto r = boost::make_shared<Choice>(label_, *this, op, choices); auto r = boost::make_shared<Choice>(label_, *this, op, assignment);
return Unique(r); return Unique(r);
} }
@ -678,7 +680,14 @@ namespace gtsam {
} }
/****************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument. /**
* Functor performing depth-first visit without Assignment<L> argument.
*
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
* can have <8 leaves. For example, if a tree has all assignment values as 1,
* then pruning will cause the tree to have only 1 leaf yet 8 assignments.
*/
template <typename L, typename Y> template <typename L, typename Y>
struct Visit { struct Visit {
using F = std::function<void(const Y&)>; using F = std::function<void(const Y&)>;
@ -707,33 +716,36 @@ namespace gtsam {
} }
/****************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument. /**
* Functor performing depth-first visit with Assignment<L> argument.
*
* NOTE: Follows the same pruning semantics as `visit`.
*/
template <typename L, typename Y> template <typename L, typename Y>
struct VisitWith { struct VisitWith {
using Choices = Assignment<L>; using F = std::function<void(const Assignment<L>&, const Y&)>;
using F = std::function<void(const Choices&, const Y&)>;
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function. explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion. Assignment<L> assignment; ///< Assignment, mutating through recursion.
F f; ///< folding function object. F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node. /// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) { void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
using Leaf = typename DecisionTree<L, Y>::Leaf; using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node)) if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(choices, leaf->constant()); return f(assignment, leaf->constant());
using Choice = typename DecisionTree<L, Y>::Choice; using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node); auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice) if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr"); throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) { for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i assignment[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse! (*this)(choice->branches()[i]); // recurse!
// Remove the choice so we are backtracking // Remove the choice so we are backtracking
auto choice_it = choices.find(choice->label()); auto choice_it = assignment.find(choice->label());
choices.erase(choice_it); assignment.erase(choice_it);
} }
} }
}; };
@ -763,12 +775,14 @@ namespace gtsam {
} }
/****************************************************************************/ /****************************************************************************/
// labels is just done with a visit // Get (partial) labels by performing a visit.
template <typename L, typename Y> template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const { std::set<L> DecisionTree<L, Y>::labels() const {
std::set<L> unique; std::set<L> unique;
auto f = [&](const Assignment<L>& choices, const Y&) { auto f = [&](const Assignment<L>& assignment, const Y&) {
for (auto&& kv : choices) unique.insert(kv.first); for (auto&& kv : assignment) {
unique.insert(kv.first);
}
}; };
visitWith(f); visitWith(f);
return unique; return unique;
@ -817,8 +831,8 @@ namespace gtsam {
throw std::runtime_error( throw std::runtime_error(
"DecisionTree::apply(unary op) undefined for empty tree."); "DecisionTree::apply(unary op) undefined for empty tree.");
} }
Assignment<L> choices; Assignment<L> assignment;
return DecisionTree(root_->apply(op, choices)); return DecisionTree(root_->apply(op, assignment));
} }
/****************************************************************************/ /****************************************************************************/

View File

@ -105,7 +105,7 @@ namespace gtsam {
virtual const Y& operator()(const Assignment<L>& x) const = 0; virtual const Y& operator()(const Assignment<L>& x) const = 0;
virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply(const Unary& op) const = 0;
virtual Ptr apply(const UnaryAssignment& op, virtual Ptr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const = 0; const Assignment<L>& assignment) const = 0;
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0; virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
@ -153,7 +153,7 @@ namespace gtsam {
/** Create a constant */ /** Create a constant */
explicit DecisionTree(const Y& y); explicit DecisionTree(const Y& y);
/** Create a new leaf function splitting on a variable */ /// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
DecisionTree(const L& label, const Y& y1, const Y& y2); DecisionTree(const L& label, const Y& y1, const Y& y2);
/** Allow Label+Cardinality for convenience */ /** Allow Label+Cardinality for convenience */
@ -219,9 +219,8 @@ namespace gtsam {
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
/** Make virtual */ /// Make virtual
virtual ~DecisionTree() { virtual ~DecisionTree() {}
}
/// Check if tree is empty. /// Check if tree is empty.
bool empty() const { return !root_; } bool empty() const { return !root_; }
@ -235,9 +234,11 @@ namespace gtsam {
/** /**
* @brief Visit all leaves in depth-first fashion. * @brief Visit all leaves in depth-first fashion.
* *
* @param f side-effect taking a value. * @param f (side-effect) Function taking a value.
* *
* @note Due to pruning, leaves might not exhaust choices. * @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
* *
* Example: * Example:
* int sum = 0; * int sum = 0;
@ -250,13 +251,15 @@ namespace gtsam {
/** /**
* @brief Visit all leaves in depth-first fashion. * @brief Visit all leaves in depth-first fashion.
* *
* @param f side-effect taking an assignment and a value. * @param f (side-effect) Function taking an assignment and a value.
* *
* @note Due to pruning, leaves might not exhaust choices. * @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
* *
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; }; * auto visitor = [&](const Assignment<L>& assignment, int y) { sum += y; };
* tree.visitWith(visitor); * tree.visitWith(visitor);
*/ */
template <typename Func> template <typename Func>