diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 627c1a5aa..3f82ce9a6 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -112,6 +112,13 @@ namespace gtsam { return f; } + /// Apply unary operator with assignment + NodePtr apply(const UnaryAssignment& op, + const Assignment& choices) const override { + NodePtr f(new Leaf(op(choices, constant_))); + return f; + } + // Apply binary operator "h = f op g" on Leaf node // Note op is not assumed commutative so we need to keep track of order // Simply calls apply on argument to call correct virtual method: @@ -322,12 +329,48 @@ namespace gtsam { for (const NodePtr& branch : f.branches_) push_back(branch->apply(op)); } + /** + * @brief Constructor which accepts a UnaryAssignment op and the + * corresponding assignment. + * + * @param label The label for this node. + * @param f The original choice node to apply the op on. + * @param op Function to apply on the choice node. Takes Assignment and + * value as arguments. + * @param choices The Assignment that will go to op. + */ + Choice(const L& label, const Choice& f, const UnaryAssignment& op, + const Assignment& choices) + : label_(label), allSame_(true) { + branches_.reserve(f.branches_.size()); // reserve space + + Assignment choices_ = choices; + + for (size_t i = 0; i < f.branches_.size(); i++) { + choices_[label_] = i; // Set assignment for label to i + + const NodePtr branch = f.branches_[i]; + push_back(branch->apply(op, choices_)); + + // Remove the choice so we are backtracking + auto choice_it = choices_.find(label_); + choices_.erase(choice_it); + } + } + /** apply unary operator */ NodePtr apply(const Unary& op) const override { auto r = boost::make_shared(label_, *this, op); return Unique(r); } + /// Apply unary operator with assignment + NodePtr apply(const UnaryAssignment& op, + const Assignment& choices) const override { + auto r = boost::make_shared(label_, *this, op, choices); + return Unique(r); + } + // Apply binary operator "h = f op g" on Choice node // Note op is not assumed commutative so we need to keep track of order // Simply calls apply on argument to call correct virtual method: @@ -739,6 +782,20 @@ namespace gtsam { return DecisionTree(root_->apply(op)); } + /// Apply unary operator with assignment + template + DecisionTree DecisionTree::apply( + const UnaryAssignment& op) const { + std::cout << "Calling the correct apply" << std::endl; + // It is unclear what should happen if tree is empty: + if (empty()) { + throw std::runtime_error( + "DecisionTree::apply(unary op) undefined for empty tree."); + } + Assignment choices; + return DecisionTree(root_->apply(op, choices)); + } + /****************************************************************************/ template DecisionTree DecisionTree::apply(const DecisionTree& g, diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index d655756b8..13ff0a8c6 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -54,6 +54,7 @@ namespace gtsam { /** Handy typedefs for unary and binary function types */ using Unary = std::function; + using UnaryAssignment = std::function&, const Y&)>; using Binary = std::function; /** A label annotated with cardinality */ @@ -103,6 +104,8 @@ namespace gtsam { &DefaultCompare) const = 0; virtual const Y& operator()(const Assignment& x) const = 0; virtual Ptr apply(const Unary& op) const = 0; + virtual Ptr apply(const UnaryAssignment& op, + const Assignment& choices) const = 0; virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0; virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; @@ -283,6 +286,16 @@ namespace gtsam { /** apply Unary operation "op" to f */ DecisionTree apply(const Unary& op) const; + /** + * @brief Apply Unary operation "op" to f while also providing the + * corresponding assignment. + * + * @param op Function which takes Assignment and Y as input and returns + * object of type Y. + * @return DecisionTree + */ + DecisionTree apply(const UnaryAssignment& op) const; + /** apply binary operation "op" to f and g */ DecisionTree apply(const DecisionTree& g, const Binary& op) const; @@ -337,6 +350,13 @@ namespace gtsam { return f.apply(op); } + /// Apply unary operator `op` with Assignment to DecisionTree `f`. + template + DecisionTree apply(const DecisionTree& f, + const typename DecisionTree::UnaryAssignment& op) { + return f.apply(op); + } + /// Apply binary operator `op` to DecisionTree `f`. template DecisionTree apply(const DecisionTree& f, diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 91deed625..935d433c6 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -90,6 +90,7 @@ struct DT : public DecisionTree { auto valueFormatter = [](const int& v) { return (boost::format("%d") % v).str(); }; + std::cout << s; Base::print("", keyFormatter, valueFormatter); } /// Equality method customized to int node type @@ -451,6 +452,33 @@ TEST(DecisionTree, threshold) { EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0)); } +/* ************************************************************************** */ +// Test apply with assignment. +TEST(DecisionTree, ApplyWithAssignment) { + // Create three level tree + vector keys; + keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2); + DT tree(keys, "1 2 3 4 5 6 7 8"); + + DecisionTree probTree( + keys, "0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08"); + double threshold = 0.035; + + // We test pruning one tree by indexing into another. + auto pruner = [&](const Assignment& choices, const int& x) { + // Prune out all the leaves with even numbers + if (probTree(choices) < threshold) { + return 0; + } else { + return x; + } + }; + DT prunedTree = tree.apply(pruner); + + DT expectedTree(keys, "0 0 0 4 5 6 7 8"); + EXPECT(assert_equal(expectedTree, prunedTree)); +} + /* ************************************************************************* */ int main() { TestResult tr;