Merge pull request #1875 from borglab/feature/fasterDT
						commit
						db353a58a4
					
				|  | @ -22,10 +22,12 @@ | |||
| #include <gtsam/discrete/DecisionTree-inl.h> | ||||
| 
 | ||||
| #include <algorithm> | ||||
| #include <limits> | ||||
| #include <map> | ||||
| #include <string> | ||||
| #include <iomanip> | ||||
| #include <vector> | ||||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
|   /**
 | ||||
|  |  | |||
|  | @ -22,18 +22,15 @@ | |||
| #include <gtsam/discrete/DecisionTree.h> | ||||
| 
 | ||||
| #include <algorithm> | ||||
| 
 | ||||
| #include <cmath> | ||||
| #include <cassert> | ||||
| #include <fstream> | ||||
| #include <list> | ||||
| #include <iterator> | ||||
| #include <map> | ||||
| #include <optional> | ||||
| #include <set> | ||||
| #include <sstream> | ||||
| #include <string> | ||||
| #include <vector> | ||||
| #include <optional> | ||||
| #include <cassert> | ||||
| #include <iterator> | ||||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
|  | @ -251,22 +248,28 @@ namespace gtsam { | |||
|         label_ = f.label(); | ||||
|         size_t count = f.nrChoices(); | ||||
|         branches_.reserve(count); | ||||
|         for (size_t i = 0; i < count; i++) | ||||
|           push_back(f.branches_[i]->apply_f_op_g(g, op)); | ||||
|         for (size_t i = 0; i < count; i++) { | ||||
|           NodePtr newBranch = f.branches_[i]->apply_f_op_g(g, op); | ||||
|           push_back(std::move(newBranch)); | ||||
|         } | ||||
|       } else if (g.label() > f.label()) { | ||||
|         // f lower than g
 | ||||
|         label_ = g.label(); | ||||
|         size_t count = g.nrChoices(); | ||||
|         branches_.reserve(count); | ||||
|         for (size_t i = 0; i < count; i++) | ||||
|           push_back(g.branches_[i]->apply_g_op_fC(f, op)); | ||||
|         for (size_t i = 0; i < count; i++) { | ||||
|           NodePtr newBranch = g.branches_[i]->apply_g_op_fC(f, op); | ||||
|           push_back(std::move(newBranch)); | ||||
|         } | ||||
|       } else { | ||||
|         // f same level as g
 | ||||
|         label_ = f.label(); | ||||
|         size_t count = f.nrChoices(); | ||||
|         branches_.reserve(count); | ||||
|         for (size_t i = 0; i < count; i++) | ||||
|           push_back(f.branches_[i]->apply_f_op_g(*g.branches_[i], op)); | ||||
|         for (size_t i = 0; i < count; i++) { | ||||
|           NodePtr newBranch = f.branches_[i]->apply_f_op_g(*g.branches_[i], op); | ||||
|           push_back(std::move(newBranch)); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|  | @ -284,12 +287,12 @@ namespace gtsam { | |||
|     } | ||||
| 
 | ||||
|     /** add a branch: TODO merge into constructor */ | ||||
|     void push_back(const NodePtr& node) { | ||||
|     void push_back(NodePtr&& node) { | ||||
|       // allSame_ is restricted to leaf nodes in a decision tree
 | ||||
|       if (allSame_ && !branches_.empty()) { | ||||
|         allSame_ = node->sameLeaf(*branches_.back()); | ||||
|       } | ||||
|       branches_.push_back(node); | ||||
|       branches_.push_back(std::move(node)); | ||||
|     } | ||||
| 
 | ||||
|     /// print (as a tree).
 | ||||
|  | @ -497,9 +500,9 @@ namespace gtsam { | |||
|   DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) { | ||||
|     auto a = std::make_shared<Choice>(label, 2); | ||||
|     NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); | ||||
|     a->push_back(l1); | ||||
|     a->push_back(l2); | ||||
|     root_ = Choice::Unique(a); | ||||
|     a->push_back(std::move(l1)); | ||||
|     a->push_back(std::move(l2)); | ||||
|     root_ = Choice::Unique(std::move(a)); | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|  | @ -510,11 +513,10 @@ namespace gtsam { | |||
|         "DecisionTree: binary constructor called with non-binary label"); | ||||
|     auto a = std::make_shared<Choice>(labelC.first, 2); | ||||
|     NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); | ||||
|     a->push_back(l1); | ||||
|     a->push_back(l2); | ||||
|     root_ = Choice::Unique(a); | ||||
|     a->push_back(std::move(l1)); | ||||
|     a->push_back(std::move(l2)); | ||||
|     root_ = Choice::Unique(std::move(a)); | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|   template<typename L, typename Y> | ||||
|   DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs, | ||||
|  | @ -557,9 +559,7 @@ namespace gtsam { | |||
|   template <typename X, typename Func> | ||||
|   DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other, | ||||
|                                    Func Y_of_X) { | ||||
|     // Define functor for identity mapping of node label.
 | ||||
|     auto L_of_L = [](const L& label) { return label; }; | ||||
|     root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X); | ||||
|     root_ = convertFrom<X>(other.root_, Y_of_X); | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|  | @ -580,7 +580,7 @@ namespace gtsam { | |||
|   template <typename L, typename Y> | ||||
|   template <typename Iterator> | ||||
|   typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose( | ||||
|       Iterator begin, Iterator end, const L& label) const { | ||||
|       Iterator begin, Iterator end, const L& label) { | ||||
|     // find highest label among branches
 | ||||
|     std::optional<L> highestLabel; | ||||
|     size_t nrChoices = 0; | ||||
|  | @ -598,8 +598,10 @@ namespace gtsam { | |||
|     // if label is already in correct order, just put together a choice on label
 | ||||
|     if (!nrChoices || !highestLabel || label > *highestLabel) { | ||||
|       auto choiceOnLabel = std::make_shared<Choice>(label, end - begin); | ||||
|       for (Iterator it = begin; it != end; it++) | ||||
|         choiceOnLabel->push_back(it->root_); | ||||
|       for (Iterator it = begin; it != end; it++) { | ||||
|         NodePtr root = it->root_; | ||||
|         choiceOnLabel->push_back(std::move(root)); | ||||
|       } | ||||
|       // If no reordering, no need to call Choice::Unique
 | ||||
|       return choiceOnLabel; | ||||
|     } else { | ||||
|  | @ -618,7 +620,7 @@ namespace gtsam { | |||
|         } | ||||
|         // We then recurse, for all values of the highest label
 | ||||
|         NodePtr fi = compose(functions.begin(), functions.end(), label); | ||||
|         choiceOnHighestLabel->push_back(fi); | ||||
|         choiceOnHighestLabel->push_back(std::move(fi)); | ||||
|       } | ||||
|       return choiceOnHighestLabel; | ||||
|     } | ||||
|  | @ -648,7 +650,7 @@ namespace gtsam { | |||
|   template<typename L, typename Y> | ||||
|   template<typename It, typename ValueIt> | ||||
|   typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::build( | ||||
|       It begin, It end, ValueIt beginY, ValueIt endY) const { | ||||
|       It begin, It end, ValueIt beginY, ValueIt endY) { | ||||
|     // get crucial counts
 | ||||
|     size_t nrChoices = begin->second; | ||||
|     size_t size = endY - beginY; | ||||
|  | @ -675,6 +677,7 @@ namespace gtsam { | |||
|     // Creates one tree (i.e.,function) for each choice of current key
 | ||||
|     // by calling create recursively, and then puts them all together.
 | ||||
|     std::vector<DecisionTree> functions; | ||||
|     functions.reserve(nrChoices); | ||||
|     size_t split = size / nrChoices; | ||||
|     for (size_t i = 0; i < nrChoices; i++, beginY += split) { | ||||
|       NodePtr f = build<It, ValueIt>(labelC, end, beginY, beginY + split); | ||||
|  | @ -689,7 +692,7 @@ namespace gtsam { | |||
|   template<typename L, typename Y> | ||||
|   template<typename It, typename ValueIt> | ||||
|   typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create( | ||||
|       It begin, It end, ValueIt beginY, ValueIt endY) const { | ||||
|       It begin, It end, ValueIt beginY, ValueIt endY) { | ||||
|     auto node = build(begin, end, beginY, endY); | ||||
|     if (auto choice = std::dynamic_pointer_cast<const Choice>(node)) { | ||||
|       return Choice::Unique(choice); | ||||
|  | @ -698,17 +701,44 @@ namespace gtsam { | |||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|   template <typename L, typename Y> | ||||
|   template <typename X> | ||||
|   typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom( | ||||
|       const typename DecisionTree<L, X>::NodePtr& f, | ||||
|       std::function<Y(const X&)> Y_of_X) { | ||||
| 
 | ||||
|     // If leaf, apply unary conversion "op" and create a unique leaf.
 | ||||
|     using LXLeaf = typename DecisionTree<L, X>::Leaf; | ||||
|     if (auto leaf = std::dynamic_pointer_cast<const LXLeaf>(f)) { | ||||
|       return NodePtr(new Leaf(Y_of_X(leaf->constant()))); | ||||
|     } | ||||
| 
 | ||||
|     // Check if Choice
 | ||||
|     using LXChoice = typename DecisionTree<L, X>::Choice; | ||||
|     auto choice = std::dynamic_pointer_cast<const LXChoice>(f); | ||||
|     if (!choice) throw std::invalid_argument( | ||||
|         "DecisionTree::convertFrom: Invalid NodePtr"); | ||||
| 
 | ||||
|     // Create a new Choice node with the same label
 | ||||
|     auto newChoice = std::make_shared<Choice>(choice->label(), choice->nrChoices()); | ||||
| 
 | ||||
|     // Convert each branch recursively
 | ||||
|     for (auto&& branch : choice->branches()) { | ||||
|       newChoice->push_back(convertFrom<X>(branch, Y_of_X)); | ||||
|     } | ||||
| 
 | ||||
|     return Choice::Unique(newChoice); | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|   template <typename L, typename Y> | ||||
|   template <typename M, typename X> | ||||
|   typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom( | ||||
|       const typename DecisionTree<M, X>::NodePtr& f, | ||||
|       std::function<L(const M&)> L_of_M, | ||||
|       std::function<Y(const X&)> Y_of_X) const { | ||||
|       std::function<L(const M&)> L_of_M, std::function<Y(const X&)> Y_of_X) { | ||||
|     using LY = DecisionTree<L, Y>; | ||||
| 
 | ||||
|     // Ugliness below because apparently we can't have templated virtual
 | ||||
|     // functions.
 | ||||
|     // If leaf, apply unary conversion "op" and create a unique leaf.
 | ||||
|     using MXLeaf = typename DecisionTree<M, X>::Leaf; | ||||
|     if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) { | ||||
|  | @ -718,19 +748,27 @@ namespace gtsam { | |||
|     // Check if Choice
 | ||||
|     using MXChoice = typename DecisionTree<M, X>::Choice; | ||||
|     auto choice = std::dynamic_pointer_cast<const MXChoice>(f); | ||||
|     if (!choice) throw std::invalid_argument( | ||||
|         "DecisionTree::convertFrom: Invalid NodePtr"); | ||||
|     if (!choice) | ||||
|       throw std::invalid_argument("DecisionTree::convertFrom: Invalid NodePtr"); | ||||
| 
 | ||||
|     // get new label
 | ||||
|     const M oldLabel = choice->label(); | ||||
|     const L newLabel = L_of_M(oldLabel); | ||||
| 
 | ||||
|     // put together via Shannon expansion otherwise not sorted.
 | ||||
|     // Shannon expansion in this context involves:
 | ||||
|     // 1. Creating separate subtrees (functions) for each possible value of the new label.
 | ||||
|     // 2. Combining these subtrees using the 'compose' method, which implements the expansion.
 | ||||
|     // This approach guarantees that the resulting tree maintains the correct variable ordering
 | ||||
|     // based on the new labels (L) after translation from the old labels (M).
 | ||||
|     // Simply creating a Choice node here would not work because it wouldn't account for the
 | ||||
|     // potentially new ordering of variables resulting from the label translation,
 | ||||
|     // which is crucial for maintaining consistency and efficiency in the converted tree.
 | ||||
|     std::vector<LY> functions; | ||||
|     for (auto&& branch : choice->branches()) { | ||||
|       functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X)); | ||||
|     } | ||||
|     return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel)); | ||||
|     return Choice::Unique( | ||||
|         LY::compose(functions.begin(), functions.end(), newLabel)); | ||||
|   } | ||||
| 
 | ||||
|   /****************************************************************************/ | ||||
|  |  | |||
|  | @ -31,7 +31,6 @@ | |||
| #include <iostream> | ||||
| #include <map> | ||||
| #include <set> | ||||
| #include <sstream> | ||||
| #include <string> | ||||
| #include <utility> | ||||
| #include <vector> | ||||
|  | @ -155,7 +154,7 @@ namespace gtsam { | |||
|      * and Y values  | ||||
|      */ | ||||
|     template <typename It, typename ValueIt> | ||||
|     NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY) const; | ||||
|     static NodePtr build(It begin, It end, ValueIt beginY, ValueIt endY); | ||||
| 
 | ||||
|     /** Internal helper function to create from
 | ||||
|      * keys, cardinalities, and Y values. | ||||
|  | @ -163,7 +162,20 @@ namespace gtsam { | |||
|      * before we prune in a top-down fashion. | ||||
|      */ | ||||
|     template <typename It, typename ValueIt> | ||||
|     NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; | ||||
|     static NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY); | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Convert from a DecisionTree<L, X> to DecisionTree<L, Y>. | ||||
|      * | ||||
|      * @tparam M The previous label type. | ||||
|      * @tparam X The previous value type. | ||||
|      * @param f The node pointer to the root of the previous DecisionTree. | ||||
|      * @param Y_of_X Functor to convert from value type X to type Y. | ||||
|      * @return NodePtr | ||||
|      */ | ||||
|     template <typename X> | ||||
|     static NodePtr convertFrom(const typename DecisionTree<L, X>::NodePtr& f, | ||||
|                                std::function<Y(const X&)> Y_of_X); | ||||
| 
 | ||||
|     /**
 | ||||
|      * @brief Convert from a DecisionTree<M, X> to DecisionTree<L, Y>. | ||||
|  | @ -176,9 +188,9 @@ namespace gtsam { | |||
|      * @return NodePtr  | ||||
|      */ | ||||
|     template <typename M, typename X> | ||||
|     NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f, | ||||
|     static NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f, | ||||
|                                std::function<L(const M&)> L_of_M, | ||||
|                         std::function<Y(const X&)> Y_of_X) const; | ||||
|                                std::function<Y(const X&)> Y_of_X); | ||||
| 
 | ||||
|    public: | ||||
|     /// @name Standard Constructors
 | ||||
|  | @ -402,7 +414,7 @@ namespace gtsam { | |||
| 
 | ||||
|     // internal use only
 | ||||
|     template<typename Iterator> NodePtr | ||||
|     compose(Iterator begin, Iterator end, const L& label) const; | ||||
|     static compose(Iterator begin, Iterator end, const L& label); | ||||
| 
 | ||||
|     /// @}
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue