diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h index acdbf63a3..72ea5e79f 100644 --- a/gtsam/discrete/AlgebraicDecisionTree.h +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -108,9 +108,12 @@ namespace gtsam { /** Convert */ template AlgebraicDecisionTree(const AlgebraicDecisionTree& other, - const std::map& map) { - this->root_ = this->template convert(other.root_, map, - Ring::id); + const std::map& map) { + std::function map_function = [&map](const M& label) -> L { + return map.at(label); + }; + std::function op = Ring::id; + this->root_ = this->template convert(other.root_, op, map_function); } /** sum */ diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 209c2ad80..96f1421ce 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -453,20 +453,24 @@ namespace gtsam { root_ = compose(functions.begin(), functions.end(), label); } - /*********************************************************************************/ - template - template - DecisionTree::DecisionTree(const DecisionTree& other, - const std::map& map, std::function op) { - root_ = convert(other.root_, map, op); - } - /*********************************************************************************/ template template DecisionTree::DecisionTree(const DecisionTree& other, std::function op) { - root_ = convert(other.root_, op); + auto map = [](const L& label) { return label; }; + root_ = convert(other.root_, op, map); + } + + /*********************************************************************************/ + template + template + DecisionTree::DecisionTree(const DecisionTree& other, + const std::map& map, std::function op) { + std::function map_function = [&map](const M& label) -> L { + return map.at(label); + }; + root_ = convert(other.root_, op, map_function); } /*********************************************************************************/ @@ -579,12 +583,11 @@ namespace gtsam { } /*********************************************************************************/ - template - template + template + template typename DecisionTree::NodePtr DecisionTree::convert( - const typename DecisionTree::NodePtr& f, const std::map& map, - std::function op) { - + const typename DecisionTree::NodePtr& f, + std::function op, std::function map) { typedef DecisionTree MX; typedef typename MX::Leaf MXLeaf; typedef typename MX::Choice MXChoice; @@ -602,50 +605,18 @@ namespace gtsam { "DecisionTree::Convert: Invalid NodePtr"); // get new label - M oldLabel = choice->label(); - L newLabel = map.at(oldLabel); + const M oldLabel = choice->label(); + const L newLabel = map(oldLabel); // put together via Shannon expansion otherwise not sorted. std::vector functions; for(const MXNodePtr& branch: choice->branches()) { - LY converted(convert(branch, map, op)); + LY converted(convert(branch, op, map)); functions += converted; } return LY::compose(functions.begin(), functions.end(), newLabel); } - /*********************************************************************************/ - template - template - typename DecisionTree::NodePtr DecisionTree::convert( - const typename DecisionTree::NodePtr& f, - std::function op) { - - typedef DecisionTree LX; - typedef typename LX::Leaf LXLeaf; - typedef typename LX::Choice LXChoice; - typedef typename LX::NodePtr LXNodePtr; - typedef DecisionTree LY; - - // ugliness below because apparently we can't have templated virtual functions - // If leaf, apply unary conversion "op" and create a unique leaf - const LXLeaf* leaf = dynamic_cast (f.get()); - if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); - - // Check if Choice - boost::shared_ptr choice = boost::dynamic_pointer_cast (f); - if (!choice) throw std::invalid_argument( - "DecisionTree::Convert: Invalid NodePtr"); - - // put together via Shannon expansion otherwise not sorted. - std::vector functions; - for(const LXNodePtr& branch: choice->branches()) { - LY converted(convert(branch, op)); - functions += converted; - } - return LY::compose(functions.begin(), functions.end(), choice->label()); - } - /*********************************************************************************/ template bool DecisionTree::equals(const DecisionTree& other, double tol, diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 26817bf79..baf2a79fa 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -127,15 +127,11 @@ namespace gtsam { template NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; - /** Convert to a different type */ - template NodePtr - convert(const typename DecisionTree::NodePtr& f, const std::map& map, std::function op); - - /** Convert only node to a different type */ - template - NodePtr convert(const typename DecisionTree::NodePtr& f, - const std::function op); + /// Convert to a different type, will not convert label if map empty. + template + NodePtr convert(const typename DecisionTree::NodePtr& f, + std::function op, + std::function map); public: @@ -168,16 +164,16 @@ namespace gtsam { DecisionTree(const L& label, // const DecisionTree& f0, const DecisionTree& f1); - /** Convert from a different type */ - template - DecisionTree(const DecisionTree& other, - const std::map& map, std::function op); - - /** Convert only nodes from a different type */ + /** Convert from a different type. */ template DecisionTree(const DecisionTree& other, std::function op); + /** Convert from a different type, also transate labels via map. */ + template + DecisionTree(const DecisionTree& other, + const std::map& map, std::function op); + /// @} /// @name Testable /// @{