refactor DecisionTree to make a distinction between leaves and assignments
							parent
							
								
									cc7f4992b7
								
							
						
					
					
						commit
						d5d5ecc3b3
					
				| 
						 | 
				
			
			@ -59,7 +59,7 @@ namespace gtsam {
 | 
			
		|||
    /** constant stored in this leaf */
 | 
			
		||||
    Y constant_;
 | 
			
		||||
 | 
			
		||||
    /** The number of assignments contained within this leaf
 | 
			
		||||
    /** The number of assignments contained within this leaf.
 | 
			
		||||
     * Particularly useful when leaves have been pruned.
 | 
			
		||||
     */
 | 
			
		||||
    size_t nrAssignments_;
 | 
			
		||||
| 
						 | 
				
			
			@ -68,7 +68,7 @@ namespace gtsam {
 | 
			
		|||
    Leaf(const Y& constant, size_t nrAssignments = 1)
 | 
			
		||||
        : constant_(constant), nrAssignments_(nrAssignments) {}
 | 
			
		||||
 | 
			
		||||
    /** return the constant */
 | 
			
		||||
    /// Return the constant
 | 
			
		||||
    const Y& constant() const {
 | 
			
		||||
      return constant_;
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -81,19 +81,19 @@ namespace gtsam {
 | 
			
		|||
      return constant_ == q.constant_;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// polymorphic equality: is q is a leaf, could be
 | 
			
		||||
    /// polymorphic equality: is q a leaf and is it the same as this leaf?
 | 
			
		||||
    bool sameLeaf(const Node& q) const override {
 | 
			
		||||
      return (q.isLeaf() && q.sameLeaf(*this));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /** equality up to tolerance */
 | 
			
		||||
    /// equality up to tolerance
 | 
			
		||||
    bool equals(const Node& q, const CompareFunc& compare) const override {
 | 
			
		||||
      const Leaf* other = dynamic_cast<const Leaf*>(&q);
 | 
			
		||||
      if (!other) return false;
 | 
			
		||||
      return compare(this->constant_, other->constant_);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /** print */
 | 
			
		||||
    /// print
 | 
			
		||||
    void print(const std::string& s, const LabelFormatter& labelFormatter,
 | 
			
		||||
               const ValueFormatter& valueFormatter) const override {
 | 
			
		||||
      std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
 | 
			
		||||
| 
						 | 
				
			
			@ -122,8 +122,8 @@ namespace gtsam {
 | 
			
		|||
 | 
			
		||||
    /// Apply unary operator with assignment
 | 
			
		||||
    NodePtr apply(const UnaryAssignment& op,
 | 
			
		||||
                  const Assignment<L>& choices) const override {
 | 
			
		||||
      NodePtr f(new Leaf(op(choices, constant_), nrAssignments_));
 | 
			
		||||
                  const Assignment<L>& assignment) const override {
 | 
			
		||||
      NodePtr f(new Leaf(op(assignment, constant_), nrAssignments_));
 | 
			
		||||
      return f;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -168,7 +168,10 @@ namespace gtsam {
 | 
			
		|||
    std::vector<NodePtr> branches_;
 | 
			
		||||
 | 
			
		||||
   private:
 | 
			
		||||
    /** incremental allSame */
 | 
			
		||||
    /**
 | 
			
		||||
     * Incremental allSame.
 | 
			
		||||
     * Records if all the branches are the same leaf.
 | 
			
		||||
     */
 | 
			
		||||
    size_t allSame_;
 | 
			
		||||
 | 
			
		||||
    using ChoicePtr = boost::shared_ptr<const Choice>;
 | 
			
		||||
| 
						 | 
				
			
			@ -181,7 +184,7 @@ namespace gtsam {
 | 
			
		|||
#endif
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /** If all branches of a choice node f are the same, just return a branch */
 | 
			
		||||
    /// If all branches of a choice node f are the same, just return a branch.
 | 
			
		||||
    static NodePtr Unique(const ChoicePtr& f) {
 | 
			
		||||
#ifndef DT_NO_PRUNING
 | 
			
		||||
      if (f->allSame_) {
 | 
			
		||||
| 
						 | 
				
			
			@ -205,15 +208,13 @@ namespace gtsam {
 | 
			
		|||
 | 
			
		||||
    bool isLeaf() const override { return false; }
 | 
			
		||||
 | 
			
		||||
    /** Constructor, given choice label and mandatory expected branch count */
 | 
			
		||||
    /// Constructor, given choice label and mandatory expected branch count.
 | 
			
		||||
    Choice(const L& label, size_t count) :
 | 
			
		||||
      label_(label), allSame_(true) {
 | 
			
		||||
      branches_.reserve(count);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Construct from applying binary op to two Choice nodes
 | 
			
		||||
     */
 | 
			
		||||
    /// Construct from applying binary op to two Choice nodes.
 | 
			
		||||
    Choice(const Choice& f, const Choice& g, const Binary& op) :
 | 
			
		||||
      allSame_(true) {
 | 
			
		||||
      // Choose what to do based on label
 | 
			
		||||
| 
						 | 
				
			
			@ -241,6 +242,7 @@ namespace gtsam {
 | 
			
		|||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Return the label of this choice node.
 | 
			
		||||
    const L& label() const {
 | 
			
		||||
      return label_;
 | 
			
		||||
    }
 | 
			
		||||
| 
						 | 
				
			
			@ -262,7 +264,7 @@ namespace gtsam {
 | 
			
		|||
      branches_.push_back(node);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /** print (as a tree) */
 | 
			
		||||
    /// print (as a tree).
 | 
			
		||||
    void print(const std::string& s, const LabelFormatter& labelFormatter,
 | 
			
		||||
               const ValueFormatter& valueFormatter) const override {
 | 
			
		||||
      std::cout << s << " Choice(";
 | 
			
		||||
| 
						 | 
				
			
			@ -308,7 +310,7 @@ namespace gtsam {
 | 
			
		|||
      return (q.isLeaf() && q.sameLeaf(*this));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /** equality */
 | 
			
		||||
    /// equality
 | 
			
		||||
    bool equals(const Node& q, const CompareFunc& compare) const override {
 | 
			
		||||
      const Choice* other = dynamic_cast<const Choice*>(&q);
 | 
			
		||||
      if (!other) return false;
 | 
			
		||||
| 
						 | 
				
			
			@ -321,7 +323,7 @@ namespace gtsam {
 | 
			
		|||
      return true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /** evaluate */
 | 
			
		||||
    /// evaluate
 | 
			
		||||
    const Y& operator()(const Assignment<L>& x) const override {
 | 
			
		||||
#ifndef NDEBUG
 | 
			
		||||
      typename Assignment<L>::const_iterator it = x.find(label_);
 | 
			
		||||
| 
						 | 
				
			
			@ -336,13 +338,13 @@ namespace gtsam {
 | 
			
		|||
      return (*child)(x);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * Construct from applying unary op to a Choice node
 | 
			
		||||
     */
 | 
			
		||||
    /// Construct from applying unary op to a Choice node.
 | 
			
		||||
    Choice(const L& label, const Choice& f, const Unary& op) :
 | 
			
		||||
      label_(label), allSame_(true) {
 | 
			
		||||
      branches_.reserve(f.branches_.size());  // reserve space
 | 
			
		||||
      for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
 | 
			
		||||
      for (const NodePtr& branch : f.branches_) {
 | 
			
		||||
        push_back(branch->apply(op));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
| 
						 | 
				
			
			@ -353,28 +355,28 @@ namespace gtsam {
 | 
			
		|||
     * @param f The original choice node to apply the op on.
 | 
			
		||||
     * @param op Function to apply on the choice node. Takes Assignment and
 | 
			
		||||
     * value as arguments.
 | 
			
		||||
     * @param choices The Assignment that will go to op.
 | 
			
		||||
     * @param assignment The Assignment that will go to op.
 | 
			
		||||
     */
 | 
			
		||||
    Choice(const L& label, const Choice& f, const UnaryAssignment& op,
 | 
			
		||||
           const Assignment<L>& choices)
 | 
			
		||||
           const Assignment<L>& assignment)
 | 
			
		||||
        : label_(label), allSame_(true) {
 | 
			
		||||
      branches_.reserve(f.branches_.size());  // reserve space
 | 
			
		||||
 | 
			
		||||
      Assignment<L> choices_ = choices;
 | 
			
		||||
      Assignment<L> assignment_ = assignment;
 | 
			
		||||
 | 
			
		||||
      for (size_t i = 0; i < f.branches_.size(); i++) {
 | 
			
		||||
        choices_[label_] = i;  // Set assignment for label to i
 | 
			
		||||
        assignment_[label_] = i;  // Set assignment for label to i
 | 
			
		||||
 | 
			
		||||
        const NodePtr branch = f.branches_[i];
 | 
			
		||||
        push_back(branch->apply(op, choices_));
 | 
			
		||||
        push_back(branch->apply(op, assignment_));
 | 
			
		||||
 | 
			
		||||
        // Remove the choice so we are backtracking
 | 
			
		||||
        auto choice_it = choices_.find(label_);
 | 
			
		||||
        choices_.erase(choice_it);
 | 
			
		||||
        // Remove the assignment so we are backtracking
 | 
			
		||||
        auto assignment_it = assignment_.find(label_);
 | 
			
		||||
        assignment_.erase(assignment_it);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /** apply unary operator */
 | 
			
		||||
    /// apply unary operator.
 | 
			
		||||
    NodePtr apply(const Unary& op) const override {
 | 
			
		||||
      auto r = boost::make_shared<Choice>(label_, *this, op);
 | 
			
		||||
      return Unique(r);
 | 
			
		||||
| 
						 | 
				
			
			@ -382,8 +384,8 @@ namespace gtsam {
 | 
			
		|||
 | 
			
		||||
    /// Apply unary operator with assignment
 | 
			
		||||
    NodePtr apply(const UnaryAssignment& op,
 | 
			
		||||
                  const Assignment<L>& choices) const override {
 | 
			
		||||
      auto r = boost::make_shared<Choice>(label_, *this, op, choices);
 | 
			
		||||
                  const Assignment<L>& assignment) const override {
 | 
			
		||||
      auto r = boost::make_shared<Choice>(label_, *this, op, assignment);
 | 
			
		||||
      return Unique(r);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -678,7 +680,14 @@ namespace gtsam {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  /****************************************************************************/
 | 
			
		||||
  // Functor performing depth-first visit without Assignment<L> argument.
 | 
			
		||||
  /**
 | 
			
		||||
   * Functor performing depth-first visit without Assignment<L> argument.
 | 
			
		||||
   *
 | 
			
		||||
   * NOTE: We differentiate between leaves and assignments. Concretely, a 3
 | 
			
		||||
   * binary variable tree will have 2^3=8 assignments, but based on pruning, it
 | 
			
		||||
   * can have <8 leaves. For example, if a tree has all assignment values as 1,
 | 
			
		||||
   * then pruning will cause the tree to have only 1 leaf yet 8 assignments.
 | 
			
		||||
   */
 | 
			
		||||
  template <typename L, typename Y>
 | 
			
		||||
  struct Visit {
 | 
			
		||||
    using F = std::function<void(const Y&)>;
 | 
			
		||||
| 
						 | 
				
			
			@ -707,33 +716,36 @@ namespace gtsam {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  /****************************************************************************/
 | 
			
		||||
  // Functor performing depth-first visit with Assignment<L> argument.
 | 
			
		||||
  /**
 | 
			
		||||
   * Functor performing depth-first visit with Assignment<L> argument.
 | 
			
		||||
   *
 | 
			
		||||
   * NOTE: Follows the same pruning semantics as `visit`.
 | 
			
		||||
   */
 | 
			
		||||
  template <typename L, typename Y>
 | 
			
		||||
  struct VisitWith {
 | 
			
		||||
    using Choices = Assignment<L>;
 | 
			
		||||
    using F = std::function<void(const Choices&, const Y&)>;
 | 
			
		||||
    using F = std::function<void(const Assignment<L>&, const Y&)>;
 | 
			
		||||
    explicit VisitWith(F f) : f(f) {}  ///< Construct from folding function.
 | 
			
		||||
    Choices choices;  ///< Assignment, mutating through recursion.
 | 
			
		||||
    Assignment<L> assignment;  ///< Assignment, mutating through recursion.
 | 
			
		||||
    F f;                       ///< folding function object.
 | 
			
		||||
 | 
			
		||||
    /// Do a depth-first visit on the tree rooted at node.
 | 
			
		||||
    void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
 | 
			
		||||
      using Leaf = typename DecisionTree<L, Y>::Leaf;
 | 
			
		||||
      if (auto leaf = boost::dynamic_pointer_cast<const Leaf>(node))
 | 
			
		||||
        return f(choices, leaf->constant());
 | 
			
		||||
        return f(assignment, leaf->constant());
 | 
			
		||||
 | 
			
		||||
      using Choice = typename DecisionTree<L, Y>::Choice;
 | 
			
		||||
      auto choice = boost::dynamic_pointer_cast<const Choice>(node);
 | 
			
		||||
      if (!choice)
 | 
			
		||||
        throw std::invalid_argument("DecisionTree::VisitWith: Invalid NodePtr");
 | 
			
		||||
      for (size_t i = 0; i < choice->nrChoices(); i++) {
 | 
			
		||||
        choices[choice->label()] = i;  // Set assignment for label to i
 | 
			
		||||
        assignment[choice->label()] = i;  // Set assignment for label to i
 | 
			
		||||
 | 
			
		||||
        (*this)(choice->branches()[i]);  // recurse!
 | 
			
		||||
 | 
			
		||||
        // Remove the choice so we are backtracking
 | 
			
		||||
        auto choice_it = choices.find(choice->label());
 | 
			
		||||
        choices.erase(choice_it);
 | 
			
		||||
        auto choice_it = assignment.find(choice->label());
 | 
			
		||||
        assignment.erase(choice_it);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
| 
						 | 
				
			
			@ -763,12 +775,14 @@ namespace gtsam {
 | 
			
		|||
  }
 | 
			
		||||
 | 
			
		||||
  /****************************************************************************/
 | 
			
		||||
  // labels is just done with a visit
 | 
			
		||||
  // Get (partial) labels by performing a visit.
 | 
			
		||||
  template <typename L, typename Y>
 | 
			
		||||
  std::set<L> DecisionTree<L, Y>::labels() const {
 | 
			
		||||
    std::set<L> unique;
 | 
			
		||||
    auto f = [&](const Assignment<L>& choices, const Y&) {
 | 
			
		||||
      for (auto&& kv : choices) unique.insert(kv.first);
 | 
			
		||||
    auto f = [&](const Assignment<L>& assignment, const Y&) {
 | 
			
		||||
      for (auto&& kv : assignment) {
 | 
			
		||||
        unique.insert(kv.first);
 | 
			
		||||
      }
 | 
			
		||||
    };
 | 
			
		||||
    visitWith(f);
 | 
			
		||||
    return unique;
 | 
			
		||||
| 
						 | 
				
			
			@ -817,8 +831,8 @@ namespace gtsam {
 | 
			
		|||
      throw std::runtime_error(
 | 
			
		||||
          "DecisionTree::apply(unary op) undefined for empty tree.");
 | 
			
		||||
    }
 | 
			
		||||
    Assignment<L> choices;
 | 
			
		||||
    return DecisionTree(root_->apply(op, choices));
 | 
			
		||||
    Assignment<L> assignment;
 | 
			
		||||
    return DecisionTree(root_->apply(op, assignment));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /****************************************************************************/
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -105,7 +105,7 @@ namespace gtsam {
 | 
			
		|||
      virtual const Y& operator()(const Assignment<L>& x) const = 0;
 | 
			
		||||
      virtual Ptr apply(const Unary& op) const = 0;
 | 
			
		||||
      virtual Ptr apply(const UnaryAssignment& op,
 | 
			
		||||
                        const Assignment<L>& choices) const = 0;
 | 
			
		||||
                        const Assignment<L>& assignment) const = 0;
 | 
			
		||||
      virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
 | 
			
		||||
      virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
 | 
			
		||||
      virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
 | 
			
		||||
| 
						 | 
				
			
			@ -153,7 +153,7 @@ namespace gtsam {
 | 
			
		|||
    /** Create a constant */
 | 
			
		||||
    explicit DecisionTree(const Y& y);
 | 
			
		||||
 | 
			
		||||
    /** Create a new leaf function splitting on a variable */
 | 
			
		||||
    /// Create tree with 2 assignments `y1`, `y2`, splitting on variable `label`
 | 
			
		||||
    DecisionTree(const L& label, const Y& y1, const Y& y2);
 | 
			
		||||
 | 
			
		||||
    /** Allow Label+Cardinality for convenience */
 | 
			
		||||
| 
						 | 
				
			
			@ -219,9 +219,8 @@ namespace gtsam {
 | 
			
		|||
    /// @name Standard Interface
 | 
			
		||||
    /// @{
 | 
			
		||||
 | 
			
		||||
    /** Make virtual */
 | 
			
		||||
    virtual ~DecisionTree() {
 | 
			
		||||
    }
 | 
			
		||||
    /// Make virtual
 | 
			
		||||
    virtual ~DecisionTree() {}
 | 
			
		||||
 | 
			
		||||
    /// Check if tree is empty.
 | 
			
		||||
    bool empty() const { return !root_; }
 | 
			
		||||
| 
						 | 
				
			
			@ -235,9 +234,11 @@ namespace gtsam {
 | 
			
		|||
    /**
 | 
			
		||||
     * @brief Visit all leaves in depth-first fashion.
 | 
			
		||||
     *
 | 
			
		||||
     * @param f side-effect taking a value.
 | 
			
		||||
     * @param f (side-effect) Function taking a value.
 | 
			
		||||
     *
 | 
			
		||||
     * @note Due to pruning, leaves might not exhaust choices.
 | 
			
		||||
     * @note Due to pruning, the number of leaves may not be the same as the
 | 
			
		||||
     * number of assignments. E.g. if we have a tree on 2 binary variables with
 | 
			
		||||
     * all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
 | 
			
		||||
     *
 | 
			
		||||
     * Example:
 | 
			
		||||
     *   int sum = 0;
 | 
			
		||||
| 
						 | 
				
			
			@ -250,13 +251,15 @@ namespace gtsam {
 | 
			
		|||
    /**
 | 
			
		||||
     * @brief Visit all leaves in depth-first fashion.
 | 
			
		||||
     *
 | 
			
		||||
     * @param f side-effect taking an assignment and a value.
 | 
			
		||||
     * @param f (side-effect) Function taking an assignment and a value.
 | 
			
		||||
     *
 | 
			
		||||
     * @note Due to pruning, leaves might not exhaust choices.
 | 
			
		||||
     * @note Due to pruning, the number of leaves may not be the same as the
 | 
			
		||||
     * number of assignments. E.g. if we have a tree on 2 binary variables with
 | 
			
		||||
     * all values being 1, then there are 2^2=4 assignments, but only 1 leaf.
 | 
			
		||||
     *
 | 
			
		||||
     * Example:
 | 
			
		||||
     *   int sum = 0;
 | 
			
		||||
     *   auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
 | 
			
		||||
     *   auto visitor = [&](const Assignment<L>& assignment, int y) { sum += y; };
 | 
			
		||||
     *   tree.visitWith(visitor);
 | 
			
		||||
     */
 | 
			
		||||
    template <typename Func>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue