Merge pull request #1137 from borglab/decisiontree/apply-with-assignment
commit
e3d68e772e
|
@ -112,6 +112,13 @@ namespace gtsam {
|
||||||
return f;
|
return f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Apply unary operator with assignment
|
||||||
|
NodePtr apply(const UnaryAssignment& op,
|
||||||
|
const Assignment<L>& choices) const override {
|
||||||
|
NodePtr f(new Leaf(op(choices, constant_)));
|
||||||
|
return f;
|
||||||
|
}
|
||||||
|
|
||||||
// Apply binary operator "h = f op g" on Leaf node
|
// Apply binary operator "h = f op g" on Leaf node
|
||||||
// Note op is not assumed commutative so we need to keep track of order
|
// Note op is not assumed commutative so we need to keep track of order
|
||||||
// Simply calls apply on argument to call correct virtual method:
|
// Simply calls apply on argument to call correct virtual method:
|
||||||
|
@ -322,12 +329,48 @@ namespace gtsam {
|
||||||
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
|
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Constructor which accepts a UnaryAssignment op and the
|
||||||
|
* corresponding assignment.
|
||||||
|
*
|
||||||
|
* @param label The label for this node.
|
||||||
|
* @param f The original choice node to apply the op on.
|
||||||
|
* @param op Function to apply on the choice node. Takes Assignment and
|
||||||
|
* value as arguments.
|
||||||
|
* @param choices The Assignment that will go to op.
|
||||||
|
*/
|
||||||
|
Choice(const L& label, const Choice& f, const UnaryAssignment& op,
|
||||||
|
const Assignment<L>& choices)
|
||||||
|
: label_(label), allSame_(true) {
|
||||||
|
branches_.reserve(f.branches_.size()); // reserve space
|
||||||
|
|
||||||
|
Assignment<L> choices_ = choices;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < f.branches_.size(); i++) {
|
||||||
|
choices_[label_] = i; // Set assignment for label to i
|
||||||
|
|
||||||
|
const NodePtr branch = f.branches_[i];
|
||||||
|
push_back(branch->apply(op, choices_));
|
||||||
|
|
||||||
|
// Remove the choice so we are backtracking
|
||||||
|
auto choice_it = choices_.find(label_);
|
||||||
|
choices_.erase(choice_it);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** apply unary operator */
|
/** apply unary operator */
|
||||||
NodePtr apply(const Unary& op) const override {
|
NodePtr apply(const Unary& op) const override {
|
||||||
auto r = boost::make_shared<Choice>(label_, *this, op);
|
auto r = boost::make_shared<Choice>(label_, *this, op);
|
||||||
return Unique(r);
|
return Unique(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Apply unary operator with assignment
|
||||||
|
NodePtr apply(const UnaryAssignment& op,
|
||||||
|
const Assignment<L>& choices) const override {
|
||||||
|
auto r = boost::make_shared<Choice>(label_, *this, op, choices);
|
||||||
|
return Unique(r);
|
||||||
|
}
|
||||||
|
|
||||||
// Apply binary operator "h = f op g" on Choice node
|
// Apply binary operator "h = f op g" on Choice node
|
||||||
// Note op is not assumed commutative so we need to keep track of order
|
// Note op is not assumed commutative so we need to keep track of order
|
||||||
// Simply calls apply on argument to call correct virtual method:
|
// Simply calls apply on argument to call correct virtual method:
|
||||||
|
@ -739,6 +782,20 @@ namespace gtsam {
|
||||||
return DecisionTree(root_->apply(op));
|
return DecisionTree(root_->apply(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Apply unary operator with assignment
|
||||||
|
template <typename L, typename Y>
|
||||||
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
|
||||||
|
const UnaryAssignment& op) const {
|
||||||
|
std::cout << "Calling the correct apply" << std::endl;
|
||||||
|
// It is unclear what should happen if tree is empty:
|
||||||
|
if (empty()) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"DecisionTree::apply(unary op) undefined for empty tree.");
|
||||||
|
}
|
||||||
|
Assignment<L> choices;
|
||||||
|
return DecisionTree(root_->apply(op, choices));
|
||||||
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
||||||
|
|
|
@ -54,6 +54,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/** Handy typedefs for unary and binary function types */
|
/** Handy typedefs for unary and binary function types */
|
||||||
using Unary = std::function<Y(const Y&)>;
|
using Unary = std::function<Y(const Y&)>;
|
||||||
|
using UnaryAssignment = std::function<Y(const Assignment<L>&, const Y&)>;
|
||||||
using Binary = std::function<Y(const Y&, const Y&)>;
|
using Binary = std::function<Y(const Y&, const Y&)>;
|
||||||
|
|
||||||
/** A label annotated with cardinality */
|
/** A label annotated with cardinality */
|
||||||
|
@ -103,6 +104,8 @@ namespace gtsam {
|
||||||
&DefaultCompare) const = 0;
|
&DefaultCompare) const = 0;
|
||||||
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
virtual const Y& operator()(const Assignment<L>& x) const = 0;
|
||||||
virtual Ptr apply(const Unary& op) const = 0;
|
virtual Ptr apply(const Unary& op) const = 0;
|
||||||
|
virtual Ptr apply(const UnaryAssignment& op,
|
||||||
|
const Assignment<L>& choices) const = 0;
|
||||||
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
|
||||||
virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
|
virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
|
||||||
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
||||||
|
@ -283,6 +286,16 @@ namespace gtsam {
|
||||||
/** apply Unary operation "op" to f */
|
/** apply Unary operation "op" to f */
|
||||||
DecisionTree apply(const Unary& op) const;
|
DecisionTree apply(const Unary& op) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Apply Unary operation "op" to f while also providing the
|
||||||
|
* corresponding assignment.
|
||||||
|
*
|
||||||
|
* @param op Function which takes Assignment<L> and Y as input and returns
|
||||||
|
* object of type Y.
|
||||||
|
* @return DecisionTree
|
||||||
|
*/
|
||||||
|
DecisionTree apply(const UnaryAssignment& op) const;
|
||||||
|
|
||||||
/** apply binary operation "op" to f and g */
|
/** apply binary operation "op" to f and g */
|
||||||
DecisionTree apply(const DecisionTree& g, const Binary& op) const;
|
DecisionTree apply(const DecisionTree& g, const Binary& op) const;
|
||||||
|
|
||||||
|
@ -337,6 +350,13 @@ namespace gtsam {
|
||||||
return f.apply(op);
|
return f.apply(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Apply unary operator `op` with Assignment to DecisionTree `f`.
|
||||||
|
template<typename L, typename Y>
|
||||||
|
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
||||||
|
const typename DecisionTree<L, Y>::UnaryAssignment& op) {
|
||||||
|
return f.apply(op);
|
||||||
|
}
|
||||||
|
|
||||||
/// Apply binary operator `op` to DecisionTree `f`.
|
/// Apply binary operator `op` to DecisionTree `f`.
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
|
||||||
|
|
|
@ -90,6 +90,7 @@ struct DT : public DecisionTree<string, int> {
|
||||||
auto valueFormatter = [](const int& v) {
|
auto valueFormatter = [](const int& v) {
|
||||||
return (boost::format("%d") % v).str();
|
return (boost::format("%d") % v).str();
|
||||||
};
|
};
|
||||||
|
std::cout << s;
|
||||||
Base::print("", keyFormatter, valueFormatter);
|
Base::print("", keyFormatter, valueFormatter);
|
||||||
}
|
}
|
||||||
/// Equality method customized to int node type
|
/// Equality method customized to int node type
|
||||||
|
@ -451,6 +452,33 @@ TEST(DecisionTree, threshold) {
|
||||||
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
|
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test apply with assignment.
|
||||||
|
TEST(DecisionTree, ApplyWithAssignment) {
|
||||||
|
// Create three level tree
|
||||||
|
vector<DT::LabelC> keys;
|
||||||
|
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
|
||||||
|
DT tree(keys, "1 2 3 4 5 6 7 8");
|
||||||
|
|
||||||
|
DecisionTree<string, double> probTree(
|
||||||
|
keys, "0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08");
|
||||||
|
double threshold = 0.035;
|
||||||
|
|
||||||
|
// We test pruning one tree by indexing into another.
|
||||||
|
auto pruner = [&](const Assignment<string>& choices, const int& x) {
|
||||||
|
// Prune out all the leaves with even numbers
|
||||||
|
if (probTree(choices) < threshold) {
|
||||||
|
return 0;
|
||||||
|
} else {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
DT prunedTree = tree.apply(pruner);
|
||||||
|
|
||||||
|
DT expectedTree(keys, "0 0 0 4 5 6 7 8");
|
||||||
|
EXPECT(assert_equal(expectedTree, prunedTree));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue