Merge pull request #1056 from borglab/feature/dt_threshold

release/4.3a0
Frank Dellaert 2022-01-22 21:36:35 -05:00 committed by GitHub
commit 3d86bc7294
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 461 additions and 450 deletions

View File

@ -18,8 +18,13 @@
#pragma once #pragma once
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
#include <algorithm>
#include <map>
#include <string>
#include <vector>
namespace gtsam { namespace gtsam {
/** /**
@ -27,10 +32,11 @@ namespace gtsam {
* Just has some nice constructors and some syntactic sugar * Just has some nice constructors and some syntactic sugar
* TODO: consider eliminating this class altogether? * TODO: consider eliminating this class altogether?
*/ */
template<typename L> template <typename L>
class GTSAM_EXPORT AlgebraicDecisionTree: public DecisionTree<L, double> { class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
/** /**
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing. * @brief Default method used by `labelFormatter` or `valueFormatter` when
* printing.
* *
* @param x The value passed to format. * @param x The value passed to format.
* @return std::string * @return std::string
@ -42,17 +48,12 @@ namespace gtsam {
} }
public: public:
using Base = DecisionTree<L, double>; using Base = DecisionTree<L, double>;
/** The Real ring with addition and multiplication */ /** The Real ring with addition and multiplication */
struct Ring { struct Ring {
static inline double zero() { static inline double zero() { return 0.0; }
return 0.0; static inline double one() { return 1.0; }
}
static inline double one() {
return 1.0;
}
static inline double add(const double& a, const double& b) { static inline double add(const double& a, const double& b) {
return a + b; return a + b;
} }
@ -65,54 +66,50 @@ namespace gtsam {
static inline double div(const double& a, const double& b) { static inline double div(const double& a, const double& b) {
return a / b; return a / b;
} }
static inline double id(const double& x) { static inline double id(const double& x) { return x; }
return x;
}
}; };
AlgebraicDecisionTree() : AlgebraicDecisionTree() : Base(1.0) {}
Base(1.0) {
}
AlgebraicDecisionTree(const Base& add) : // Explicitly non-explicit constructor
Base(add) { AlgebraicDecisionTree(const Base& add) : Base(add) {}
}
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const L& label, double y1, double y2) : AlgebraicDecisionTree(const L& label, double y1, double y2)
Base(label, y1, y2) { : Base(label, y1, y2) {}
}
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) : AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
Base(labelC, y1, y2) { double y2)
} : Base(labelC, y1, y2) {}
/** Create from keys and vector table */ /** Create from keys and vector table */
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) { (const std::vector<typename Base::LabelC>& labelCs,
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), const std::vector<double>& ys) {
ys.end()); this->root_ =
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create from keys and string table */ /** Create from keys and string table */
AlgebraicDecisionTree // AlgebraicDecisionTree //
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) { (const std::vector<typename Base::LabelC>& labelCs,
const std::string& table) {
// Convert string to doubles // Convert string to doubles
std::vector<double> ys; std::vector<double> ys;
std::istringstream iss(table); std::istringstream iss(table);
std::copy(std::istream_iterator<double>(iss), std::copy(std::istream_iterator<double>(iss),
std::istream_iterator<double>(), std::back_inserter(ys)); std::istream_iterator<double>(), std::back_inserter(ys));
// now call recursive Create // now call recursive Create
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(), this->root_ =
ys.end()); Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/** Create a new function splitting on a variable */ /** Create a new function splitting on a variable */
template<typename Iterator> template <typename Iterator>
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
Base(nullptr) { : Base(nullptr) {
this->root_ = compose(begin, end, label); this->root_ = compose(begin, end, label);
} }
@ -122,7 +119,7 @@ namespace gtsam {
* @param other: The AlgebraicDecisionTree with label type M to convert. * @param other: The AlgebraicDecisionTree with label type M to convert.
* @param map: Map from label type M to label type L. * @param map: Map from label type M to label type L.
*/ */
template<typename M> template <typename M>
AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other, AlgebraicDecisionTree(const AlgebraicDecisionTree<M>& other,
const std::map<M, L>& map) { const std::map<M, L>& map) {
// Functor for label conversion so we can use `convertFrom`. // Functor for label conversion so we can use `convertFrom`.
@ -160,8 +157,8 @@ namespace gtsam {
/// print method customized to value type `double`. /// print method customized to value type `double`.
void print(const std::string& s, void print(const std::string& s,
const typename Base::LabelFormatter& labelFormatter = const typename Base::LabelFormatter& labelFormatter =
&DefaultFormatter) const { &DefaultFormatter) const {
auto valueFormatter = [](const double& v) { auto valueFormatter = [](const double& v) {
return (boost::format("%4.4g") % v).str(); return (boost::format("%4.4g") % v).str();
}; };
@ -177,8 +174,8 @@ namespace gtsam {
return Base::equals(other, compare); return Base::equals(other, compare);
} }
}; };
// AlgebraicDecisionTree
template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {}; template <typename T>
} struct traits<AlgebraicDecisionTree<T>>
// namespace gtsam : public Testable<AlgebraicDecisionTree<T>> {};
} // namespace gtsam

View File

@ -21,42 +21,44 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <algorithm>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
#include <boost/format.hpp> #include <boost/format.hpp>
#include <boost/make_shared.hpp>
#include <boost/noncopyable.hpp> #include <boost/noncopyable.hpp>
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include <boost/tuple/tuple.hpp> #include <boost/tuple/tuple.hpp>
#include <boost/type_traits/has_dereference.hpp> #include <boost/type_traits/has_dereference.hpp>
#include <boost/unordered_set.hpp> #include <boost/unordered_set.hpp>
#include <boost/make_shared.hpp>
#include <cmath> #include <cmath>
#include <fstream> #include <fstream>
#include <list> #include <list>
#include <map>
#include <set>
#include <sstream> #include <sstream>
#include <string>
#include <vector>
using boost::assign::operator+=; using boost::assign::operator+=;
namespace gtsam { namespace gtsam {
/*********************************************************************************/ /****************************************************************************/
// Node // Node
/*********************************************************************************/ /****************************************************************************/
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
template<typename L, typename Y> template<typename L, typename Y>
int DecisionTree<L, Y>::Node::nrNodes = 0; int DecisionTree<L, Y>::Node::nrNodes = 0;
#endif #endif
/*********************************************************************************/ /****************************************************************************/
// Leaf // Leaf
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template <typename L, typename Y>
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node { struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
/** constant stored in this leaf */ /** constant stored in this leaf */
Y constant_; Y constant_;
public:
/** Constructor from constant */ /** Constructor from constant */
Leaf(const Y& constant) : Leaf(const Y& constant) :
constant_(constant) {} constant_(constant) {}
@ -96,7 +98,7 @@ namespace gtsam {
std::string value = valueFormatter(constant_); std::string value = valueFormatter(constant_);
if (showZero || value.compare("0")) if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
} }
/** evaluate */ /** evaluate */
@ -121,13 +123,13 @@ namespace gtsam {
// Applying binary operator to two leaves results in a leaf // Applying binary operator to two leaves results in a leaf
NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override { NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const override {
NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL
return h; return h;
} }
// If second argument is a Choice node, call it's apply with leaf as second // If second argument is a Choice node, call it's apply with leaf as second
NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override { NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const override {
return fC.apply_fC_op_gL(*this, op); // operand order back to normal return fC.apply_fC_op_gL(*this, op); // operand order back to normal
} }
/** choose a branch, create new memory ! */ /** choose a branch, create new memory ! */
@ -136,32 +138,30 @@ namespace gtsam {
} }
bool isLeaf() const override { return true; } bool isLeaf() const override { return true; }
}; // Leaf
}; // Leaf /****************************************************************************/
/*********************************************************************************/
// Choice // Choice
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node { struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
/** the label of the variable on which we split */ /** the label of the variable on which we split */
L label_; L label_;
/** The children of this Choice node. */ /** The children of this Choice node. */
std::vector<NodePtr> branches_; std::vector<NodePtr> branches_;
private: private:
/** incremental allSame */ /** incremental allSame */
size_t allSame_; size_t allSame_;
using ChoicePtr = boost::shared_ptr<const Choice>; using ChoicePtr = boost::shared_ptr<const Choice>;
public: public:
~Choice() override { ~Choice() override {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl; std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
<< std::std::endl;
#endif #endif
} }
@ -172,7 +172,8 @@ namespace gtsam {
assert(f->branches().size() > 0); assert(f->branches().size() > 0);
NodePtr f0 = f->branches_[0]; NodePtr f0 = f->branches_[0];
assert(f0->isLeaf()); assert(f0->isLeaf());
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant())); NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
return newLeaf; return newLeaf;
} else } else
#endif #endif
@ -192,7 +193,6 @@ namespace gtsam {
*/ */
Choice(const Choice& f, const Choice& g, const Binary& op) : Choice(const Choice& f, const Choice& g, const Binary& op) :
allSame_(true) { allSame_(true) {
// Choose what to do based on label // Choose what to do based on label
if (f.label() > g.label()) { if (f.label() > g.label()) {
// f higher than g // f higher than g
@ -318,10 +318,8 @@ namespace gtsam {
*/ */
Choice(const L& label, const Choice& f, const Unary& op) : Choice(const L& label, const Choice& f, const Unary& op) :
label_(label), allSame_(true) { label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space
branches_.reserve(f.branches_.size()); // reserve space for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
for (const NodePtr& branch: f.branches_)
push_back(branch->apply(op));
} }
/** apply unary operator */ /** apply unary operator */
@ -364,8 +362,7 @@ namespace gtsam {
/** choose a branch, recursively */ /** choose a branch, recursively */
NodePtr choose(const L& label, size_t index) const override { NodePtr choose(const L& label, size_t index) const override {
if (label_ == label) if (label_ == label) return branches_[index]; // choose branch
return branches_[index]; // choose branch
// second case, not label of interest, just recurse // second case, not label of interest, just recurse
auto r = boost::make_shared<Choice>(label_, branches_.size()); auto r = boost::make_shared<Choice>(label_, branches_.size());
@ -373,12 +370,11 @@ namespace gtsam {
r->push_back(branch->choose(label, index)); r->push_back(branch->choose(label, index));
return Unique(r); return Unique(r);
} }
}; // Choice
}; // Choice /****************************************************************************/
/*********************************************************************************/
// DecisionTree // DecisionTree
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree() { DecisionTree<L, Y>::DecisionTree() {
} }
@ -388,13 +384,13 @@ namespace gtsam {
root_(root) { root_(root) {
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const Y& y) { DecisionTree<L, Y>::DecisionTree(const Y& y) {
root_ = NodePtr(new Leaf(y)); root_ = NodePtr(new Leaf(y));
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) { DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
auto a = boost::make_shared<Choice>(label, 2); auto a = boost::make_shared<Choice>(label, 2);
@ -404,7 +400,7 @@ namespace gtsam {
root_ = Choice::Unique(a); root_ = Choice::Unique(a);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1, DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
const Y& y2) { const Y& y2) {
@ -417,7 +413,7 @@ namespace gtsam {
root_ = Choice::Unique(a); root_ = Choice::Unique(a);
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs, DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::vector<Y>& ys) { const std::vector<Y>& ys) {
@ -425,29 +421,28 @@ namespace gtsam {
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs, DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
const std::string& table) { const std::string& table) {
// Convert std::string to values of type Y // Convert std::string to values of type Y
std::vector<Y> ys; std::vector<Y> ys;
std::istringstream iss(table); std::istringstream iss(table);
copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(), copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
back_inserter(ys)); back_inserter(ys));
// now call recursive Create // now call recursive Create
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
template<typename Iterator> DecisionTree<L, Y>::DecisionTree( template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
Iterator begin, Iterator end, const L& label) { Iterator begin, Iterator end, const L& label) {
root_ = compose(begin, end, label); root_ = compose(begin, end, label);
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y>::DecisionTree(const L& label, DecisionTree<L, Y>::DecisionTree(const L& label,
const DecisionTree& f0, const DecisionTree& f1) { const DecisionTree& f0, const DecisionTree& f1) {
@ -456,17 +451,17 @@ namespace gtsam {
root_ = compose(functions.begin(), functions.end(), label); root_ = compose(functions.begin(), functions.end(), label);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename X, typename Func> template <typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
Func Y_of_X) { Func Y_of_X) {
// Define functor for identity mapping of node label. // Define functor for identity mapping of node label.
auto L_of_L = [](const L& label) { return label; }; auto L_of_L = [](const L& label) { return label; };
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X); root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X, typename Func> template <typename M, typename X, typename Func>
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other, DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
@ -475,16 +470,16 @@ namespace gtsam {
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X); root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
} }
/*********************************************************************************/ /****************************************************************************/
// Called by two constructors above. // Called by two constructors above.
// Takes a label and a corresponding range of decision trees, and creates a new // Takes a label and a corresponding range of decision trees, and creates a
// decision tree. However, the order of the labels needs to be respected, so we // new decision tree. However, the order of the labels needs to be respected,
// cannot just create a root Choice node on the label: if the label is not the // so we cannot just create a root Choice node on the label: if the label is
// highest label, we need to do a complicated and expensive recursive call. // not the highest label, we need a complicated/ expensive recursive call.
template<typename L, typename Y> template<typename Iterator> template <typename L, typename Y>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin, template <typename Iterator>
Iterator end, const L& label) const { typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
Iterator begin, Iterator end, const L& label) const {
// find highest label among branches // find highest label among branches
boost::optional<L> highestLabel; boost::optional<L> highestLabel;
size_t nrChoices = 0; size_t nrChoices = 0;
@ -527,7 +522,7 @@ namespace gtsam {
} }
} }
/*********************************************************************************/ /****************************************************************************/
// "create" is a bit of a complicated thing, but very useful. // "create" is a bit of a complicated thing, but very useful.
// It takes a range of labels and a corresponding range of values, // It takes a range of labels and a corresponding range of values,
// and creates a decision tree, as follows: // and creates a decision tree, as follows:
@ -552,7 +547,6 @@ namespace gtsam {
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
It begin, It end, ValueIt beginY, ValueIt endY) const { It begin, It end, ValueIt beginY, ValueIt endY) const {
// get crucial counts // get crucial counts
size_t nrChoices = begin->second; size_t nrChoices = begin->second;
size_t size = endY - beginY; size_t size = endY - beginY;
@ -564,7 +558,11 @@ namespace gtsam {
// Create a simple choice node with values as leaves. // Create a simple choice node with values as leaves.
if (size != nrChoices) { if (size != nrChoices) {
std::cout << "Trying to create DD on " << begin->first << std::endl; std::cout << "Trying to create DD on " << begin->first << std::endl;
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl; std::cout << boost::format(
"DecisionTree::create: expected %d values but got %d "
"instead") %
nrChoices % size
<< std::endl;
throw std::invalid_argument("DecisionTree::create invalid argument"); throw std::invalid_argument("DecisionTree::create invalid argument");
} }
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY); auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
@ -585,7 +583,7 @@ namespace gtsam {
return compose(functions.begin(), functions.end(), begin->first); return compose(functions.begin(), functions.end(), begin->first);
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
template <typename M, typename X> template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
@ -594,11 +592,11 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const { std::function<Y(const X&)> Y_of_X) const {
using LY = DecisionTree<L, Y>; using LY = DecisionTree<L, Y>;
// ugliness below because apparently we can't have templated virtual functions // ugliness below because apparently we can't have templated virtual
// If leaf, apply unary conversion "op" and create a unique leaf // functions If leaf, apply unary conversion "op" and create a unique leaf
using MXLeaf = typename DecisionTree<M, X>::Leaf; using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f)) if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
return NodePtr(new Leaf(Y_of_X(leaf->constant()))); return NodePtr(new Leaf(Y_of_X(leaf->constant())));
// Check if Choice // Check if Choice
using MXChoice = typename DecisionTree<M, X>::Choice; using MXChoice = typename DecisionTree<M, X>::Choice;
@ -612,19 +610,19 @@ namespace gtsam {
// put together via Shannon expansion otherwise not sorted. // put together via Shannon expansion otherwise not sorted.
std::vector<LY> functions; std::vector<LY> functions;
for(auto && branch: choice->branches()) { for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X)); functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
} }
return LY::compose(functions.begin(), functions.end(), newLabel); return LY::compose(functions.begin(), functions.end(), newLabel);
} }
/*********************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit without Assignment<L> argument. // Functor performing depth-first visit without Assignment<L> argument.
template <typename L, typename Y> template <typename L, typename Y>
struct Visit { struct Visit {
using F = std::function<void(const Y&)>; using F = std::function<void(const Y&)>;
Visit(F f) : f(f) {} ///< Construct from folding function. explicit Visit(F f) : f(f) {} ///< Construct from folding function.
F f; ///< folding function object. F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node. /// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const { void operator()(const typename DecisionTree<L, Y>::NodePtr& node) const {
@ -647,15 +645,15 @@ namespace gtsam {
visit(root_); visit(root_);
} }
/*********************************************************************************/ /****************************************************************************/
// Functor performing depth-first visit with Assignment<L> argument. // Functor performing depth-first visit with Assignment<L> argument.
template <typename L, typename Y> template <typename L, typename Y>
struct VisitWith { struct VisitWith {
using Choices = Assignment<L>; using Choices = Assignment<L>;
using F = std::function<void(const Choices&, const Y&)>; using F = std::function<void(const Choices&, const Y&)>;
VisitWith(F f) : f(f) {} ///< Construct from folding function. explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
Choices choices; ///< Assignment, mutating through recursion. Choices choices; ///< Assignment, mutating through recursion.
F f; ///< folding function object. F f; ///< folding function object.
/// Do a depth-first visit on the tree rooted at node. /// Do a depth-first visit on the tree rooted at node.
void operator()(const typename DecisionTree<L, Y>::NodePtr& node) { void operator()(const typename DecisionTree<L, Y>::NodePtr& node) {
@ -681,7 +679,7 @@ namespace gtsam {
visit(root_); visit(root_);
} }
/*********************************************************************************/ /****************************************************************************/
// fold is just done with a visit // fold is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>
template <typename Func, typename X> template <typename Func, typename X>
@ -690,7 +688,7 @@ namespace gtsam {
return x0; return x0;
} }
/*********************************************************************************/ /****************************************************************************/
// labels is just done with a visit // labels is just done with a visit
template <typename L, typename Y> template <typename L, typename Y>
std::set<L> DecisionTree<L, Y>::labels() const { std::set<L> DecisionTree<L, Y>::labels() const {
@ -702,7 +700,7 @@ namespace gtsam {
return unique; return unique;
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
bool DecisionTree<L, Y>::equals(const DecisionTree& other, bool DecisionTree<L, Y>::equals(const DecisionTree& other,
const CompareFunc& compare) const { const CompareFunc& compare) const {
@ -736,7 +734,7 @@ namespace gtsam {
return DecisionTree(root_->apply(op)); return DecisionTree(root_->apply(op));
} }
/*********************************************************************************/ /****************************************************************************/
template<typename L, typename Y> template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g, DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
const Binary& op) const { const Binary& op) const {
@ -752,7 +750,7 @@ namespace gtsam {
return result; return result;
} }
/*********************************************************************************/ /****************************************************************************/
// The way this works: // The way this works:
// We have an ADT, picture it as a tree. // We have an ADT, picture it as a tree.
// At a certain depth, we have a branch on "label". // At a certain depth, we have a branch on "label".
@ -772,7 +770,7 @@ namespace gtsam {
return result; return result;
} }
/*********************************************************************************/ /****************************************************************************/
template <typename L, typename Y> template <typename L, typename Y>
void DecisionTree<L, Y>::dot(std::ostream& os, void DecisionTree<L, Y>::dot(std::ostream& os,
const LabelFormatter& labelFormatter, const LabelFormatter& labelFormatter,
@ -790,9 +788,11 @@ namespace gtsam {
bool showZero) const { bool showZero) const {
std::ofstream os((name + ".dot").c_str()); std::ofstream os((name + ".dot").c_str());
dot(os, labelFormatter, valueFormatter, showZero); dot(os, labelFormatter, valueFormatter, showZero);
int result = system( int result =
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed"); .c_str());
if (result == -1)
throw std::runtime_error("DecisionTree::dot system call failed");
} }
template <typename L, typename Y> template <typename L, typename Y>
@ -804,8 +804,6 @@ namespace gtsam {
return ss.str(); return ss.str();
} }
/*********************************************************************************/ /******************************************************************************/
} // namespace gtsam
} // namespace gtsam

View File

@ -26,9 +26,11 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <sstream>
#include <vector>
#include <set> #include <set>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
namespace gtsam { namespace gtsam {
@ -39,15 +41,13 @@ namespace gtsam {
*/ */
template<typename L, typename Y> template<typename L, typename Y>
class DecisionTree { class DecisionTree {
protected: protected:
/// Default method for comparison of two objects of type Y. /// Default method for comparison of two objects of type Y.
static bool DefaultCompare(const Y& a, const Y& b) { static bool DefaultCompare(const Y& a, const Y& b) {
return a == b; return a == b;
} }
public: public:
using LabelFormatter = std::function<std::string(L)>; using LabelFormatter = std::function<std::string(L)>;
using ValueFormatter = std::function<std::string(Y)>; using ValueFormatter = std::function<std::string(Y)>;
using CompareFunc = std::function<bool(const Y&, const Y&)>; using CompareFunc = std::function<bool(const Y&, const Y&)>;
@ -57,15 +57,14 @@ namespace gtsam {
using Binary = std::function<Y(const Y&, const Y&)>; using Binary = std::function<Y(const Y&, const Y&)>;
/** A label annotated with cardinality */ /** A label annotated with cardinality */
using LabelC = std::pair<L,size_t>; using LabelC = std::pair<L, size_t>;
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */ /** DTs consist of Leaf and Choice nodes, both subclasses of Node */
class Leaf; struct Leaf;
class Choice; struct Choice;
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
class Node { struct Node {
public:
using Ptr = boost::shared_ptr<const Node>; using Ptr = boost::shared_ptr<const Node>;
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
@ -75,14 +74,16 @@ namespace gtsam {
// Constructor // Constructor
Node() { Node() {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush(); std::cout << ++nrNodes << " constructed " << id() << std::endl;
std::cout.flush();
#endif #endif
} }
// Destructor // Destructor
virtual ~Node() { virtual ~Node() {
#ifdef DT_DEBUG_MEMORY #ifdef DT_DEBUG_MEMORY
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush(); std::cout << --nrNodes << " destructed " << id() << std::endl;
std::cout.flush();
#endif #endif
} }
@ -110,17 +111,17 @@ namespace gtsam {
}; };
/** ------------------------ Node base class --------------------------- */ /** ------------------------ Node base class --------------------------- */
public: public:
/** A function is a shared pointer to the root of a DT */ /** A function is a shared pointer to the root of a DT */
using NodePtr = typename Node::Ptr; using NodePtr = typename Node::Ptr;
/// A DecisionTree just contains the root. TODO(dellaert): make protected. /// A DecisionTree just contains the root. TODO(dellaert): make protected.
NodePtr root_; NodePtr root_;
protected: protected:
/** Internal recursive function to create from keys, cardinalities,
/** Internal recursive function to create from keys, cardinalities, and Y values */ * and Y values
*/
template<typename It, typename ValueIt> template<typename It, typename ValueIt>
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
@ -140,7 +141,6 @@ namespace gtsam {
std::function<Y(const X&)> Y_of_X) const; std::function<Y(const X&)> Y_of_X) const;
public: public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -148,7 +148,7 @@ namespace gtsam {
DecisionTree(); DecisionTree();
/** Create a constant */ /** Create a constant */
DecisionTree(const Y& y); explicit DecisionTree(const Y& y);
/** Create a new leaf function splitting on a variable */ /** Create a new leaf function splitting on a variable */
DecisionTree(const L& label, const Y& y1, const Y& y2); DecisionTree(const L& label, const Y& y1, const Y& y2);
@ -167,8 +167,8 @@ namespace gtsam {
DecisionTree(Iterator begin, Iterator end, const L& label); DecisionTree(Iterator begin, Iterator end, const L& label);
/** Create DecisionTree from two others */ /** Create DecisionTree from two others */
DecisionTree(const L& label, // DecisionTree(const L& label, const DecisionTree& f0,
const DecisionTree& f0, const DecisionTree& f1); const DecisionTree& f1);
/** /**
* @brief Convert from a different value type. * @brief Convert from a different value type.
@ -234,6 +234,8 @@ namespace gtsam {
* *
* @param f side-effect taking a value. * @param f side-effect taking a value.
* *
* @note Due to pruning, leaves might not exhaust choices.
*
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](int y) { sum += y; }; * auto visitor = [&](int y) { sum += y; };
@ -247,6 +249,8 @@ namespace gtsam {
* *
* @param f side-effect taking an assignment and a value. * @param f side-effect taking an assignment and a value.
* *
* @note Due to pruning, leaves might not exhaust choices.
*
* Example: * Example:
* int sum = 0; * int sum = 0;
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; }; * auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
@ -264,6 +268,7 @@ namespace gtsam {
* @return X final value for accumulator. * @return X final value for accumulator.
* *
* @note X is always passed by value. * @note X is always passed by value.
* @note Due to pruning, leaves might not exhaust choices.
* *
* Example: * Example:
* auto add = [](const double& y, double x) { return y + x; }; * auto add = [](const double& y, double x) { return y + x; };
@ -289,7 +294,8 @@ namespace gtsam {
} }
/** combine subtrees on key with binary operation "op" */ /** combine subtrees on key with binary operation "op" */
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const; DecisionTree combine(const L& label, size_t cardinality,
const Binary& op) const;
/** combine with LabelC for convenience */ /** combine with LabelC for convenience */
DecisionTree combine(const LabelC& labelC, const Binary& op) const { DecisionTree combine(const LabelC& labelC, const Binary& op) const {
@ -313,15 +319,14 @@ namespace gtsam {
/// @{ /// @{
// internal use only // internal use only
DecisionTree(const NodePtr& root); explicit DecisionTree(const NodePtr& root);
// internal use only // internal use only
template<typename Iterator> NodePtr template<typename Iterator> NodePtr
compose(Iterator begin, Iterator end, const L& label) const; compose(Iterator begin, Iterator end, const L& label) const;
/// @} /// @}
}; // DecisionTree
}; // DecisionTree
/** free versions of apply */ /** free versions of apply */
@ -340,11 +345,19 @@ namespace gtsam {
return f.apply(g, op); return f.apply(g, op);
} }
/// unzip a DecisionTree if its leaves are `std::pair` /**
template<typename L, typename T1, typename T2> * @brief unzip a DecisionTree with `std::pair` values.
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(const DecisionTree<L, std::pair<T1, T2> > &input) { *
return std::make_pair(DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }), * @param input the DecisionTree with `(T1,T2)` values.
DecisionTree<L, T2>(input, [](std::pair<T1, T2> i) { return i.second; })); * @return a pair of DecisionTree on T1 and T2, respectively.
*/
template <typename L, typename T1, typename T2>
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
const DecisionTree<L, std::pair<T1, T2> >& input) {
return std::make_pair(
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
DecisionTree<L, T2>(input,
[](std::pair<T1, T2> i) { return i.second; }));
} }
} // namespace gtsam } // namespace gtsam

View File

@ -17,9 +17,9 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <gtsam/base/FastSet.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/base/FastSet.h>
#include <boost/make_shared.hpp> #include <boost/make_shared.hpp>
#include <boost/format.hpp> #include <boost/format.hpp>
@ -29,42 +29,42 @@ using namespace std;
namespace gtsam { namespace gtsam {
/* ******************************************************************************** */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor() { DecisionTreeFactor::DecisionTreeFactor() {}
}
/* ******************************************************************************** */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) : const ADT& potentials)
DiscreteFactor(keys.indices()), ADT(potentials), : DiscreteFactor(keys.indices()),
cardinalities_(keys.cardinalities()) { ADT(potentials),
} cardinalities_(keys.cardinalities()) {}
/* *************************************************************************/ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) { : DiscreteFactor(c.keys()),
} AlgebraicDecisionTree<Key>(c),
cardinalities_(c.cardinalities_) {}
/* ************************************************************************* */ /* ************************************************************************ */
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const { bool DecisionTreeFactor::equals(const DiscreteFactor& other,
if(!dynamic_cast<const DecisionTreeFactor*>(&other)) { double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
return false; return false;
} } else {
else {
const auto& f(static_cast<const DecisionTreeFactor&>(other)); const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol); return ADT::equals(f, tol);
} }
} }
/* ************************************************************************* */ /* ************************************************************************ */
double DecisionTreeFactor::safe_div(const double &a, const double &b) { double DecisionTreeFactor::safe_div(const double& a, const double& b) {
// The use for safe_div is when we divide the product factor by the sum // The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the // factor. If the product or sum is zero, we accord zero probability to the
// event. // event.
return (a == 0 || b == 0) ? 0 : (a / b); return (a == 0 || b == 0) ? 0 : (a / b);
} }
/* ************************************************************************* */ /* ************************************************************************ */
void DecisionTreeFactor::print(const string& s, void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
cout << s; cout << s;
@ -75,31 +75,32 @@ namespace gtsam {
ADT::print("", formatter); ADT::print("", formatter);
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f, DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
ADT::Binary op) const { ADT::Binary op) const {
map<Key,size_t> cs; // new cardinalities map<Key, size_t> cs; // new cardinalities
// make unique key-cardinality map // make unique key-cardinality map
for(Key j: keys()) cs[j] = cardinality(j); for (Key j : keys()) cs[j] = cardinality(j);
for(Key j: f.keys()) cs[j] = f.cardinality(j); for (Key j : f.keys()) cs[j] = f.cardinality(j);
// Convert map into keys // Convert map into keys
DiscreteKeys keys; DiscreteKeys keys;
for(const std::pair<const Key,size_t>& key: cs) for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
keys.push_back(key);
// apply operand // apply operand
ADT result = ADT::apply(f, op); ADT result = ADT::apply(f, op);
// Make a new factor // Make a new factor
return DecisionTreeFactor(keys, result); return DecisionTreeFactor(keys, result);
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals, DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
ADT::Binary op) const { size_t nrFrontals, ADT::Binary op) const {
if (nrFrontals > size())
if (nrFrontals > size()) throw invalid_argument( throw invalid_argument(
(boost::format( (boost::format(
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") "DecisionTreeFactor::combine: invalid number of frontal "
% nrFrontals % size()).str()); "keys %d, nr.keys=%d") %
nrFrontals % size())
.str());
// sum over nrFrontals keys // sum over nrFrontals keys
size_t i; size_t i;
@ -113,20 +114,21 @@ namespace gtsam {
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (; i < keys().size(); i++) { for (; i < keys().size(); i++) {
Key j = keys()[i]; Key j = keys()[i];
dkeys.push_back(DiscreteKey(j,cardinality(j))); dkeys.push_back(DiscreteKey(j, cardinality(j)));
} }
return boost::make_shared<DecisionTreeFactor>(dkeys, result); return boost::make_shared<DecisionTreeFactor>(dkeys, result);
} }
/* ************************************************************************ */
/* ************************************************************************* */ DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys, const Ordering& frontalKeys, ADT::Binary op) const {
ADT::Binary op) const { if (frontalKeys.size() > size())
throw invalid_argument(
if (frontalKeys.size() > size()) throw invalid_argument( (boost::format(
(boost::format( "DecisionTreeFactor::combine: invalid number of frontal "
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") "keys %d, nr.keys=%d") %
% frontalKeys.size() % size()).str()); frontalKeys.size() % size())
.str());
// sum over nrFrontals keys // sum over nrFrontals keys
size_t i; size_t i;
@ -137,20 +139,22 @@ namespace gtsam {
} }
// create new factor, note we collect keys that are not in frontalKeys // create new factor, note we collect keys that are not in frontalKeys
// TODO: why do we need this??? result should contain correct keys!!! // TODO(frank): why do we need this??? result should contain correct keys!!!
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (i = 0; i < keys().size(); i++) { for (i = 0; i < keys().size(); i++) {
Key j = keys()[i]; Key j = keys()[i];
// TODO: inefficient! // TODO(frank): inefficient!
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end()) if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
frontalKeys.end())
continue; continue;
dkeys.push_back(DiscreteKey(j,cardinality(j))); dkeys.push_back(DiscreteKey(j, cardinality(j)));
} }
return boost::make_shared<DecisionTreeFactor>(dkeys, result); return boost::make_shared<DecisionTreeFactor>(dkeys, result);
} }
/* ************************************************************************* */ /* ************************************************************************ */
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const { std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
const {
// Get all possible assignments // Get all possible assignments
std::vector<std::pair<Key, size_t>> pairs; std::vector<std::pair<Key, size_t>> pairs;
for (auto& key : keys()) { for (auto& key : keys()) {
@ -168,7 +172,7 @@ namespace gtsam {
return result; return result;
} }
/* ************************************************************************* */ /* ************************************************************************ */
DiscreteKeys DecisionTreeFactor::discreteKeys() const { DiscreteKeys DecisionTreeFactor::discreteKeys() const {
DiscreteKeys result; DiscreteKeys result;
for (auto&& key : keys()) { for (auto&& key : keys()) {
@ -180,7 +184,7 @@ namespace gtsam {
return result; return result;
} }
/* ************************************************************************* */ /* ************************************************************************ */
static std::string valueFormatter(const double& v) { static std::string valueFormatter(const double& v) {
return (boost::format("%4.2g") % v).str(); return (boost::format("%4.2g") % v).str();
} }
@ -194,8 +198,8 @@ namespace gtsam {
/** output to graphviz format, open a file */ /** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name, void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
bool showZero) const { bool showZero) const {
ADT::dot(name, keyFormatter, valueFormatter, showZero); ADT::dot(name, keyFormatter, valueFormatter, showZero);
} }
@ -205,8 +209,8 @@ namespace gtsam {
return ADT::dot(keyFormatter, valueFormatter, showZero); return ADT::dot(keyFormatter, valueFormatter, showZero);
} }
// Print out header. // Print out header.
/* ************************************************************************* */ /* ************************************************************************ */
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter, string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
const Names& names) const { const Names& names) const {
stringstream ss; stringstream ss;
@ -271,17 +275,19 @@ namespace gtsam {
return ss.str(); return ss.str();
} }
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table), const vector<double>& table)
cardinalities_(keys.cardinalities()) { : DiscreteFactor(keys.indices()),
} AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */ /* ************************************************************************ */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) : DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table), const string& table)
cardinalities_(keys.cardinalities()) { : DiscreteFactor(keys.indices()),
} AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */ /* ************************************************************************ */
} // namespace gtsam } // namespace gtsam

View File

@ -18,16 +18,18 @@
#pragma once #pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteFactor.h> #include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/inference/Ordering.h> #include <gtsam/inference/Ordering.h>
#include <algorithm>
#include <boost/shared_ptr.hpp> #include <boost/shared_ptr.hpp>
#include <map>
#include <vector>
#include <exception>
#include <stdexcept> #include <stdexcept>
#include <string>
#include <utility>
#include <vector>
namespace gtsam { namespace gtsam {
@ -36,21 +38,19 @@ namespace gtsam {
/** /**
* A discrete probabilistic factor * A discrete probabilistic factor
*/ */
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> { class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
public AlgebraicDecisionTree<Key> {
public: public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam
typedef DecisionTreeFactor This; typedef DecisionTreeFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr; typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
protected: protected:
std::map<Key,size_t> cardinalities_; std::map<Key, size_t> cardinalities_;
public:
public:
/// @name Standard Constructors /// @name Standard Constructors
/// @{ /// @{
@ -61,7 +61,8 @@ namespace gtsam {
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from doubles */ /** Constructor from doubles */
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table); DecisionTreeFactor(const DiscreteKeys& keys,
const std::vector<double>& table);
/** Constructor from string */ /** Constructor from string */
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table); DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
@ -86,7 +87,8 @@ namespace gtsam {
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
// print // print
void print(const std::string& s = "DecisionTreeFactor:\n", void print(
const std::string& s = "DecisionTreeFactor:\n",
const KeyFormatter& formatter = DefaultKeyFormatter) const override; const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @} /// @}
@ -105,7 +107,7 @@ namespace gtsam {
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j);} size_t cardinality(Key j) const { return cardinalities_.at(j); }
/// divide by factor f (safely) /// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
@ -113,9 +115,7 @@ namespace gtsam {
} }
/// Convert into a decisiontree /// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override { DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
return *this;
}
/// Create new factor by summing all values with the same separator values /// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const { shared_ptr sum(size_t nrFrontals) const {
@ -164,27 +164,6 @@ namespace gtsam {
*/ */
shared_ptr combine(const Ordering& keys, ADT::Binary op) const; shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
// /**
// * @brief Permutes the keys in Potentials and DiscreteFactor
// *
// * This re-implements the permuteWithInverse() in both Potentials
// * and DiscreteFactor by doing both of them together.
// */
//
// void permuteWithInverse(const Permutation& inversePermutation){
// DiscreteFactor::permuteWithInverse(inversePermutation);
// Potentials::permuteWithInverse(inversePermutation);
// }
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
// DiscreteFactor::reduceWithInverse(inverseReduction);
// Potentials::reduceWithInverse(inverseReduction);
// }
/// Enumerate all values into a map from values to double. /// Enumerate all values into a map from values to double.
std::vector<std::pair<DiscreteValues, double>> enumerate() const; std::vector<std::pair<DiscreteValues, double>> enumerate() const;
@ -197,13 +176,13 @@ namespace gtsam {
/** output to graphviz format, stream version */ /** output to graphviz format, stream version */
void dot(std::ostream& os, void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const; bool showZero = true) const;
/** output to graphviz format, open a file */ /** output to graphviz format, open a file */
void dot(const std::string& name, void dot(const std::string& name,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
bool showZero = true) const; bool showZero = true) const;
/** output to graphviz format string */ /** output to graphviz format string */
std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
@ -217,7 +196,7 @@ namespace gtsam {
* @return std::string a markdown string. * @return std::string a markdown string.
*/ */
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/** /**
* @brief Render as html table * @brief Render as html table
@ -227,14 +206,13 @@ namespace gtsam {
* @return std::string a html string. * @return std::string a html string.
*/ */
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const Names& names = {}) const override; const Names& names = {}) const override;
/// @} /// @}
};
};
// DecisionTreeFactor
// traits // traits
template<> struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {}; template <>
struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
}// namespace gtsam } // namespace gtsam

View File

@ -17,38 +17,39 @@
*/ */
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits #include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers // headers first to make sure no missing headers
//#define DT_NO_PRUNING //#define DT_NO_PRUNING
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only #include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING #define DISABLE_TIMING
#include <boost/tokenizer.hpp>
#include <boost/assign/std/map.hpp> #include <boost/assign/std/map.hpp>
#include <boost/assign/std/vector.hpp> #include <boost/assign/std/vector.hpp>
#include <boost/tokenizer.hpp>
using namespace boost::assign; using namespace boost::assign;
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/base/timing.h> #include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
/* ******************************************************************************** */ /* ************************************************************************** */
typedef AlgebraicDecisionTree<Key> ADT; typedef AlgebraicDecisionTree<Key> ADT;
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<ADT> : public Testable<ADT> {}; template <>
} struct traits<ADT> : public Testable<ADT> {};
} // namespace gtsam
#define DISABLE_DOT #define DISABLE_DOT
template<typename T> template <typename T>
void dot(const T&f, const string& filename) { void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT #ifndef DISABLE_DOT
f.dot(filename); f.dot(filename);
#endif #endif
@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) {
// If second argument of binary op is Leaf // If second argument of binary op is Leaf
template<typename L> template<typename L>
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL( typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
Cache& cache, const Leaf& gL, Mul op) const { double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
Ptr h(new Choice(label(), cardinality())); Ptr h(new Choice(label(), cardinality()));
for(const NodePtr& branch: branches_) for(const NodePtr& branch: branches_)
h->push_back(branch->apply_f_op_g(cache, gL, op)); h->push_back(branch->apply_f_op_g(cache, gL, op));
@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) {
} }
*/ */
/* ******************************************************************************** */ /* ************************************************************************** */
// instrumented operators // instrumented operators
/* ******************************************************************************** */ /* ************************************************************************** */
size_t muls = 0, adds = 0; size_t muls = 0, adds = 0;
double elapsed; double elapsed;
void resetCounts() { void resetCounts() {
@ -83,8 +84,9 @@ void resetCounts() {
} }
void printCounts(const string& s) { void printCounts(const string& s) {
#ifndef DISABLE_TIMING #ifndef DISABLE_TIMING
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
% (1000 * elapsed) << endl; (1000 * elapsed)
<< endl;
#endif #endif
resetCounts(); resetCounts();
} }
@ -97,12 +99,11 @@ double add_(const double& a, const double& b) {
return a + b; return a + b;
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test ADT // test ADT
TEST(ADT, example3) TEST(ADT, example3) {
{
// Create labels // Create labels
DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2); DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
// Literals // Literals
ADT a(A, 0.5, 0.5); ADT a(A, 0.5, 0.5);
@ -114,22 +115,21 @@ TEST(ADT, example3)
ADT cnotb = c * notb; ADT cnotb = c * notb;
dot(cnotb, "ADT-cnotb"); dot(cnotb, "ADT-cnotb");
// a.print("a: "); // a.print("a: ");
// cnotb.print("cnotb: "); // cnotb.print("cnotb: ");
ADT acnotb = a * cnotb; ADT acnotb = a * cnotb;
// acnotb.print("acnotb: "); // acnotb.print("acnotb: ");
// acnotb.printCache("acnotb Cache:"); // acnotb.printCache("acnotb Cache:");
dot(acnotb, "ADT-acnotb"); dot(acnotb, "ADT-acnotb");
ADT big = apply(apply(d, note, &mul), acnotb, &add_); ADT big = apply(apply(d, note, &mul), acnotb, &add_);
dot(big, "ADT-big"); dot(big, "ADT-big");
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Asia Bayes Network // Asia Bayes Network
/* ******************************************************************************** */ /* ************************************************************************** */
/** Convert Signature into CPT */ /** Convert Signature into CPT */
ADT create(const Signature& signature) { ADT create(const Signature& signature) {
@ -143,9 +143,9 @@ ADT create(const Signature& signature) {
/* ************************************************************************* */ /* ************************************************************************* */
// test Asia Joint // test Asia Joint
TEST(ADT, joint) TEST(ADT, joint) {
{ DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2); D(7, 2);
resetCounts(); resetCounts();
gttic_(asiaCPTs); gttic_(asiaCPTs);
@ -204,10 +204,9 @@ TEST(ADT, joint)
/* ************************************************************************* */ /* ************************************************************************* */
// test Inference with joint // test Inference with joint
TEST(ADT, inference) TEST(ADT, inference) {
{ DiscreteKey A(0, 2), D(1, 2), //
DiscreteKey A(0,2), D(1,2),// B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2);
resetCounts(); resetCounts();
gttic_(infCPTs); gttic_(infCPTs);
@ -244,7 +243,7 @@ TEST(ADT, inference)
dot(joint, "Joint-Product-ASTLBEX"); dot(joint, "Joint-Product-ASTLBEX");
joint = apply(joint, pD, &mul); joint = apply(joint, pD, &mul);
dot(joint, "Joint-Product-ASTLBEXD"); dot(joint, "Joint-Product-ASTLBEXD");
EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering EXPECT_LONGS_EQUAL(370, (long)muls); // different ordering
gttoc_(asiaProd); gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd); tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall(); elapsed = asiaProdNode->secs() + asiaProdNode->wall();
@ -271,9 +270,8 @@ TEST(ADT, inference)
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(ADT, factor_graph) TEST(ADT, factor_graph) {
{ DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2);
resetCounts(); resetCounts();
gttic_(createCPTs); gttic_(createCPTs);
@ -403,18 +401,19 @@ TEST(ADT, factor_graph)
/* ************************************************************************* */ /* ************************************************************************* */
// test equality // test equality
TEST(ADT, equality_noparser) TEST(ADT, equality_noparser) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
Signature::Table tableA, tableB; Signature::Table tableA, tableB;
Signature::Row rA, rB; Signature::Row rA, rB;
rA += 80, 20; rB += 60, 40; rA += 80, 20;
tableA += rA; tableB += rB; rB += 60, 40;
tableA += rA;
tableB += rB;
// Check straight equality // Check straight equality
ADT pA1 = create(A % tableA); ADT pA1 = create(A % tableA);
ADT pA2 = create(A % tableA); ADT pA2 = create(A % tableA);
EXPECT(pA1.equals(pA2)); // should be equal EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply // Check equality after apply
ADT pB = create(B % tableB); ADT pB = create(B % tableB);
@ -425,13 +424,12 @@ TEST(ADT, equality_noparser)
/* ************************************************************************* */ /* ************************************************************************* */
// test equality // test equality
TEST(ADT, equality_parser) TEST(ADT, equality_parser) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Check straight equality // Check straight equality
ADT pA1 = create(A % "80/20"); ADT pA1 = create(A % "80/20");
ADT pA2 = create(A % "80/20"); ADT pA2 = create(A % "80/20");
EXPECT(pA1.equals(pA2)); // should be equal EXPECT(pA1.equals(pA2)); // should be equal
// Check equality after apply // Check equality after apply
ADT pB = create(B % "60/40"); ADT pB = create(B % "60/40");
@ -440,12 +438,11 @@ TEST(ADT, equality_parser)
EXPECT(pAB2.equals(pAB1)); EXPECT(pAB2.equals(pAB1));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Factor graph construction // Factor graph construction
// test constructor from strings // test constructor from strings
TEST(ADT, constructor) TEST(ADT, constructor) {
{ DiscreteKey v0(0, 2), v1(1, 3);
DiscreteKey v0(0,2), v1(1,3);
DiscreteValues x00, x01, x02, x10, x11, x12; DiscreteValues x00, x01, x02, x10, x11, x12;
x00[0] = 0, x00[1] = 0; x00[0] = 0, x00[1] = 0;
x01[0] = 0, x01[1] = 1; x01[0] = 0, x01[1] = 1;
@ -470,11 +467,10 @@ TEST(ADT, constructor)
EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9); EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9);
EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9); EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9);
DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2); DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
vector<double> table(5 * 4 * 3 * 2); vector<double> table(5 * 4 * 3 * 2);
double x = 0; double x = 0;
for(double& t: table) for (double& t : table) t = x++;
t = x++;
ADT f3(z0 & z1 & z2 & z3, table); ADT f3(z0 & z1 & z2 & z3, table);
DiscreteValues assignment; DiscreteValues assignment;
assignment[0] = 0; assignment[0] = 0;
@ -487,9 +483,8 @@ TEST(ADT, constructor)
/* ************************************************************************* */ /* ************************************************************************* */
// test conversion to integer indices // test conversion to integer indices
// Only works if DiscreteKeys are binary, as size_t has binary cardinality! // Only works if DiscreteKeys are binary, as size_t has binary cardinality!
TEST(ADT, conversion) TEST(ADT, conversion) {
{ DiscreteKey X(0, 2), Y(1, 2);
DiscreteKey X(0,2), Y(1,2);
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
dot(fDiscreteKey, "conversion-f1"); dot(fDiscreteKey, "conversion-f1");
@ -513,11 +508,10 @@ TEST(ADT, conversion)
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9); EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test operations in elimination // test operations in elimination
TEST(ADT, elimination) TEST(ADT, elimination) {
{ DiscreteKey A(0, 2), B(1, 3), C(2, 2);
DiscreteKey A(0,2), B(1,3), C(2,2);
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
dot(f1, "elimination-f1"); dot(f1, "elimination-f1");
@ -525,53 +519,51 @@ TEST(ADT, elimination)
// sum out lower key // sum out lower key
ADT actualSum = f1.sum(C); ADT actualSum = f1.sum(C);
ADT expectedSum(A & B, "3 7 11 9 6 10"); ADT expectedSum(A & B, "3 7 11 9 6 10");
CHECK(assert_equal(expectedSum,actualSum)); CHECK(assert_equal(expectedSum, actualSum));
// normalize // normalize
ADT actual = f1 / actualSum; ADT actual = f1 / actualSum;
vector<double> cpt; vector<double> cpt;
cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, //
1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10;
ADT expected(A & B & C, cpt); ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual)); CHECK(assert_equal(expected, actual));
} }
{ {
// sum out lower 2 keys // sum out lower 2 keys
ADT actualSum = f1.sum(C).sum(B); ADT actualSum = f1.sum(C).sum(B);
ADT expectedSum(A, 21, 25); ADT expectedSum(A, 21, 25);
CHECK(assert_equal(expectedSum,actualSum)); CHECK(assert_equal(expectedSum, actualSum));
// normalize // normalize
ADT actual = f1 / actualSum; ADT actual = f1 / actualSum;
vector<double> cpt; vector<double> cpt;
cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, //
1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25;
ADT expected(A & B & C, cpt); ADT expected(A & B & C, cpt);
CHECK(assert_equal(expected,actual)); CHECK(assert_equal(expected, actual));
} }
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test non-commutative op // Test non-commutative op
TEST(ADT, div) TEST(ADT, div) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Literals // Literals
ADT a(A, 8, 16); ADT a(A, 8, 16);
ADT b(B, 2, 4); ADT b(B, 2, 4);
ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4
ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16
EXPECT(assert_equal(expected_a_div_b, a / b)); EXPECT(assert_equal(expected_a_div_b, a / b));
EXPECT(assert_equal(expected_b_div_a, b / a)); EXPECT(assert_equal(expected_b_div_a, b / a));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test zero shortcut // test zero shortcut
TEST(ADT, zero) TEST(ADT, zero) {
{ DiscreteKey A(0, 2), B(1, 2);
DiscreteKey A(0,2), B(1,2);
// Literals // Literals
ADT a(A, 0, 1); ADT a(A, 0, 1);

View File

@ -24,21 +24,21 @@ using namespace boost::assign;
#include <gtsam/base/Testable.h> #include <gtsam/base/Testable.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
//#define DT_DEBUG_MEMORY // #define DT_DEBUG_MEMORY
//#define DT_NO_PRUNING // #define DT_NO_PRUNING
#define DISABLE_DOT #define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h> #include <gtsam/discrete/DecisionTree-inl.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
template<typename T> template <typename T>
void dot(const T&f, const string& filename) { void dot(const T& f, const string& filename) {
#ifndef DISABLE_DOT #ifndef DISABLE_DOT
f.dot(filename); f.dot(filename);
#endif #endif
} }
#define DOT(x)(dot(x,#x)) #define DOT(x) (dot(x, #x))
struct Crazy { struct Crazy {
int a; int a;
@ -65,14 +65,15 @@ struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {}; template <>
} struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
} // namespace gtsam
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree) GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
/* ******************************************************************************** */ /* ************************************************************************** */
// Test string labels and int range // Test string labels and int range
/* ******************************************************************************** */ /* ************************************************************************** */
struct DT : public DecisionTree<string, int> { struct DT : public DecisionTree<string, int> {
using Base = DecisionTree<string, int>; using Base = DecisionTree<string, int>;
@ -98,30 +99,21 @@ struct DT : public DecisionTree<string, int> {
// traits // traits
namespace gtsam { namespace gtsam {
template<> struct traits<DT> : public Testable<DT> {}; template <>
} struct traits<DT> : public Testable<DT> {};
} // namespace gtsam
GTSAM_CONCEPT_TESTABLE_INST(DT) GTSAM_CONCEPT_TESTABLE_INST(DT)
struct Ring { struct Ring {
static inline int zero() { static inline int zero() { return 0; }
return 0; static inline int one() { return 1; }
} static inline int id(const int& a) { return a; }
static inline int one() { static inline int add(const int& a, const int& b) { return a + b; }
return 1; static inline int mul(const int& a, const int& b) { return a * b; }
}
static inline int id(const int& a) {
return a;
}
static inline int add(const int& a, const int& b) {
return a + b;
}
static inline int mul(const int& a, const int& b) {
return a * b;
}
}; };
/* ******************************************************************************** */ /* ************************************************************************** */
// test DT // test DT
TEST(DecisionTree, example) { TEST(DecisionTree, example) {
// Create labels // Create labels
@ -139,20 +131,20 @@ TEST(DecisionTree, example) {
// A // A
DT a(A, 0, 5); DT a(A, 0, 5);
LONGS_EQUAL(0,a(x00)) LONGS_EQUAL(0, a(x00))
LONGS_EQUAL(5,a(x10)) LONGS_EQUAL(5, a(x10))
DOT(a); DOT(a);
// pruned // pruned
DT p(A, 2, 2); DT p(A, 2, 2);
LONGS_EQUAL(2,p(x00)) LONGS_EQUAL(2, p(x00))
LONGS_EQUAL(2,p(x10)) LONGS_EQUAL(2, p(x10))
DOT(p); DOT(p);
// \neg B // \neg B
DT notb(B, 5, 0); DT notb(B, 5, 0);
LONGS_EQUAL(5,notb(x00)) LONGS_EQUAL(5, notb(x00))
LONGS_EQUAL(5,notb(x10)) LONGS_EQUAL(5, notb(x10))
DOT(notb); DOT(notb);
// Check supplying empty trees yields an exception // Check supplying empty trees yields an exception
@ -162,34 +154,34 @@ TEST(DecisionTree, example) {
// apply, two nodes, in natural order // apply, two nodes, in natural order
DT anotb = apply(a, notb, &Ring::mul); DT anotb = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,anotb(x00)) LONGS_EQUAL(0, anotb(x00))
LONGS_EQUAL(0,anotb(x01)) LONGS_EQUAL(0, anotb(x01))
LONGS_EQUAL(25,anotb(x10)) LONGS_EQUAL(25, anotb(x10))
LONGS_EQUAL(0,anotb(x11)) LONGS_EQUAL(0, anotb(x11))
DOT(anotb); DOT(anotb);
// check pruning // check pruning
DT pnotb = apply(p, notb, &Ring::mul); DT pnotb = apply(p, notb, &Ring::mul);
LONGS_EQUAL(10,pnotb(x00)) LONGS_EQUAL(10, pnotb(x00))
LONGS_EQUAL( 0,pnotb(x01)) LONGS_EQUAL(0, pnotb(x01))
LONGS_EQUAL(10,pnotb(x10)) LONGS_EQUAL(10, pnotb(x10))
LONGS_EQUAL( 0,pnotb(x11)) LONGS_EQUAL(0, pnotb(x11))
DOT(pnotb); DOT(pnotb);
// check pruning // check pruning
DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul);
LONGS_EQUAL(0,zeros(x00)) LONGS_EQUAL(0, zeros(x00))
LONGS_EQUAL(0,zeros(x01)) LONGS_EQUAL(0, zeros(x01))
LONGS_EQUAL(0,zeros(x10)) LONGS_EQUAL(0, zeros(x10))
LONGS_EQUAL(0,zeros(x11)) LONGS_EQUAL(0, zeros(x11))
DOT(zeros); DOT(zeros);
// apply, two nodes, in switched order // apply, two nodes, in switched order
DT notba = apply(a, notb, &Ring::mul); DT notba = apply(a, notb, &Ring::mul);
LONGS_EQUAL(0,notba(x00)) LONGS_EQUAL(0, notba(x00))
LONGS_EQUAL(0,notba(x01)) LONGS_EQUAL(0, notba(x01))
LONGS_EQUAL(25,notba(x10)) LONGS_EQUAL(25, notba(x10))
LONGS_EQUAL(0,notba(x11)) LONGS_EQUAL(0, notba(x11))
DOT(notba); DOT(notba);
// Test choose 0 // Test choose 0
@ -204,10 +196,10 @@ TEST(DecisionTree, example) {
// apply, two nodes at same level // apply, two nodes at same level
DT a_and_a = apply(a, a, &Ring::mul); DT a_and_a = apply(a, a, &Ring::mul);
LONGS_EQUAL(0,a_and_a(x00)) LONGS_EQUAL(0, a_and_a(x00))
LONGS_EQUAL(0,a_and_a(x01)) LONGS_EQUAL(0, a_and_a(x01))
LONGS_EQUAL(25,a_and_a(x10)) LONGS_EQUAL(25, a_and_a(x10))
LONGS_EQUAL(25,a_and_a(x11)) LONGS_EQUAL(25, a_and_a(x11))
DOT(a_and_a); DOT(a_and_a);
// create a function on C // create a function on C
@ -219,16 +211,16 @@ TEST(DecisionTree, example) {
// mul notba with C // mul notba with C
DT notbac = apply(notba, c, &Ring::mul); DT notbac = apply(notba, c, &Ring::mul);
LONGS_EQUAL(125,notbac(x101)) LONGS_EQUAL(125, notbac(x101))
DOT(notbac); DOT(notbac);
// mul now in different order // mul now in different order
DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul);
LONGS_EQUAL(125,acnotb(x101)) LONGS_EQUAL(125, acnotb(x101))
DOT(acnotb); DOT(acnotb);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test Conversion of values // test Conversion of values
bool bool_of_int(const int& y) { return y != 0; }; bool bool_of_int(const int& y) { return y != 0; };
typedef DecisionTree<string, bool> StringBoolTree; typedef DecisionTree<string, bool> StringBoolTree;
@ -249,11 +241,9 @@ TEST(DecisionTree, ConvertValuesOnly) {
EXPECT(!f2(x00)); EXPECT(!f2(x00));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test Conversion of both values and labels. // test Conversion of both values and labels.
enum Label { enum Label { U, V, X, Y, Z };
U, V, X, Y, Z
};
typedef DecisionTree<Label, bool> LabelBoolTree; typedef DecisionTree<Label, bool> LabelBoolTree;
TEST(DecisionTree, ConvertBoth) { TEST(DecisionTree, ConvertBoth) {
@ -281,7 +271,7 @@ TEST(DecisionTree, ConvertBoth) {
EXPECT(!f2(x11)); EXPECT(!f2(x11));
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// test Compose expansion // test Compose expansion
TEST(DecisionTree, Compose) { TEST(DecisionTree, Compose) {
// Create labels // Create labels
@ -292,7 +282,7 @@ TEST(DecisionTree, Compose) {
// Create from string // Create from string
vector<DT::LabelC> keys; vector<DT::LabelC> keys;
keys += DT::LabelC(A,2), DT::LabelC(B,2); keys += DT::LabelC(A, 2), DT::LabelC(B, 2);
DT f2(keys, "0 2 1 3"); DT f2(keys, "0 2 1 3");
EXPECT(assert_equal(f2, f1, 1e-9)); EXPECT(assert_equal(f2, f1, 1e-9));
@ -302,13 +292,13 @@ TEST(DecisionTree, Compose) {
DOT(f4); DOT(f4);
// a bigger tree // a bigger tree
keys += DT::LabelC(C,2); keys += DT::LabelC(C, 2);
DT f5(keys, "0 4 2 6 1 5 3 7"); DT f5(keys, "0 4 2 6 1 5 3 7");
EXPECT(assert_equal(f5, f4, 1e-9)); EXPECT(assert_equal(f5, f4, 1e-9));
DOT(f5); DOT(f5);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Check we can create a decision tree of containers. // Check we can create a decision tree of containers.
TEST(DecisionTree, Containers) { TEST(DecisionTree, Containers) {
using Container = std::vector<double>; using Container = std::vector<double>;
@ -318,7 +308,7 @@ TEST(DecisionTree, Containers) {
StringContainerTree tree; StringContainerTree tree;
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3)); DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
// Check conversion // Check conversion
@ -330,11 +320,11 @@ TEST(DecisionTree, Containers) {
StringContainerTree converted(stringIntTree, container_of_int); StringContainerTree converted(stringIntTree, container_of_int);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test visit. // Test visit.
TEST(DecisionTree, visit) { TEST(DecisionTree, visit) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0; double sum = 0.0;
auto visitor = [&](int y) { sum += y; }; auto visitor = [&](int y) { sum += y; };
@ -342,11 +332,11 @@ TEST(DecisionTree, visit) {
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test visit, with Choices argument. // Test visit, with Choices argument.
TEST(DecisionTree, visitWith) { TEST(DecisionTree, visitWith) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
double sum = 0.0; double sum = 0.0;
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; }; auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
@ -354,29 +344,29 @@ TEST(DecisionTree, visitWith) {
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test fold. // Test fold.
TEST(DecisionTree, fold) { TEST(DecisionTree, fold) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
auto add = [](const int& y, double x) { return y + x; }; auto add = [](const int& y, double x) { return y + x; };
double sum = tree.fold(add, 0.0); double sum = tree.fold(add, 0.0);
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test retrieving all labels. // Test retrieving all labels.
TEST(DecisionTree, labels) { TEST(DecisionTree, labels) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B");
DT tree(B, DT(A, 0, 1), DT(A, 2, 3)); DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
auto labels = tree.labels(); auto labels = tree.labels();
EXPECT_LONGS_EQUAL(2, labels.size()); EXPECT_LONGS_EQUAL(2, labels.size());
} }
/* ******************************************************************************** */ /* ************************************************************************** */
// Test retrieving all labels. // Test unzip method.
TEST(DecisionTree, unzip) { TEST(DecisionTree, unzip) {
using DTP = DecisionTree<string, std::pair<int, string>>; using DTP = DecisionTree<string, std::pair<int, string>>;
using DT1 = DecisionTree<string, int>; using DT1 = DecisionTree<string, int>;
@ -384,10 +374,8 @@ TEST(DecisionTree, unzip) {
// Create small two-level tree // Create small two-level tree
string A("A"), B("B"), C("C"); string A("A"), B("B"), C("C");
DTP tree(B, DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
DTP(A, {0, "zero"}, {1, "one"}), DTP(A, {2, "two"}, {1337, "l33t"}));
DTP(A, {2, "two"}, {1337, "l33t"})
);
DT1 dt1; DT1 dt1;
DT2 dt2; DT2 dt2;
@ -400,6 +388,29 @@ TEST(DecisionTree, unzip) {
EXPECT(tree2.equals(dt2)); EXPECT(tree2.equals(dt2));
} }
/* ************************************************************************** */
// Test thresholding.
TEST(DecisionTree, threshold) {
// Create three level tree
vector<DT::LabelC> keys;
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
DT tree(keys, "0 1 2 3 4 5 6 7");
// Check number of leaves equal to zero
auto count = [](const int& value, int count) {
return value == 0 ? count + 1 : count;
};
EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
// Now threshold
auto threshold = [](int value) { return value < 5 ? 0 : value; };
DT thresholded(tree, threshold);
// Check number of leaves equal to zero now = 2
// Note: it is 2, because the pruned branches are counted as 1!
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -191,20 +191,36 @@ TEST(DiscreteConditional, marginals) {
DiscreteConditional prior(B % "1/2"); DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional; DiscreteConditional pAB = prior * conditional;
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first); DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "5/4"); DiscreteConditional pA(A % "5/4");
EXPECT(assert_equal(pA, actualA)); EXPECT(assert_equal(pA, actualA));
EXPECT_LONGS_EQUAL(1, actualA.nrFrontals()); EXPECT(actualA.frontals() == KeyVector{1});
EXPECT_LONGS_EQUAL(0, actualA.nrParents()); EXPECT_LONGS_EQUAL(0, actualA.nrParents());
KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals());
EXPECT((frontalsA == KeyVector{1}));
DiscreteConditional actualB = pAB.marginal(B.first); DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB)); EXPECT(assert_equal(prior, actualB));
EXPECT_LONGS_EQUAL(1, actualB.nrFrontals()); EXPECT(actualB.frontals() == KeyVector{0});
EXPECT_LONGS_EQUAL(0, actualB.nrParents()); EXPECT_LONGS_EQUAL(0, actualB.nrParents());
KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals()); }
EXPECT((frontalsB == KeyVector{0}));
/* ************************************************************************* */
// Check calculation of marginals in case branches are pruned
TEST(DiscreteConditional, marginals2) {
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);
DiscreteConditional pA(A % "8/4");
EXPECT(assert_equal(pA, actualA));
DiscreteConditional actualB = pAB.marginal(B.first);
EXPECT(assert_equal(prior, actualB));
} }
/* ************************************************************************* */ /* ************************************************************************* */