diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 826e54b95..48da55ce1 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -59,7 +59,7 @@ namespace gtsam { /** constant stored in this leaf */ Y constant_; - /** The number of assignments contained within this leaf + /** The number of assignments contained within this leaf. * Particularly useful when leaves have been pruned. */ size_t nrAssignments_; @@ -68,7 +68,7 @@ namespace gtsam { Leaf(const Y& constant, size_t nrAssignments = 1) : constant_(constant), nrAssignments_(nrAssignments) {} - /** return the constant */ + /// Return the constant const Y& constant() const { return constant_; } @@ -81,19 +81,19 @@ namespace gtsam { return constant_ == q.constant_; } - /// polymorphic equality: is q is a leaf, could be + /// polymorphic equality: is q a leaf and is it the same as this leaf? bool sameLeaf(const Node& q) const override { return (q.isLeaf() && q.sameLeaf(*this)); } - /** equality up to tolerance */ + /// equality up to tolerance bool equals(const Node& q, const CompareFunc& compare) const override { const Leaf* other = dynamic_cast(&q); if (!other) return false; return compare(this->constant_, other->constant_); } - /** print */ + /// print void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; @@ -122,8 +122,8 @@ namespace gtsam { /// Apply unary operator with assignment NodePtr apply(const UnaryAssignment& op, - const Assignment& choices) const override { - NodePtr f(new Leaf(op(choices, constant_), nrAssignments_)); + const Assignment& assignment) const override { + NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_)); return f; } @@ -168,7 +168,10 @@ namespace gtsam { std::vector branches_; private: - /** incremental allSame */ + /** + * Incremental allSame. + * Records if all the branches are the same leaf. + */ size_t allSame_; using ChoicePtr = boost::shared_ptr; @@ -181,7 +184,7 @@ namespace gtsam { #endif } - /** If all branches of a choice node f are the same, just return a branch */ + /// If all branches of a choice node f are the same, just return a branch. static NodePtr Unique(const ChoicePtr& f) { #ifndef DT_NO_PRUNING if (f->allSame_) { @@ -205,15 +208,13 @@ namespace gtsam { bool isLeaf() const override { return false; } - /** Constructor, given choice label and mandatory expected branch count */ + /// Constructor, given choice label and mandatory expected branch count. Choice(const L& label, size_t count) : label_(label), allSame_(true) { branches_.reserve(count); } - /** - * Construct from applying binary op to two Choice nodes - */ + /// Construct from applying binary op to two Choice nodes. Choice(const Choice& f, const Choice& g, const Binary& op) : allSame_(true) { // Choose what to do based on label @@ -241,6 +242,7 @@ namespace gtsam { } } + /// Return the label of this choice node. const L& label() const { return label_; } @@ -262,7 +264,7 @@ namespace gtsam { branches_.push_back(node); } - /** print (as a tree) */ + /// print (as a tree). void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { std::cout << s << " Choice("; @@ -308,7 +310,7 @@ namespace gtsam { return (q.isLeaf() && q.sameLeaf(*this)); } - /** equality */ + /// equality bool equals(const Node& q, const CompareFunc& compare) const override { const Choice* other = dynamic_cast(&q); if (!other) return false; @@ -321,7 +323,7 @@ namespace gtsam { return true; } - /** evaluate */ + /// evaluate const Y& operator()(const Assignment& x) const override { #ifndef NDEBUG typename Assignment::const_iterator it = x.find(label_); @@ -336,13 +338,13 @@ namespace gtsam { return (*child)(x); } - /** - * Construct from applying unary op to a Choice node - */ + /// Construct from applying unary op to a Choice node. Choice(const L& label, const Choice& f, const Unary& op) : label_(label), allSame_(true) { branches_.reserve(f.branches_.size()); // reserve space - for (const NodePtr& branch : f.branches_) push_back(branch->apply(op)); + for (const NodePtr& branch : f.branches_) { + push_back(branch->apply(op)); + } } /** @@ -353,28 +355,28 @@ namespace gtsam { * @param f The original choice node to apply the op on. * @param op Function to apply on the choice node. Takes Assignment and * value as arguments. - * @param choices The Assignment that will go to op. + * @param assignment The Assignment that will go to op. */ Choice(const L& label, const Choice& f, const UnaryAssignment& op, - const Assignment& choices) + const Assignment& assignment) : label_(label), allSame_(true) { branches_.reserve(f.branches_.size()); // reserve space - Assignment choices_ = choices; + Assignment assignment_ = assignment; for (size_t i = 0; i < f.branches_.size(); i++) { - choices_[label_] = i; // Set assignment for label to i + assignment_[label_] = i; // Set assignment for label to i const NodePtr branch = f.branches_[i]; - push_back(branch->apply(op, choices_)); + push_back(branch->apply(op, assignment_)); - // Remove the choice so we are backtracking - auto choice_it = choices_.find(label_); - choices_.erase(choice_it); + // Remove the assignment so we are backtracking + auto assignment_it = assignment_.find(label_); + assignment_.erase(assignment_it); } } - /** apply unary operator */ + /// apply unary operator. NodePtr apply(const Unary& op) const override { auto r = boost::make_shared(label_, *this, op); return Unique(r); @@ -382,8 +384,8 @@ namespace gtsam { /// Apply unary operator with assignment NodePtr apply(const UnaryAssignment& op, - const Assignment& choices) const override { - auto r = boost::make_shared(label_, *this, op, choices); + const Assignment& assignment) const override { + auto r = boost::make_shared(label_, *this, op, assignment); return Unique(r); } @@ -678,7 +680,14 @@ namespace gtsam { } /****************************************************************************/ - // Functor performing depth-first visit without Assignment argument. + /** + * Functor performing depth-first visit without Assignment argument. + * + * NOTE: We differentiate between leaves and assignments. Concretely, a 3 + * binary variable tree will have 2^3=8 assignments, but based on pruning, it + * can have <8 leaves. For example, if a tree has all assignment values as 1, + * then pruning will cause the tree to have only 1 leaf yet 8 assignments. + */ template struct Visit { using F = std::function; @@ -707,33 +716,36 @@ namespace gtsam { } /****************************************************************************/ - // Functor performing depth-first visit with Assignment argument. + /** + * Functor performing depth-first visit with Assignment argument. + * + * NOTE: Follows the same pruning semantics as `visit`. + */ template struct VisitWith { - using Choices = Assignment; - using F = std::function; + using F = std::function&, const Y&)>; explicit VisitWith(F f) : f(f) {} ///< Construct from folding function. - Choices choices; ///< Assignment, mutating through recursion. - F f; ///< folding function object. + Assignment assignment; ///< Assignment, mutating through recursion. + F f; ///< folding function object. /// Do a depth-first visit on the tree rooted at node. void operator()(const typename DecisionTree::NodePtr& node) { using Leaf = typename DecisionTree::Leaf; if (auto leaf = boost::dynamic_pointer_cast(node)) - return f(choices, leaf->constant()); + return f(assignment, leaf->constant()); using Choice = typename DecisionTree::Choice; auto choice = boost::dynamic_pointer_cast(node); if (!choice) throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr"); for (size_t i = 0; i < choice->nrChoices(); i++) { - choices[choice->label()] = i; // Set assignment for label to i + assignment[choice->label()] = i; // Set assignment for label to i (*this)(choice->branches()[i]); // recurse! // Remove the choice so we are backtracking - auto choice_it = choices.find(choice->label()); - choices.erase(choice_it); + auto choice_it = assignment.find(choice->label()); + assignment.erase(choice_it); } } }; @@ -763,12 +775,14 @@ namespace gtsam { } /****************************************************************************/ - // labels is just done with a visit + // Get (partial) labels by performing a visit. template std::set DecisionTree::labels() const { std::set unique; - auto f = [&](const Assignment& choices, const Y&) { - for (auto&& kv : choices) unique.insert(kv.first); + auto f = [&](const Assignment& assignment, const Y&) { + for (auto&& kv : assignment) { + unique.insert(kv.first); + } }; visitWith(f); return unique; @@ -817,8 +831,8 @@ namespace gtsam { throw std::runtime_error( "DecisionTree::apply(unary op) undefined for empty tree."); } - Assignment choices; - return DecisionTree(root_->apply(op, choices)); + Assignment assignment; + return DecisionTree(root_->apply(op, assignment)); } /****************************************************************************/ diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index c0a2a7a1c..9520d43bc 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -105,7 +105,7 @@ namespace gtsam { virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; virtual Ptr apply(const UnaryAssignment& op, - const Assignment& choices) const = 0; + const Assignment& assignment) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0; virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; @@ -153,7 +153,7 @@ namespace gtsam { /** Create a constant */ explicit DecisionTree(const Y& y); - /** Create a new leaf function splitting on a variable */ + /// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label` DecisionTree(const L& label, const Y& y1, const Y& y2); /** Allow Label+Cardinality for convenience */ @@ -219,9 +219,8 @@ namespace gtsam { /// @name Standard Interface /// @{ - /** Make virtual */ - virtual ~DecisionTree() { - } + /// Make virtual + virtual ~DecisionTree() {} /// Check if tree is empty. bool empty() const { return !root_; } @@ -234,11 +233,13 @@ namespace gtsam { /** * @brief Visit all leaves in depth-first fashion. - * - * @param f side-effect taking a value. - * - * @note Due to pruning, leaves might not exhaust choices. - * + * + * @param f (side-effect) Function taking a value. + * + * @note Due to pruning, the number of leaves may not be the same as the + * number of assignments. E.g. if we have a tree on 2 binary variables with + * all values being 1, then there are 2^2=4 assignments, but only 1 leaf. + * * Example: * int sum = 0; * auto visitor = [&](int y) { sum += y; }; @@ -249,14 +250,16 @@ namespace gtsam { /** * @brief Visit all leaves in depth-first fashion. - * - * @param f side-effect taking an assignment and a value. - * - * @note Due to pruning, leaves might not exhaust choices. - * + * + * @param f (side-effect) Function taking an assignment and a value. + * + * @note Due to pruning, the number of leaves may not be the same as the + * number of assignments. E.g. if we have a tree on 2 binary variables with + * all values being 1, then there are 2^2=4 assignments, but only 1 leaf. + * * Example: * int sum = 0; - * auto visitor = [&](const Assignment& choices, int y) { sum += y; }; + * auto visitor = [&](const Assignment& assignment, int y) { sum += y; }; * tree.visitWith(visitor); */ template @@ -275,7 +278,7 @@ namespace gtsam { * * @note X is always passed by value. * @note Due to pruning, leaves might not exhaust choices. - * + * * Example: * auto add = [](const double& y, double x) { return y + x; }; * double sum = tree.fold(add, 0.0);