Merge pull request #1137 from borglab/decisiontree/apply-with-assignment

release/4.3a0
Varun Agrawal 2022-03-19 18:32:23 -04:00 committed by GitHub
commit e3d68e772e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 0 deletions

View File

@ -112,6 +112,13 @@ namespace gtsam {
return f;
}
/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
NodePtr f(new Leaf(op(choices, constant_)));
return f;
}
// Apply binary operator "h = f op g" on Leaf node
// Note op is not assumed commutative so we need to keep track of order
// Simply calls apply on argument to call correct virtual method:
@ -322,12 +329,48 @@ namespace gtsam {
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
}
/**
* @brief Constructor which accepts a UnaryAssignment op and the
* corresponding assignment.
*
* @param label The label for this node.
* @param f The original choice node to apply the op on.
* @param op Function to apply on the choice node. Takes Assignment and
* value as arguments.
* @param choices The Assignment that will go to op.
*/
Choice(const L& label, const Choice& f, const UnaryAssignment& op,
const Assignment<L>& choices)
: label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space
Assignment<L> choices_ = choices;
for (size_t i = 0; i < f.branches_.size(); i++) {
choices_[label_] = i; // Set assignment for label to i
const NodePtr branch = f.branches_[i];
push_back(branch->apply(op, choices_));
// Remove the choice so we are backtracking
auto choice_it = choices_.find(label_);
choices_.erase(choice_it);
}
}
/** apply unary operator */
NodePtr apply(const Unary& op) const override {
auto r = boost::make_shared<Choice>(label_, *this, op);
return Unique(r);
}
/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
auto r = boost::make_shared<Choice>(label_, *this, op, choices);
return Unique(r);
}
// Apply binary operator "h = f op g" on Choice node
// Note op is not assumed commutative so we need to keep track of order
// Simply calls apply on argument to call correct virtual method:
@ -739,6 +782,20 @@ namespace gtsam {
return DecisionTree(root_->apply(op));
}
/// Apply unary operator with assignment
template <typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
const UnaryAssignment& op) const {
std::cout << "Calling the correct apply" << std::endl;
// It is unclear what should happen if tree is empty:
if (empty()) {
throw std::runtime_error(
"DecisionTree::apply(unary op) undefined for empty tree.");
}
Assignment<L> choices;
return DecisionTree(root_->apply(op, choices));
}
/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,

View File

@ -54,6 +54,7 @@ namespace gtsam {
/** Handy typedefs for unary and binary function types */
using Unary = std::function<Y(const Y&)>;
using UnaryAssignment = std::function<Y(const Assignment<L>&, const Y&)>;
using Binary = std::function<Y(const Y&, const Y&)>;
/** A label annotated with cardinality */
@ -103,6 +104,8 @@ namespace gtsam {
&DefaultCompare) const = 0;
virtual const Y& operator()(const Assignment<L>& x) const = 0;
virtual Ptr apply(const Unary& op) const = 0;
virtual Ptr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const = 0;
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
@ -283,6 +286,16 @@ namespace gtsam {
/** apply Unary operation "op" to f */
DecisionTree apply(const Unary& op) const;
/**
* @brief Apply Unary operation "op" to f while also providing the
* corresponding assignment.
*
* @param op Function which takes Assignment<L> and Y as input and returns
* object of type Y.
* @return DecisionTree
*/
DecisionTree apply(const UnaryAssignment& op) const;
/** apply binary operation "op" to f and g */
DecisionTree apply(const DecisionTree& g, const Binary& op) const;
@ -337,6 +350,13 @@ namespace gtsam {
return f.apply(op);
}
/// Apply unary operator `op` with Assignment to DecisionTree `f`.
template<typename L, typename Y>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
const typename DecisionTree<L, Y>::UnaryAssignment& op) {
return f.apply(op);
}
/// Apply binary operator `op` to DecisionTree `f`.
template<typename L, typename Y>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,

View File

@ -90,6 +90,7 @@ struct DT : public DecisionTree<string, int> {
auto valueFormatter = [](const int& v) {
return (boost::format("%d") % v).str();
};
std::cout << s;
Base::print("", keyFormatter, valueFormatter);
}
/// Equality method customized to int node type
@ -451,6 +452,33 @@ TEST(DecisionTree, threshold) {
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
}
/* ************************************************************************** */
// Test apply with assignment.
TEST(DecisionTree, ApplyWithAssignment) {
// Create three level tree
vector<DT::LabelC> keys;
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
DT tree(keys, "1 2 3 4 5 6 7 8");
DecisionTree<string, double> probTree(
keys, "0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08");
double threshold = 0.035;
// We test pruning one tree by indexing into another.
auto pruner = [&](const Assignment<string>& choices, const int& x) {
// Prune out all the leaves with even numbers
if (probTree(choices) < threshold) {
return 0;
} else {
return x;
}
};
DT prunedTree = tree.apply(pruner);
DT expectedTree(keys, "0 0 0 4 5 6 7 8");
EXPECT(assert_equal(expectedTree, prunedTree));
}
/* ************************************************************************* */
int main() {
TestResult tr;