Merge pull request #1137 from borglab/decisiontree/apply-with-assignment
commit
e3d68e772e
|
@ -112,6 +112,13 @@ namespace gtsam {
|
|||
return f;
|
||||
}
|
||||
|
||||
/// Apply unary operator with assignment
|
||||
NodePtr apply(const UnaryAssignment& op,
|
||||
const Assignment<L>& choices) const override {
|
||||
NodePtr f(new Leaf(op(choices, constant_)));
|
||||
return f;
|
||||
}
|
||||
|
||||
// Apply binary operator "h = f op g" on Leaf node
|
||||
// Note op is not assumed commutative so we need to keep track of order
|
||||
// Simply calls apply on argument to call correct virtual method:
|
||||
|
@ -322,12 +329,48 @@ namespace gtsam {
|
|||
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Constructor which accepts a UnaryAssignment op and the
|
||||
* corresponding assignment.
|
||||
*
|
||||
* @param label The label for this node.
|
||||
* @param f The original choice node to apply the op on.
|
||||
* @param op Function to apply on the choice node. Takes Assignment and
|
||||
* value as arguments.
|
||||
* @param choices The Assignment that will go to op.
|
||||
*/
|
||||
Choice(const L& label, const Choice& f, const UnaryAssignment& op,
|
||||
const Assignment<L>& choices)
|
||||
: label_(label), allSame_(true) {
|
||||
branches_.reserve(f.branches_.size()); // reserve space
|
||||
|
||||
Assignment<L> choices_ = choices;
|
||||
|
||||
for (size_t i = 0; i < f.branches_.size(); i++) {
|
||||
choices_[label_] = i; // Set assignment for label to i
|
||||
|
||||
const NodePtr branch = f.branches_[i];
|
||||
push_back(branch->apply(op, choices_));
|
||||
|
||||
// Remove the choice so we are backtracking
|
||||
auto choice_it = choices_.find(label_);
|
||||
choices_.erase(choice_it);
|
||||
}
|
||||
}
|
||||
|
||||
/** apply unary operator */
|
||||
NodePtr apply(const Unary& op) const override {
|
||||
auto r = boost::make_shared<Choice>(label_, *this, op);
|
||||
return Unique(r);
|
||||
}
|
||||
|
||||
/// Apply unary operator with assignment
|
||||
NodePtr apply(const UnaryAssignment& op,
|
||||
const Assignment<L>& choices) const override {
|
||||
auto r = boost::make_shared<Choice>(label_, *this, op, choices);
|
||||
return Unique(r);
|
||||
}
|
||||
|
||||
// Apply binary operator "h = f op g" on Choice node
|
||||
// Note op is not assumed commutative so we need to keep track of order
|
||||
// Simply calls apply on argument to call correct virtual method:
|
||||
|
@ -739,6 +782,20 @@ namespace gtsam {
|
|||
return DecisionTree(root_->apply(op));
|
||||
}
|
||||
|
||||
/// Apply unary operator with assignment
|
||||
template <typename L, typename Y>
|
||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
|
||||
const UnaryAssignment& op) const {
|
||||
std::cout << "Calling the correct apply" << std::endl;
|
||||
// It is unclear what should happen if tree is empty:
|
||||
if (empty()) {
|
||||
throw std::runtime_error(
|
||||
"DecisionTree::apply(unary op) undefined for empty tree.");
|
||||
}
|
||||
Assignment<L> choices;
|
||||
return DecisionTree(root_->apply(op, choices));
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
||||
|
|
|
@ -54,6 +54,7 @@ namespace gtsam {
|
|||
|
||||
/** Handy typedefs for unary and binary function types */
|
||||
using Unary = std::function<Y(const Y&)>;
|
||||
using UnaryAssignment = std::function<Y(const Assignment<L>&, const Y&)>;
|
||||
using Binary = std::function<Y(const Y&, const Y&)>;
|
||||
|
||||
/** A label annotated with cardinality */
|
||||
|
@ -103,6 +104,8 @@ namespace gtsam {
|
|||
&DefaultCompare) const = 0;
|
||||
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
||||
virtual Ptr apply(const Unary& op) const = 0;
|
||||
virtual Ptr apply(const UnaryAssignment& op,
|
||||
const Assignment<L>& choices) const = 0;
|
||||
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
||||
virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
|
||||
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
||||
|
@ -283,6 +286,16 @@ namespace gtsam {
|
|||
/** apply Unary operation "op" to f */
|
||||
DecisionTree apply(const Unary& op) const;
|
||||
|
||||
/**
|
||||
* @brief Apply Unary operation "op" to f while also providing the
|
||||
* corresponding assignment.
|
||||
*
|
||||
* @param op Function which takes Assignment<L> and Y as input and returns
|
||||
* object of type Y.
|
||||
* @return DecisionTree
|
||||
*/
|
||||
DecisionTree apply(const UnaryAssignment& op) const;
|
||||
|
||||
/** apply binary operation "op" to f and g */
|
||||
DecisionTree apply(const DecisionTree& g, const Binary& op) const;
|
||||
|
||||
|
@ -337,6 +350,13 @@ namespace gtsam {
|
|||
return f.apply(op);
|
||||
}
|
||||
|
||||
/// Apply unary operator `op` with Assignment to DecisionTree `f`.
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
||||
const typename DecisionTree<L, Y>::UnaryAssignment& op) {
|
||||
return f.apply(op);
|
||||
}
|
||||
|
||||
/// Apply binary operator `op` to DecisionTree `f`.
|
||||
template<typename L, typename Y>
|
||||
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
||||
|
|
|
@ -90,6 +90,7 @@ struct DT : public DecisionTree<string, int> {
|
|||
auto valueFormatter = [](const int& v) {
|
||||
return (boost::format("%d") % v).str();
|
||||
};
|
||||
std::cout << s;
|
||||
Base::print("", keyFormatter, valueFormatter);
|
||||
}
|
||||
/// Equality method customized to int node type
|
||||
|
@ -451,6 +452,33 @@ TEST(DecisionTree, threshold) {
|
|||
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
|
||||
}
|
||||
|
||||
/* ************************************************************************** */
|
||||
// Test apply with assignment.
|
||||
TEST(DecisionTree, ApplyWithAssignment) {
|
||||
// Create three level tree
|
||||
vector<DT::LabelC> keys;
|
||||
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
|
||||
DT tree(keys, "1 2 3 4 5 6 7 8");
|
||||
|
||||
DecisionTree<string, double> probTree(
|
||||
keys, "0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08");
|
||||
double threshold = 0.035;
|
||||
|
||||
// We test pruning one tree by indexing into another.
|
||||
auto pruner = [&](const Assignment<string>& choices, const int& x) {
|
||||
// Prune out all the leaves with even numbers
|
||||
if (probTree(choices) < threshold) {
|
||||
return 0;
|
||||
} else {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
DT prunedTree = tree.apply(pruner);
|
||||
|
||||
DT expectedTree(keys, "0 0 0 4 5 6 7 8");
|
||||
EXPECT(assert_equal(expectedTree, prunedTree));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue