From f98b9223e8cbd8a39fee2d91d6459dae1542c3b7 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Tue, 15 Oct 2024 15:43:03 +0900 Subject: [PATCH] Make compose and convertFrom static --- gtsam/discrete/DecisionTree-inl.h | 28 ++++++++++++++++------------ gtsam/discrete/DecisionTree.h | 8 ++++---- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 27e98fcde..8be5efaa6 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -580,7 +580,7 @@ namespace gtsam { template template typename DecisionTree::NodePtr DecisionTree::compose( - Iterator begin, Iterator end, const L& label) const { + Iterator begin, Iterator end, const L& label) { // find highest label among branches std::optional highestLabel; size_t nrChoices = 0; @@ -703,12 +703,9 @@ namespace gtsam { template typename DecisionTree::NodePtr DecisionTree::convertFrom( const typename DecisionTree::NodePtr& f, - std::function L_of_M, - std::function Y_of_X) const { + std::function L_of_M, std::function Y_of_X) { using LY = DecisionTree; - // Ugliness below because apparently we can't have templated virtual - // functions. // If leaf, apply unary conversion "op" and create a unique leaf. using MXLeaf = typename DecisionTree::Leaf; if (auto leaf = std::dynamic_pointer_cast(f)) { @@ -718,19 +715,27 @@ namespace gtsam { // Check if Choice using MXChoice = typename DecisionTree::Choice; auto choice = std::dynamic_pointer_cast(f); - if (!choice) throw std::invalid_argument( - "DecisionTree::convertFrom: Invalid NodePtr"); + if (!choice) + throw std::invalid_argument("DecisionTree::convertFrom: Invalid NodePtr"); // get new label const M oldLabel = choice->label(); const L newLabel = L_of_M(oldLabel); - // put together via Shannon expansion otherwise not sorted. + // Shannon expansion in this context involves: + // 1. Creating separate subtrees (functions) for each possible value of the new label. + // 2. Combining these subtrees using the 'compose' method, which implements the expansion. + // This approach guarantees that the resulting tree maintains the correct variable ordering + // based on the new labels (L) after translation from the old labels (M). + // Simply creating a Choice node here would not work because it wouldn't account for the + // potentially new ordering of variables resulting from the label translation, + // which is crucial for maintaining consistency and efficiency in the converted tree. std::vector functions; for (auto&& branch : choice->branches()) { functions.emplace_back(convertFrom(branch, L_of_M, Y_of_X)); } - return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel)); + return Choice::Unique( + LY::compose(functions.begin(), functions.end(), newLabel)); } /****************************************************************************/ @@ -740,9 +745,8 @@ namespace gtsam { * * NOTE: We differentiate between leaves and assignments. Concretely, a 3 * binary variable tree will have 2^3=8 assignments, but based on pruning, it - * can have less than 8 leaves. For example, if a tree has all assignment - * values as 1, then pruning will cause the tree to have only 1 leaf yet 8 - * assignments. + * can have <8 leaves. For example, if a tree has all assignment values as 1, + * then pruning will cause the tree to have only 1 leaf yet 8 assignments. */ template struct Visit { diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 6d6179a7e..0d9db1fce 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -176,9 +176,9 @@ namespace gtsam { * @return NodePtr */ template - NodePtr convertFrom(const typename DecisionTree::NodePtr& f, - std::function L_of_M, - std::function Y_of_X) const; + static NodePtr convertFrom(const typename DecisionTree::NodePtr& f, + std::function L_of_M, + std::function Y_of_X); public: /// @name Standard Constructors @@ -402,7 +402,7 @@ namespace gtsam { // internal use only template NodePtr - compose(Iterator begin, Iterator end, const L& label) const; + static compose(Iterator begin, Iterator end, const L& label); /// @}