refactor DecisionTree to make a distinction between leaves and assignments

release/4.3a0
Varun Agrawal 2022-03-31 06:10:40 -04:00
parent cc7f4992b7
commit d5d5ecc3b3
2 changed files with 80 additions and 63 deletions

View File

@ -59,7 +59,7 @@ namespace gtsam {
/** constant stored in this leaf */
Y constant_;
/** The number of assignments contained within this leaf
/** The number of assignments contained within this leaf.
* Particularly useful when leaves have been pruned.
*/
size_t nrAssignments_;
@ -68,7 +68,7 @@ namespace gtsam {
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}
/** return the constant */
/// Return the constant
const Y& constant() const {
return constant_;
}
@ -81,19 +81,19 @@ namespace gtsam {
return constant_ == q.constant_;
}
/// polymorphic equality: is q is a leaf, could be
/// polymorphic equality: is q a leaf and is it the same as this leaf?
bool sameLeaf(const Node& q) const override {
return (q.isLeaf() && q.sameLeaf(*this));
}
/** equality up to tolerance */
/// equality up to tolerance
bool equals(const Node& q, const CompareFunc& compare) const override {
const Leaf* other = dynamic_cast<const Leaf*>(&q);
if (!other) return false;
return compare(this->constant_, other->constant_);
}
/** print */
/// print
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
@ -122,8 +122,8 @@ namespace gtsam {
/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
NodePtr f(new Leaf(op(choices, constant_), nrAssignments_));
const Assignment<L>& assignment) const override {
NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
return f;
}
@ -168,7 +168,10 @@ namespace gtsam {
std::vector<NodePtr> branches_;
private:
/** incremental allSame */
/**
* Incremental allSame.
* Records if all the branches are the same leaf.
*/
size_t allSame_;
using ChoicePtr = boost::shared_ptr<const Choice>;
@ -181,7 +184,7 @@ namespace gtsam {
#endif
}
/** If all branches of a choice node f are the same, just return a branch */
/// If all branches of a choice node f are the same, just return a branch.
static NodePtr Unique(const ChoicePtr& f) {
#ifndef DT_NO_PRUNING
if (f->allSame_) {
@ -205,15 +208,13 @@ namespace gtsam {
bool isLeaf() const override { return false; }
/** Constructor, given choice label and mandatory expected branch count */
/// Constructor, given choice label and mandatory expected branch count.
Choice(const L& label, size_t count) :
label_(label), allSame_(true) {
branches_.reserve(count);
}
/**
* Construct from applying binary op to two Choice nodes
*/
/// Construct from applying binary op to two Choice nodes.
Choice(const Choice& f, const Choice& g, const Binary& op) :
allSame_(true) {
// Choose what to do based on label
@ -241,6 +242,7 @@ namespace gtsam {
}
}
/// Return the label of this choice node.
const L& label() const {
return label_;
}
@ -262,7 +264,7 @@ namespace gtsam {
branches_.push_back(node);
}
/** print (as a tree) */
/// print (as a tree).
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Choice(";
@ -308,7 +310,7 @@ namespace gtsam {
return (q.isLeaf() && q.sameLeaf(*this));
}
/** equality */
/// equality
bool equals(const Node& q, const CompareFunc& compare) const override {
const Choice* other = dynamic_cast<const Choice*>(&q);
if (!other) return false;
@ -321,7 +323,7 @@ namespace gtsam {
return true;
}
/** evaluate */
/// evaluate
const Y& operator()(const Assignment<L>& x) const override {
#ifndef NDEBUG
typename Assignment<L>::const_iterator it = x.find(label_);
@ -336,13 +338,13 @@ namespace gtsam {
return (*child)(x);
}
/**
* Construct from applying unary op to a Choice node
*/
/// Construct from applying unary op to a Choice node.
Choice(const L& label, const Choice& f, const Unary& op) :
label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
for (const NodePtr& branch : f.branches_) {
push_back(branch->apply(op));
}
}
/**
@ -353,28 +355,28 @@ namespace gtsam {
* @param f The original choice node to apply the op on.
* @param op Function to apply on the choice node. Takes Assignment and
* value as arguments.
* @param choices The Assignment that will go to op.
* @param assignment The Assignment that will go to op.
*/
Choice(const L& label, const Choice& f, const UnaryAssignment& op,
const Assignment<L>& choices)
const Assignment<L>& assignment)
: label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space
Assignment<L> choices_ = choices;
Assignment<L> assignment_ = assignment;
for (size_t i = 0; i < f.branches_.size(); i++) {
choices_[label_] = i; // Set assignment for label to i
assignment_[label_] = i; // Set assignment for label to i
const NodePtr branch = f.branches_[i];
push_back(branch->apply(op, choices_));
push_back(branch->apply(op, assignment_));
// Remove the choice so we are backtracking
auto choice_it = choices_.find(label_);
choices_.erase(choice_it);
// Remove the assignment so we are backtracking
auto assignment_it = assignment_.find(label_);
assignment_.erase(assignment_it);
}
}
/** apply unary operator */
/// apply unary operator.
NodePtr apply(const Unary& op) const override {
auto r = boost::make_shared<Choice>(label_, *this, op);
return Unique(r);
@ -382,8 +384,8 @@ namespace gtsam {
/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
auto r = boost::make_shared<Choice>(label_, *this, op, choices);
const Assignment<L>& assignment) const override {
auto r = boost::make_shared<Choice>(label_, *this, op, assignment);
return Unique(r);
}
@ -678,7 +680,14 @@ namespace gtsam {
}
/****************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument.
/**
* Functor performing depth-first visit without Assignment<L> argument.
*
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
* can have <8 leaves. For example, if a tree has all assignment values as 1,
* then pruning will cause the tree to have only 1 leaf yet 8 assignments.
*/
template <typename L, typename Y>
struct Visit {
using F = std::function<void(const Y&)>;
@ -707,33 +716,36 @@ namespace gtsam {
}
/****************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument.
/**
* Functor performing depth-first visit with Assignment<L> argument.
*
* NOTE: Follows the same pruning semantics as `visit`.
*/
template <typename L, typename Y>
struct VisitWith {
using Choices = Assignment<L>;
using F = std::function<void(const Choices&, const Y&)>;
using F = std::function<void(const Assignment<L>&, const Y&)>;
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object.
Assignment<L> assignment; ///< Assignment, mutating through recursion.
F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
using Leaf = typename DecisionTree<L, Y>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
return f(choices, leaf->constant());
return f(assignment, leaf->constant());
using Choice = typename DecisionTree<L, Y>::Choice;
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
if (!choice)
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
for (size_t i = 0; i < choice->nrChoices(); i++) {
choices[choice->label()] = i; // Set assignment for label to i
assignment[choice->label()] = i; // Set assignment for label to i
(*this)(choice->branches()[i]); // recurse!
// Remove the choice so we are backtracking
auto choice_it = choices.find(choice->label());
choices.erase(choice_it);
auto choice_it = assignment.find(choice->label());
assignment.erase(choice_it);
}
}
};
@ -763,12 +775,14 @@ namespace gtsam {
}
/****************************************************************************/
// labels is just done with a visit
// Get (partial) labels by performing a visit.
template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const {
std::set<L> unique;
auto f = [&](const Assignment<L>& choices, const Y&) {
for (auto&& kv : choices) unique.insert(kv.first);
auto f = [&](const Assignment<L>& assignment, const Y&) {
for (auto&& kv : assignment) {
unique.insert(kv.first);
}
};
visitWith(f);
return unique;
@ -817,8 +831,8 @@ namespace gtsam {
throw std::runtime_error(
"DecisionTree::apply(unary op) undefined for empty tree.");
}
Assignment<L> choices;
return DecisionTree(root_->apply(op, choices));
Assignment<L> assignment;
return DecisionTree(root_->apply(op, assignment));
}
/****************************************************************************/

View File

@ -105,7 +105,7 @@ namespace gtsam {
virtual const Y& operator()(const Assignment<L>& x) const = 0;
virtual Ptr apply(const Unary& op) const = 0;
virtual Ptr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const = 0;
const Assignment<L>& assignment) const = 0;
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
@ -153,7 +153,7 @@ namespace gtsam {
/** Create a constant */
explicit DecisionTree(const Y& y);
/** Create a new leaf function splitting on a variable */
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
DecisionTree(const L& label, const Y& y1, const Y& y2);
/** Allow Label+Cardinality for convenience */
@ -219,9 +219,8 @@ namespace gtsam {
/// @name Standard Interface
/// @{
/** Make virtual */
virtual ~DecisionTree() {
}
/// Make virtual
virtual ~DecisionTree() {}
/// Check if tree is empty.
bool empty() const { return !root_; }
@ -234,11 +233,13 @@ namespace gtsam {
/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f side-effect taking a value.
*
* @note Due to pruning, leaves might not exhaust choices.
*
*
* @param f (side-effect) Function taking a value.
*
* @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
*
* Example:
* int sum = 0;
* auto visitor = [&](int y) { sum += y; };
@ -249,14 +250,16 @@ namespace gtsam {
/**
* @brief Visit all leaves in depth-first fashion.
*
* @param f side-effect taking an assignment and a value.
*
* @note Due to pruning, leaves might not exhaust choices.
*
*
* @param f (side-effect) Function taking an assignment and a value.
*
* @note Due to pruning, the number of leaves may not be the same as the
* number of assignments. E.g. if we have a tree on 2 binary variables with
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
*
* Example:
* int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
* auto visitor = [&](const Assignment<L>& assignment, int y) { sum += y; };
* tree.visitWith(visitor);
*/
template <typename Func>
@ -275,7 +278,7 @@ namespace gtsam {
*
* @note X is always passed by value.
* @note Due to pruning, leaves might not exhaust choices.
*
*
* Example:
* auto add = [](const double& y, double x) { return y + x; };
* double sum = tree.fold(add, 0.0);