refactor DecisionTree to make a distinction between leaves and assignments
parent
cc7f4992b7
commit
d5d5ecc3b3
|
|
@ -59,7 +59,7 @@ namespace gtsam {
|
|||
/** constant stored in this leaf */
|
||||
Y constant_;
|
||||
|
||||
/** The number of assignments contained within this leaf
|
||||
/** The number of assignments contained within this leaf.
|
||||
* Particularly useful when leaves have been pruned.
|
||||
*/
|
||||
size_t nrAssignments_;
|
||||
|
|
@ -68,7 +68,7 @@ namespace gtsam {
|
|||
Leaf(const Y& constant, size_t nrAssignments = 1)
|
||||
: constant_(constant), nrAssignments_(nrAssignments) {}
|
||||
|
||||
/** return the constant */
|
||||
/// Return the constant
|
||||
const Y& constant() const {
|
||||
return constant_;
|
||||
}
|
||||
|
|
@ -81,19 +81,19 @@ namespace gtsam {
|
|||
return constant_ == q.constant_;
|
||||
}
|
||||
|
||||
/// polymorphic equality: is q is a leaf, could be
|
||||
/// polymorphic equality: is q a leaf and is it the same as this leaf?
|
||||
bool sameLeaf(const Node& q) const override {
|
||||
return (q.isLeaf() && q.sameLeaf(*this));
|
||||
}
|
||||
|
||||
/** equality up to tolerance */
|
||||
/// equality up to tolerance
|
||||
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||
const Leaf* other = dynamic_cast<const Leaf*>(&q);
|
||||
if (!other) return false;
|
||||
return compare(this->constant_, other->constant_);
|
||||
}
|
||||
|
||||
/** print */
|
||||
/// print
|
||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const override {
|
||||
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
|
||||
|
|
@ -122,8 +122,8 @@ namespace gtsam {
|
|||
|
||||
/// Apply unary operator with assignment
|
||||
NodePtr apply(const UnaryAssignment& op,
|
||||
const Assignment<L>& choices) const override {
|
||||
NodePtr f(new Leaf(op(choices, constant_), nrAssignments_));
|
||||
const Assignment<L>& assignment) const override {
|
||||
NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
|
||||
return f;
|
||||
}
|
||||
|
||||
|
|
@ -168,7 +168,10 @@ namespace gtsam {
|
|||
std::vector<NodePtr> branches_;
|
||||
|
||||
private:
|
||||
/** incremental allSame */
|
||||
/**
|
||||
* Incremental allSame.
|
||||
* Records if all the branches are the same leaf.
|
||||
*/
|
||||
size_t allSame_;
|
||||
|
||||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||
|
|
@ -181,7 +184,7 @@ namespace gtsam {
|
|||
#endif
|
||||
}
|
||||
|
||||
/** If all branches of a choice node f are the same, just return a branch */
|
||||
/// If all branches of a choice node f are the same, just return a branch.
|
||||
static NodePtr Unique(const ChoicePtr& f) {
|
||||
#ifndef DT_NO_PRUNING
|
||||
if (f->allSame_) {
|
||||
|
|
@ -205,15 +208,13 @@ namespace gtsam {
|
|||
|
||||
bool isLeaf() const override { return false; }
|
||||
|
||||
/** Constructor, given choice label and mandatory expected branch count */
|
||||
/// Constructor, given choice label and mandatory expected branch count.
|
||||
Choice(const L& label, size_t count) :
|
||||
label_(label), allSame_(true) {
|
||||
branches_.reserve(count);
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct from applying binary op to two Choice nodes
|
||||
*/
|
||||
/// Construct from applying binary op to two Choice nodes.
|
||||
Choice(const Choice& f, const Choice& g, const Binary& op) :
|
||||
allSame_(true) {
|
||||
// Choose what to do based on label
|
||||
|
|
@ -241,6 +242,7 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/// Return the label of this choice node.
|
||||
const L& label() const {
|
||||
return label_;
|
||||
}
|
||||
|
|
@ -262,7 +264,7 @@ namespace gtsam {
|
|||
branches_.push_back(node);
|
||||
}
|
||||
|
||||
/** print (as a tree) */
|
||||
/// print (as a tree).
|
||||
void print(const std::string& s, const LabelFormatter& labelFormatter,
|
||||
const ValueFormatter& valueFormatter) const override {
|
||||
std::cout << s << " Choice(";
|
||||
|
|
@ -308,7 +310,7 @@ namespace gtsam {
|
|||
return (q.isLeaf() && q.sameLeaf(*this));
|
||||
}
|
||||
|
||||
/** equality */
|
||||
/// equality
|
||||
bool equals(const Node& q, const CompareFunc& compare) const override {
|
||||
const Choice* other = dynamic_cast<const Choice*>(&q);
|
||||
if (!other) return false;
|
||||
|
|
@ -321,7 +323,7 @@ namespace gtsam {
|
|||
return true;
|
||||
}
|
||||
|
||||
/** evaluate */
|
||||
/// evaluate
|
||||
const Y& operator()(const Assignment<L>& x) const override {
|
||||
#ifndef NDEBUG
|
||||
typename Assignment<L>::const_iterator it = x.find(label_);
|
||||
|
|
@ -336,13 +338,13 @@ namespace gtsam {
|
|||
return (*child)(x);
|
||||
}
|
||||
|
||||
/**
|
||||
* Construct from applying unary op to a Choice node
|
||||
*/
|
||||
/// Construct from applying unary op to a Choice node.
|
||||
Choice(const L& label, const Choice& f, const Unary& op) :
|
||||
label_(label), allSame_(true) {
|
||||
branches_.reserve(f.branches_.size()); // reserve space
|
||||
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
|
||||
for (const NodePtr& branch : f.branches_) {
|
||||
push_back(branch->apply(op));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -353,28 +355,28 @@ namespace gtsam {
|
|||
* @param f The original choice node to apply the op on.
|
||||
* @param op Function to apply on the choice node. Takes Assignment and
|
||||
* value as arguments.
|
||||
* @param choices The Assignment that will go to op.
|
||||
* @param assignment The Assignment that will go to op.
|
||||
*/
|
||||
Choice(const L& label, const Choice& f, const UnaryAssignment& op,
|
||||
const Assignment<L>& choices)
|
||||
const Assignment<L>& assignment)
|
||||
: label_(label), allSame_(true) {
|
||||
branches_.reserve(f.branches_.size()); // reserve space
|
||||
|
||||
Assignment<L> choices_ = choices;
|
||||
Assignment<L> assignment_ = assignment;
|
||||
|
||||
for (size_t i = 0; i < f.branches_.size(); i++) {
|
||||
choices_[label_] = i; // Set assignment for label to i
|
||||
assignment_[label_] = i; // Set assignment for label to i
|
||||
|
||||
const NodePtr branch = f.branches_[i];
|
||||
push_back(branch->apply(op, choices_));
|
||||
push_back(branch->apply(op, assignment_));
|
||||
|
||||
// Remove the choice so we are backtracking
|
||||
auto choice_it = choices_.find(label_);
|
||||
choices_.erase(choice_it);
|
||||
// Remove the assignment so we are backtracking
|
||||
auto assignment_it = assignment_.find(label_);
|
||||
assignment_.erase(assignment_it);
|
||||
}
|
||||
}
|
||||
|
||||
/** apply unary operator */
|
||||
/// apply unary operator.
|
||||
NodePtr apply(const Unary& op) const override {
|
||||
auto r = boost::make_shared<Choice>(label_, *this, op);
|
||||
return Unique(r);
|
||||
|
|
@ -382,8 +384,8 @@ namespace gtsam {
|
|||
|
||||
/// Apply unary operator with assignment
|
||||
NodePtr apply(const UnaryAssignment& op,
|
||||
const Assignment<L>& choices) const override {
|
||||
auto r = boost::make_shared<Choice>(label_, *this, op, choices);
|
||||
const Assignment<L>& assignment) const override {
|
||||
auto r = boost::make_shared<Choice>(label_, *this, op, assignment);
|
||||
return Unique(r);
|
||||
}
|
||||
|
||||
|
|
@ -678,7 +680,14 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/****************************************************************************/
|
||||
// Functor performing depth-first visit without Assignment<L> argument.
|
||||
/**
|
||||
* Functor performing depth-first visit without Assignment<L> argument.
|
||||
*
|
||||
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
|
||||
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
|
||||
* can have <8 leaves. For example, if a tree has all assignment values as 1,
|
||||
* then pruning will cause the tree to have only 1 leaf yet 8 assignments.
|
||||
*/
|
||||
template <typename L, typename Y>
|
||||
struct Visit {
|
||||
using F = std::function<void(const Y&)>;
|
||||
|
|
@ -707,33 +716,36 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/****************************************************************************/
|
||||
// Functor performing depth-first visit with Assignment<L> argument.
|
||||
/**
|
||||
* Functor performing depth-first visit with Assignment<L> argument.
|
||||
*
|
||||
* NOTE: Follows the same pruning semantics as `visit`.
|
||||
*/
|
||||
template <typename L, typename Y>
|
||||
struct VisitWith {
|
||||
using Choices = Assignment<L>;
|
||||
using F = std::function<void(const Choices&, const Y&)>;
|
||||
using F = std::function<void(const Assignment<L>&, const Y&)>;
|
||||
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
||||
Choices choices; ///< Assignment, mutating through recursion.
|
||||
F f; ///< folding function object.
|
||||
Assignment<L> assignment; ///< Assignment, mutating through recursion.
|
||||
F f; ///< folding function object.
|
||||
|
||||
/// Do a depth-first visit on the tree rooted at node.
|
||||
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
|
||||
using Leaf = typename DecisionTree<L, Y>::Leaf;
|
||||
if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
|
||||
return f(choices, leaf->constant());
|
||||
return f(assignment, leaf->constant());
|
||||
|
||||
using Choice = typename DecisionTree<L, Y>::Choice;
|
||||
auto choice = boost::dynamic_pointer_cast<const Choice>(node);
|
||||
if (!choice)
|
||||
throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
|
||||
for (size_t i = 0; i < choice->nrChoices(); i++) {
|
||||
choices[choice->label()] = i; // Set assignment for label to i
|
||||
assignment[choice->label()] = i; // Set assignment for label to i
|
||||
|
||||
(*this)(choice->branches()[i]); // recurse!
|
||||
|
||||
// Remove the choice so we are backtracking
|
||||
auto choice_it = choices.find(choice->label());
|
||||
choices.erase(choice_it);
|
||||
auto choice_it = assignment.find(choice->label());
|
||||
assignment.erase(choice_it);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
@ -763,12 +775,14 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
/****************************************************************************/
|
||||
// labels is just done with a visit
|
||||
// Get (partial) labels by performing a visit.
|
||||
template <typename L, typename Y>
|
||||
std::set<L> DecisionTree<L, Y>::labels() const {
|
||||
std::set<L> unique;
|
||||
auto f = [&](const Assignment<L>& choices, const Y&) {
|
||||
for (auto&& kv : choices) unique.insert(kv.first);
|
||||
auto f = [&](const Assignment<L>& assignment, const Y&) {
|
||||
for (auto&& kv : assignment) {
|
||||
unique.insert(kv.first);
|
||||
}
|
||||
};
|
||||
visitWith(f);
|
||||
return unique;
|
||||
|
|
@ -817,8 +831,8 @@ namespace gtsam {
|
|||
throw std::runtime_error(
|
||||
"DecisionTree::apply(unary op) undefined for empty tree.");
|
||||
}
|
||||
Assignment<L> choices;
|
||||
return DecisionTree(root_->apply(op, choices));
|
||||
Assignment<L> assignment;
|
||||
return DecisionTree(root_->apply(op, assignment));
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ namespace gtsam {
|
|||
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
||||
virtual Ptr apply(const Unary& op) const = 0;
|
||||
virtual Ptr apply(const UnaryAssignment& op,
|
||||
const Assignment<L>& choices) const = 0;
|
||||
const Assignment<L>& assignment) const = 0;
|
||||
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
||||
virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
|
||||
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
||||
|
|
@ -153,7 +153,7 @@ namespace gtsam {
|
|||
/** Create a constant */
|
||||
explicit DecisionTree(const Y& y);
|
||||
|
||||
/** Create a new leaf function splitting on a variable */
|
||||
/// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
|
||||
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
||||
|
||||
/** Allow Label+Cardinality for convenience */
|
||||
|
|
@ -219,9 +219,8 @@ namespace gtsam {
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** Make virtual */
|
||||
virtual ~DecisionTree() {
|
||||
}
|
||||
/// Make virtual
|
||||
virtual ~DecisionTree() {}
|
||||
|
||||
/// Check if tree is empty.
|
||||
bool empty() const { return !root_; }
|
||||
|
|
@ -234,11 +233,13 @@ namespace gtsam {
|
|||
|
||||
/**
|
||||
* @brief Visit all leaves in depth-first fashion.
|
||||
*
|
||||
* @param f side-effect taking a value.
|
||||
*
|
||||
* @note Due to pruning, leaves might not exhaust choices.
|
||||
*
|
||||
*
|
||||
* @param f (side-effect) Function taking a value.
|
||||
*
|
||||
* @note Due to pruning, the number of leaves may not be the same as the
|
||||
* number of assignments. E.g. if we have a tree on 2 binary variables with
|
||||
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
|
||||
*
|
||||
* Example:
|
||||
* int sum = 0;
|
||||
* auto visitor = [&](int y) { sum += y; };
|
||||
|
|
@ -249,14 +250,16 @@ namespace gtsam {
|
|||
|
||||
/**
|
||||
* @brief Visit all leaves in depth-first fashion.
|
||||
*
|
||||
* @param f side-effect taking an assignment and a value.
|
||||
*
|
||||
* @note Due to pruning, leaves might not exhaust choices.
|
||||
*
|
||||
*
|
||||
* @param f (side-effect) Function taking an assignment and a value.
|
||||
*
|
||||
* @note Due to pruning, the number of leaves may not be the same as the
|
||||
* number of assignments. E.g. if we have a tree on 2 binary variables with
|
||||
* all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
|
||||
*
|
||||
* Example:
|
||||
* int sum = 0;
|
||||
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
|
||||
* auto visitor = [&](const Assignment<L>& assignment, int y) { sum += y; };
|
||||
* tree.visitWith(visitor);
|
||||
*/
|
||||
template <typename Func>
|
||||
|
|
@ -275,7 +278,7 @@ namespace gtsam {
|
|||
*
|
||||
* @note X is always passed by value.
|
||||
* @note Due to pruning, leaves might not exhaust choices.
|
||||
*
|
||||
*
|
||||
* Example:
|
||||
* auto add = [](const double& y, double x) { return y + x; };
|
||||
* double sum = tree.fold(add, 0.0);
|
||||
|
|
|
|||
Loading…
Reference in New Issue