Merge pull request #1056 from borglab/feature/dt_threshold
commit
3d86bc7294
|
|
@ -18,8 +18,13 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -30,7 +35,8 @@ namespace gtsam {
|
||||||
template <typename L>
|
template <typename L>
|
||||||
class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
|
class GTSAM_EXPORT AlgebraicDecisionTree : public DecisionTree<L, double> {
|
||||||
/**
|
/**
|
||||||
* @brief Default method used by `labelFormatter` or `valueFormatter` when printing.
|
* @brief Default method used by `labelFormatter` or `valueFormatter` when
|
||||||
|
* printing.
|
||||||
*
|
*
|
||||||
* @param x The value passed to format.
|
* @param x The value passed to format.
|
||||||
* @return std::string
|
* @return std::string
|
||||||
|
|
@ -42,17 +48,12 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
using Base = DecisionTree<L, double>;
|
using Base = DecisionTree<L, double>;
|
||||||
|
|
||||||
/** The Real ring with addition and multiplication */
|
/** The Real ring with addition and multiplication */
|
||||||
struct Ring {
|
struct Ring {
|
||||||
static inline double zero() {
|
static inline double zero() { return 0.0; }
|
||||||
return 0.0;
|
static inline double one() { return 1.0; }
|
||||||
}
|
|
||||||
static inline double one() {
|
|
||||||
return 1.0;
|
|
||||||
}
|
|
||||||
static inline double add(const double& a, const double& b) {
|
static inline double add(const double& a, const double& b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
|
@ -65,39 +66,35 @@ namespace gtsam {
|
||||||
static inline double div(const double& a, const double& b) {
|
static inline double div(const double& a, const double& b) {
|
||||||
return a / b;
|
return a / b;
|
||||||
}
|
}
|
||||||
static inline double id(const double& x) {
|
static inline double id(const double& x) { return x; }
|
||||||
return x;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
AlgebraicDecisionTree() :
|
AlgebraicDecisionTree() : Base(1.0) {}
|
||||||
Base(1.0) {
|
|
||||||
}
|
|
||||||
|
|
||||||
AlgebraicDecisionTree(const Base& add) :
|
// Explicitly non-explicit constructor
|
||||||
Base(add) {
|
AlgebraicDecisionTree(const Base& add) : Base(add) {}
|
||||||
}
|
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/** Create a new leaf function splitting on a variable */
|
||||||
AlgebraicDecisionTree(const L& label, double y1, double y2) :
|
AlgebraicDecisionTree(const L& label, double y1, double y2)
|
||||||
Base(label, y1, y2) {
|
: Base(label, y1, y2) {}
|
||||||
}
|
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/** Create a new leaf function splitting on a variable */
|
||||||
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1, double y2) :
|
AlgebraicDecisionTree(const typename Base::LabelC& labelC, double y1,
|
||||||
Base(labelC, y1, y2) {
|
double y2)
|
||||||
}
|
: Base(labelC, y1, y2) {}
|
||||||
|
|
||||||
/** Create from keys and vector table */
|
/** Create from keys and vector table */
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Base::LabelC>& labelCs, const std::vector<double>& ys) {
|
(const std::vector<typename Base::LabelC>& labelCs,
|
||||||
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
const std::vector<double>& ys) {
|
||||||
ys.end());
|
this->root_ =
|
||||||
|
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create from keys and string table */
|
/** Create from keys and string table */
|
||||||
AlgebraicDecisionTree //
|
AlgebraicDecisionTree //
|
||||||
(const std::vector<typename Base::LabelC>& labelCs, const std::string& table) {
|
(const std::vector<typename Base::LabelC>& labelCs,
|
||||||
|
const std::string& table) {
|
||||||
// Convert string to doubles
|
// Convert string to doubles
|
||||||
std::vector<double> ys;
|
std::vector<double> ys;
|
||||||
std::istringstream iss(table);
|
std::istringstream iss(table);
|
||||||
|
|
@ -105,14 +102,14 @@ namespace gtsam {
|
||||||
std::istream_iterator<double>(), std::back_inserter(ys));
|
std::istream_iterator<double>(), std::back_inserter(ys));
|
||||||
|
|
||||||
// now call recursive Create
|
// now call recursive Create
|
||||||
this->root_ = Base::create(labelCs.begin(), labelCs.end(), ys.begin(),
|
this->root_ =
|
||||||
ys.end());
|
Base::create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Create a new function splitting on a variable */
|
/** Create a new function splitting on a variable */
|
||||||
template <typename Iterator>
|
template <typename Iterator>
|
||||||
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) :
|
AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label)
|
||||||
Base(nullptr) {
|
: Base(nullptr) {
|
||||||
this->root_ = compose(begin, end, label);
|
this->root_ = compose(begin, end, label);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -177,8 +174,8 @@ namespace gtsam {
|
||||||
return Base::equals(other, compare);
|
return Base::equals(other, compare);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
// AlgebraicDecisionTree
|
|
||||||
|
|
||||||
template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {};
|
template <typename T>
|
||||||
}
|
struct traits<AlgebraicDecisionTree<T>>
|
||||||
// namespace gtsam
|
: public Testable<AlgebraicDecisionTree<T>> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -21,42 +21,44 @@
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <boost/assign/std/vector.hpp>
|
#include <boost/assign/std/vector.hpp>
|
||||||
#include <boost/format.hpp>
|
#include <boost/format.hpp>
|
||||||
|
#include <boost/make_shared.hpp>
|
||||||
#include <boost/noncopyable.hpp>
|
#include <boost/noncopyable.hpp>
|
||||||
#include <boost/optional.hpp>
|
#include <boost/optional.hpp>
|
||||||
#include <boost/tuple/tuple.hpp>
|
#include <boost/tuple/tuple.hpp>
|
||||||
#include <boost/type_traits/has_dereference.hpp>
|
#include <boost/type_traits/has_dereference.hpp>
|
||||||
#include <boost/unordered_set.hpp>
|
#include <boost/unordered_set.hpp>
|
||||||
#include <boost/make_shared.hpp>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <list>
|
#include <list>
|
||||||
|
#include <map>
|
||||||
|
#include <set>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
using boost::assign::operator+=;
|
using boost::assign::operator+=;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Node
|
// Node
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
int DecisionTree<L, Y>::Node::nrNodes = 0;
|
int DecisionTree<L, Y>::Node::nrNodes = 0;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Leaf
|
// Leaf
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
class DecisionTree<L, Y>::Leaf: public DecisionTree<L, Y>::Node {
|
struct DecisionTree<L, Y>::Leaf : public DecisionTree<L, Y>::Node {
|
||||||
|
|
||||||
/** constant stored in this leaf */
|
/** constant stored in this leaf */
|
||||||
Y constant_;
|
Y constant_;
|
||||||
|
|
||||||
public:
|
|
||||||
|
|
||||||
/** Constructor from constant */
|
/** Constructor from constant */
|
||||||
Leaf(const Y& constant) :
|
Leaf(const Y& constant) :
|
||||||
constant_(constant) {}
|
constant_(constant) {}
|
||||||
|
|
@ -96,7 +98,7 @@ namespace gtsam {
|
||||||
std::string value = valueFormatter(constant_);
|
std::string value = valueFormatter(constant_);
|
||||||
if (showZero || value.compare("0"))
|
if (showZero || value.compare("0"))
|
||||||
os << "\"" << this->id() << "\" [label=\"" << value
|
os << "\"" << this->id() << "\" [label=\"" << value
|
||||||
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55,
|
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
/** evaluate */
|
/** evaluate */
|
||||||
|
|
@ -136,15 +138,13 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isLeaf() const override { return true; }
|
bool isLeaf() const override { return true; }
|
||||||
|
|
||||||
}; // Leaf
|
}; // Leaf
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Choice
|
// Choice
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
class DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
|
struct DecisionTree<L, Y>::Choice: public DecisionTree<L, Y>::Node {
|
||||||
|
|
||||||
/** the label of the variable on which we split */
|
/** the label of the variable on which we split */
|
||||||
L label_;
|
L label_;
|
||||||
|
|
||||||
|
|
@ -158,10 +158,10 @@ namespace gtsam {
|
||||||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
~Choice() override {
|
~Choice() override {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl;
|
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
||||||
|
<< std::std::endl;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -172,7 +172,8 @@ namespace gtsam {
|
||||||
assert(f->branches().size() > 0);
|
assert(f->branches().size() > 0);
|
||||||
NodePtr f0 = f->branches_[0];
|
NodePtr f0 = f->branches_[0];
|
||||||
assert(f0->isLeaf());
|
assert(f0->isLeaf());
|
||||||
NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
NodePtr newLeaf(
|
||||||
|
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
|
||||||
return newLeaf;
|
return newLeaf;
|
||||||
} else
|
} else
|
||||||
#endif
|
#endif
|
||||||
|
|
@ -192,7 +193,6 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
Choice(const Choice& f, const Choice& g, const Binary& op) :
|
Choice(const Choice& f, const Choice& g, const Binary& op) :
|
||||||
allSame_(true) {
|
allSame_(true) {
|
||||||
|
|
||||||
// Choose what to do based on label
|
// Choose what to do based on label
|
||||||
if (f.label() > g.label()) {
|
if (f.label() > g.label()) {
|
||||||
// f higher than g
|
// f higher than g
|
||||||
|
|
@ -318,10 +318,8 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
Choice(const L& label, const Choice& f, const Unary& op) :
|
Choice(const L& label, const Choice& f, const Unary& op) :
|
||||||
label_(label), allSame_(true) {
|
label_(label), allSame_(true) {
|
||||||
|
|
||||||
branches_.reserve(f.branches_.size()); // reserve space
|
branches_.reserve(f.branches_.size()); // reserve space
|
||||||
for (const NodePtr& branch: f.branches_)
|
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
|
||||||
push_back(branch->apply(op));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** apply unary operator */
|
/** apply unary operator */
|
||||||
|
|
@ -364,8 +362,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/** choose a branch, recursively */
|
/** choose a branch, recursively */
|
||||||
NodePtr choose(const L& label, size_t index) const override {
|
NodePtr choose(const L& label, size_t index) const override {
|
||||||
if (label_ == label)
|
if (label_ == label) return branches_[index]; // choose branch
|
||||||
return branches_[index]; // choose branch
|
|
||||||
|
|
||||||
// second case, not label of interest, just recurse
|
// second case, not label of interest, just recurse
|
||||||
auto r = boost::make_shared<Choice>(label_, branches_.size());
|
auto r = boost::make_shared<Choice>(label_, branches_.size());
|
||||||
|
|
@ -373,12 +370,11 @@ namespace gtsam {
|
||||||
r->push_back(branch->choose(label, index));
|
r->push_back(branch->choose(label, index));
|
||||||
return Unique(r);
|
return Unique(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
}; // Choice
|
}; // Choice
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// DecisionTree
|
// DecisionTree
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree() {
|
DecisionTree<L, Y>::DecisionTree() {
|
||||||
}
|
}
|
||||||
|
|
@ -388,13 +384,13 @@ namespace gtsam {
|
||||||
root_(root) {
|
root_(root) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const Y& y) {
|
DecisionTree<L, Y>::DecisionTree(const Y& y) {
|
||||||
root_ = NodePtr(new Leaf(y));
|
root_ = NodePtr(new Leaf(y));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
|
DecisionTree<L, Y>::DecisionTree(const L& label, const Y& y1, const Y& y2) {
|
||||||
auto a = boost::make_shared<Choice>(label, 2);
|
auto a = boost::make_shared<Choice>(label, 2);
|
||||||
|
|
@ -404,7 +400,7 @@ namespace gtsam {
|
||||||
root_ = Choice::Unique(a);
|
root_ = Choice::Unique(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
|
DecisionTree<L, Y>::DecisionTree(const LabelC& labelC, const Y& y1,
|
||||||
const Y& y2) {
|
const Y& y2) {
|
||||||
|
|
@ -417,7 +413,7 @@ namespace gtsam {
|
||||||
root_ = Choice::Unique(a);
|
root_ = Choice::Unique(a);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||||
const std::vector<Y>& ys) {
|
const std::vector<Y>& ys) {
|
||||||
|
|
@ -425,11 +421,10 @@ namespace gtsam {
|
||||||
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
DecisionTree<L, Y>::DecisionTree(const std::vector<LabelC>& labelCs,
|
||||||
const std::string& table) {
|
const std::string& table) {
|
||||||
|
|
||||||
// Convert std::string to values of type Y
|
// Convert std::string to values of type Y
|
||||||
std::vector<Y> ys;
|
std::vector<Y> ys;
|
||||||
std::istringstream iss(table);
|
std::istringstream iss(table);
|
||||||
|
|
@ -440,14 +435,14 @@ namespace gtsam {
|
||||||
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
|
template<typename Iterator> DecisionTree<L, Y>::DecisionTree(
|
||||||
Iterator begin, Iterator end, const L& label) {
|
Iterator begin, Iterator end, const L& label) {
|
||||||
root_ = compose(begin, end, label);
|
root_ = compose(begin, end, label);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y>::DecisionTree(const L& label,
|
DecisionTree<L, Y>::DecisionTree(const L& label,
|
||||||
const DecisionTree& f0, const DecisionTree& f1) {
|
const DecisionTree& f0, const DecisionTree& f1) {
|
||||||
|
|
@ -456,7 +451,7 @@ namespace gtsam {
|
||||||
root_ = compose(functions.begin(), functions.end(), label);
|
root_ = compose(functions.begin(), functions.end(), label);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename X, typename Func>
|
template <typename X, typename Func>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<L, X>& other,
|
||||||
|
|
@ -466,7 +461,7 @@ namespace gtsam {
|
||||||
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
root_ = convertFrom<L, X>(other.root_, L_of_L, Y_of_X);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename M, typename X, typename Func>
|
template <typename M, typename X, typename Func>
|
||||||
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
DecisionTree<L, Y>::DecisionTree(const DecisionTree<M, X>& other,
|
||||||
|
|
@ -475,16 +470,16 @@ namespace gtsam {
|
||||||
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
root_ = convertFrom<M, X>(other.root_, L_of_M, Y_of_X);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Called by two constructors above.
|
// Called by two constructors above.
|
||||||
// Takes a label and a corresponding range of decision trees, and creates a new
|
// Takes a label and a corresponding range of decision trees, and creates a
|
||||||
// decision tree. However, the order of the labels needs to be respected, so we
|
// new decision tree. However, the order of the labels needs to be respected,
|
||||||
// cannot just create a root Choice node on the label: if the label is not the
|
// so we cannot just create a root Choice node on the label: if the label is
|
||||||
// highest label, we need to do a complicated and expensive recursive call.
|
// not the highest label, we need a complicated/ expensive recursive call.
|
||||||
template<typename L, typename Y> template<typename Iterator>
|
template <typename L, typename Y>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(Iterator begin,
|
template <typename Iterator>
|
||||||
Iterator end, const L& label) const {
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
|
||||||
|
Iterator begin, Iterator end, const L& label) const {
|
||||||
// find highest label among branches
|
// find highest label among branches
|
||||||
boost::optional<L> highestLabel;
|
boost::optional<L> highestLabel;
|
||||||
size_t nrChoices = 0;
|
size_t nrChoices = 0;
|
||||||
|
|
@ -527,7 +522,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// "create" is a bit of a complicated thing, but very useful.
|
// "create" is a bit of a complicated thing, but very useful.
|
||||||
// It takes a range of labels and a corresponding range of values,
|
// It takes a range of labels and a corresponding range of values,
|
||||||
// and creates a decision tree, as follows:
|
// and creates a decision tree, as follows:
|
||||||
|
|
@ -552,7 +547,6 @@ namespace gtsam {
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::create(
|
||||||
It begin, It end, ValueIt beginY, ValueIt endY) const {
|
It begin, It end, ValueIt beginY, ValueIt endY) const {
|
||||||
|
|
||||||
// get crucial counts
|
// get crucial counts
|
||||||
size_t nrChoices = begin->second;
|
size_t nrChoices = begin->second;
|
||||||
size_t size = endY - beginY;
|
size_t size = endY - beginY;
|
||||||
|
|
@ -564,7 +558,11 @@ namespace gtsam {
|
||||||
// Create a simple choice node with values as leaves.
|
// Create a simple choice node with values as leaves.
|
||||||
if (size != nrChoices) {
|
if (size != nrChoices) {
|
||||||
std::cout << "Trying to create DD on " << begin->first << std::endl;
|
std::cout << "Trying to create DD on " << begin->first << std::endl;
|
||||||
std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
|
std::cout << boost::format(
|
||||||
|
"DecisionTree::create: expected %d values but got %d "
|
||||||
|
"instead") %
|
||||||
|
nrChoices % size
|
||||||
|
<< std::endl;
|
||||||
throw std::invalid_argument("DecisionTree::create invalid argument");
|
throw std::invalid_argument("DecisionTree::create invalid argument");
|
||||||
}
|
}
|
||||||
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
|
auto choice = boost::make_shared<Choice>(begin->first, endY - beginY);
|
||||||
|
|
@ -585,7 +583,7 @@ namespace gtsam {
|
||||||
return compose(functions.begin(), functions.end(), begin->first);
|
return compose(functions.begin(), functions.end(), begin->first);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||||
|
|
@ -594,8 +592,8 @@ namespace gtsam {
|
||||||
std::function<Y(const X&)> Y_of_X) const {
|
std::function<Y(const X&)> Y_of_X) const {
|
||||||
using LY = DecisionTree<L, Y>;
|
using LY = DecisionTree<L, Y>;
|
||||||
|
|
||||||
// ugliness below because apparently we can't have templated virtual functions
|
// ugliness below because apparently we can't have templated virtual
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf
|
// functions If leaf, apply unary conversion "op" and create a unique leaf
|
||||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||||
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
|
if (auto leaf = boost::dynamic_pointer_cast<const MXLeaf>(f))
|
||||||
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
return NodePtr(new Leaf(Y_of_X(leaf->constant())));
|
||||||
|
|
@ -618,12 +616,12 @@ namespace gtsam {
|
||||||
return LY::compose(functions.begin(), functions.end(), newLabel);
|
return LY::compose(functions.begin(), functions.end(), newLabel);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Functor performing depth-first visit without Assignment<L> argument.
|
// Functor performing depth-first visit without Assignment<L> argument.
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
struct Visit {
|
struct Visit {
|
||||||
using F = std::function<void(const Y&)>;
|
using F = std::function<void(const Y&)>;
|
||||||
Visit(F f) : f(f) {} ///< Construct from folding function.
|
explicit Visit(F f) : f(f) {} ///< Construct from folding function.
|
||||||
F f; ///< folding function object.
|
F f; ///< folding function object.
|
||||||
|
|
||||||
/// Do a depth-first visit on the tree rooted at node.
|
/// Do a depth-first visit on the tree rooted at node.
|
||||||
|
|
@ -647,13 +645,13 @@ namespace gtsam {
|
||||||
visit(root_);
|
visit(root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// Functor performing depth-first visit with Assignment<L> argument.
|
// Functor performing depth-first visit with Assignment<L> argument.
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
struct VisitWith {
|
struct VisitWith {
|
||||||
using Choices = Assignment<L>;
|
using Choices = Assignment<L>;
|
||||||
using F = std::function<void(const Choices&, const Y&)>;
|
using F = std::function<void(const Choices&, const Y&)>;
|
||||||
VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
explicit VisitWith(F f) : f(f) {} ///< Construct from folding function.
|
||||||
Choices choices; ///< Assignment, mutating through recursion.
|
Choices choices; ///< Assignment, mutating through recursion.
|
||||||
F f; ///< folding function object.
|
F f; ///< folding function object.
|
||||||
|
|
||||||
|
|
@ -681,7 +679,7 @@ namespace gtsam {
|
||||||
visit(root_);
|
visit(root_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// fold is just done with a visit
|
// fold is just done with a visit
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename Func, typename X>
|
template <typename Func, typename X>
|
||||||
|
|
@ -690,7 +688,7 @@ namespace gtsam {
|
||||||
return x0;
|
return x0;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// labels is just done with a visit
|
// labels is just done with a visit
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
std::set<L> DecisionTree<L, Y>::labels() const {
|
std::set<L> DecisionTree<L, Y>::labels() const {
|
||||||
|
|
@ -702,7 +700,7 @@ namespace gtsam {
|
||||||
return unique;
|
return unique;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
bool DecisionTree<L, Y>::equals(const DecisionTree& other,
|
||||||
const CompareFunc& compare) const {
|
const CompareFunc& compare) const {
|
||||||
|
|
@ -736,7 +734,7 @@ namespace gtsam {
|
||||||
return DecisionTree(root_->apply(op));
|
return DecisionTree(root_->apply(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
|
||||||
const Binary& op) const {
|
const Binary& op) const {
|
||||||
|
|
@ -752,7 +750,7 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
// The way this works:
|
// The way this works:
|
||||||
// We have an ADT, picture it as a tree.
|
// We have an ADT, picture it as a tree.
|
||||||
// At a certain depth, we have a branch on "label".
|
// At a certain depth, we have a branch on "label".
|
||||||
|
|
@ -772,7 +770,7 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/****************************************************************************/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
void DecisionTree<L, Y>::dot(std::ostream& os,
|
void DecisionTree<L, Y>::dot(std::ostream& os,
|
||||||
const LabelFormatter& labelFormatter,
|
const LabelFormatter& labelFormatter,
|
||||||
|
|
@ -790,9 +788,11 @@ namespace gtsam {
|
||||||
bool showZero) const {
|
bool showZero) const {
|
||||||
std::ofstream os((name + ".dot").c_str());
|
std::ofstream os((name + ".dot").c_str());
|
||||||
dot(os, labelFormatter, valueFormatter, showZero);
|
dot(os, labelFormatter, valueFormatter, showZero);
|
||||||
int result = system(
|
int result =
|
||||||
("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str());
|
system(("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null")
|
||||||
if (result==-1) throw std::runtime_error("DecisionTree::dot system call failed");
|
.c_str());
|
||||||
|
if (result == -1)
|
||||||
|
throw std::runtime_error("DecisionTree::dot system call failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
|
|
@ -804,8 +804,6 @@ namespace gtsam {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*********************************************************************************/
|
/******************************************************************************/
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,9 +26,11 @@
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <sstream>
|
|
||||||
#include <vector>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -39,7 +41,6 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
template<typename L, typename Y>
|
template<typename L, typename Y>
|
||||||
class DecisionTree {
|
class DecisionTree {
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// Default method for comparison of two objects of type Y.
|
/// Default method for comparison of two objects of type Y.
|
||||||
static bool DefaultCompare(const Y& a, const Y& b) {
|
static bool DefaultCompare(const Y& a, const Y& b) {
|
||||||
|
|
@ -47,7 +48,6 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
using LabelFormatter = std::function<std::string(L)>;
|
using LabelFormatter = std::function<std::string(L)>;
|
||||||
using ValueFormatter = std::function<std::string(Y)>;
|
using ValueFormatter = std::function<std::string(Y)>;
|
||||||
using CompareFunc = std::function<bool(const Y&, const Y&)>;
|
using CompareFunc = std::function<bool(const Y&, const Y&)>;
|
||||||
|
|
@ -60,12 +60,11 @@ namespace gtsam {
|
||||||
using LabelC = std::pair<L, size_t>;
|
using LabelC = std::pair<L, size_t>;
|
||||||
|
|
||||||
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
|
/** DTs consist of Leaf and Choice nodes, both subclasses of Node */
|
||||||
class Leaf;
|
struct Leaf;
|
||||||
class Choice;
|
struct Choice;
|
||||||
|
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
class Node {
|
struct Node {
|
||||||
public:
|
|
||||||
using Ptr = boost::shared_ptr<const Node>;
|
using Ptr = boost::shared_ptr<const Node>;
|
||||||
|
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
|
|
@ -75,14 +74,16 @@ namespace gtsam {
|
||||||
// Constructor
|
// Constructor
|
||||||
Node() {
|
Node() {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush();
|
std::cout << ++nrNodes << " constructed " << id() << std::endl;
|
||||||
|
std::cout.flush();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Destructor
|
// Destructor
|
||||||
virtual ~Node() {
|
virtual ~Node() {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush();
|
std::cout << --nrNodes << " destructed " << id() << std::endl;
|
||||||
|
std::cout.flush();
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -111,7 +112,6 @@ namespace gtsam {
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/** A function is a shared pointer to the root of a DT */
|
/** A function is a shared pointer to the root of a DT */
|
||||||
using NodePtr = typename Node::Ptr;
|
using NodePtr = typename Node::Ptr;
|
||||||
|
|
||||||
|
|
@ -119,8 +119,9 @@ namespace gtsam {
|
||||||
NodePtr root_;
|
NodePtr root_;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
/** Internal recursive function to create from keys, cardinalities,
|
||||||
/** Internal recursive function to create from keys, cardinalities, and Y values */
|
* and Y values
|
||||||
|
*/
|
||||||
template<typename It, typename ValueIt>
|
template<typename It, typename ValueIt>
|
||||||
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const;
|
||||||
|
|
||||||
|
|
@ -140,7 +141,6 @@ namespace gtsam {
|
||||||
std::function<Y(const X&)> Y_of_X) const;
|
std::function<Y(const X&)> Y_of_X) const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
@ -148,7 +148,7 @@ namespace gtsam {
|
||||||
DecisionTree();
|
DecisionTree();
|
||||||
|
|
||||||
/** Create a constant */
|
/** Create a constant */
|
||||||
DecisionTree(const Y& y);
|
explicit DecisionTree(const Y& y);
|
||||||
|
|
||||||
/** Create a new leaf function splitting on a variable */
|
/** Create a new leaf function splitting on a variable */
|
||||||
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
DecisionTree(const L& label, const Y& y1, const Y& y2);
|
||||||
|
|
@ -167,8 +167,8 @@ namespace gtsam {
|
||||||
DecisionTree(Iterator begin, Iterator end, const L& label);
|
DecisionTree(Iterator begin, Iterator end, const L& label);
|
||||||
|
|
||||||
/** Create DecisionTree from two others */
|
/** Create DecisionTree from two others */
|
||||||
DecisionTree(const L& label, //
|
DecisionTree(const L& label, const DecisionTree& f0,
|
||||||
const DecisionTree& f0, const DecisionTree& f1);
|
const DecisionTree& f1);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert from a different value type.
|
* @brief Convert from a different value type.
|
||||||
|
|
@ -234,6 +234,8 @@ namespace gtsam {
|
||||||
*
|
*
|
||||||
* @param f side-effect taking a value.
|
* @param f side-effect taking a value.
|
||||||
*
|
*
|
||||||
|
* @note Due to pruning, leaves might not exhaust choices.
|
||||||
|
*
|
||||||
* Example:
|
* Example:
|
||||||
* int sum = 0;
|
* int sum = 0;
|
||||||
* auto visitor = [&](int y) { sum += y; };
|
* auto visitor = [&](int y) { sum += y; };
|
||||||
|
|
@ -247,6 +249,8 @@ namespace gtsam {
|
||||||
*
|
*
|
||||||
* @param f side-effect taking an assignment and a value.
|
* @param f side-effect taking an assignment and a value.
|
||||||
*
|
*
|
||||||
|
* @note Due to pruning, leaves might not exhaust choices.
|
||||||
|
*
|
||||||
* Example:
|
* Example:
|
||||||
* int sum = 0;
|
* int sum = 0;
|
||||||
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
|
* auto visitor = [&](const Assignment<L>& choices, int y) { sum += y; };
|
||||||
|
|
@ -264,6 +268,7 @@ namespace gtsam {
|
||||||
* @return X final value for accumulator.
|
* @return X final value for accumulator.
|
||||||
*
|
*
|
||||||
* @note X is always passed by value.
|
* @note X is always passed by value.
|
||||||
|
* @note Due to pruning, leaves might not exhaust choices.
|
||||||
*
|
*
|
||||||
* Example:
|
* Example:
|
||||||
* auto add = [](const double& y, double x) { return y + x; };
|
* auto add = [](const double& y, double x) { return y + x; };
|
||||||
|
|
@ -289,7 +294,8 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** combine subtrees on key with binary operation "op" */
|
/** combine subtrees on key with binary operation "op" */
|
||||||
DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const;
|
DecisionTree combine(const L& label, size_t cardinality,
|
||||||
|
const Binary& op) const;
|
||||||
|
|
||||||
/** combine with LabelC for convenience */
|
/** combine with LabelC for convenience */
|
||||||
DecisionTree combine(const LabelC& labelC, const Binary& op) const {
|
DecisionTree combine(const LabelC& labelC, const Binary& op) const {
|
||||||
|
|
@ -313,14 +319,13 @@ namespace gtsam {
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
// internal use only
|
// internal use only
|
||||||
DecisionTree(const NodePtr& root);
|
explicit DecisionTree(const NodePtr& root);
|
||||||
|
|
||||||
// internal use only
|
// internal use only
|
||||||
template<typename Iterator> NodePtr
|
template<typename Iterator> NodePtr
|
||||||
compose(Iterator begin, Iterator end, const L& label) const;
|
compose(Iterator begin, Iterator end, const L& label) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
}; // DecisionTree
|
}; // DecisionTree
|
||||||
|
|
||||||
/** free versions of apply */
|
/** free versions of apply */
|
||||||
|
|
@ -340,11 +345,19 @@ namespace gtsam {
|
||||||
return f.apply(g, op);
|
return f.apply(g, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// unzip a DecisionTree if its leaves are `std::pair`
|
/**
|
||||||
|
* @brief unzip a DecisionTree with `std::pair` values.
|
||||||
|
*
|
||||||
|
* @param input the DecisionTree with `(T1,T2)` values.
|
||||||
|
* @return a pair of DecisionTree on T1 and T2, respectively.
|
||||||
|
*/
|
||||||
template <typename L, typename T1, typename T2>
|
template <typename L, typename T1, typename T2>
|
||||||
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(const DecisionTree<L, std::pair<T1, T2> > &input) {
|
std::pair<DecisionTree<L, T1>, DecisionTree<L, T2> > unzip(
|
||||||
return std::make_pair(DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
|
const DecisionTree<L, std::pair<T1, T2> >& input) {
|
||||||
DecisionTree<L, T2>(input, [](std::pair<T1, T2> i) { return i.second; }));
|
return std::make_pair(
|
||||||
|
DecisionTree<L, T1>(input, [](std::pair<T1, T2> i) { return i.first; }),
|
||||||
|
DecisionTree<L, T2>(input,
|
||||||
|
[](std::pair<T1, T2> i) { return i.second; }));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,9 @@
|
||||||
* @author Frank Dellaert
|
* @author Frank Dellaert
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/FastSet.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/base/FastSet.h>
|
|
||||||
|
|
||||||
#include <boost/make_shared.hpp>
|
#include <boost/make_shared.hpp>
|
||||||
#include <boost/format.hpp>
|
#include <boost/format.hpp>
|
||||||
|
|
@ -29,34 +29,34 @@ using namespace std;
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor() {
|
DecisionTreeFactor::DecisionTreeFactor() {}
|
||||||
}
|
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
|
||||||
const ADT& potentials) :
|
|
||||||
DiscreteFactor(keys.indices()), ADT(potentials),
|
|
||||||
cardinalities_(keys.cardinalities()) {
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) {
|
const ADT& potentials)
|
||||||
}
|
: DiscreteFactor(keys.indices()),
|
||||||
|
ADT(potentials),
|
||||||
|
cardinalities_(keys.cardinalities()) {}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
bool DecisionTreeFactor::equals(const DiscreteFactor& other, double tol) const {
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c)
|
||||||
|
: DiscreteFactor(c.keys()),
|
||||||
|
AlgebraicDecisionTree<Key>(c),
|
||||||
|
cardinalities_(c.cardinalities_) {}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
|
bool DecisionTreeFactor::equals(const DiscreteFactor& other,
|
||||||
|
double tol) const {
|
||||||
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
const auto& f(static_cast<const DecisionTreeFactor&>(other));
|
||||||
return ADT::equals(f, tol);
|
return ADT::equals(f, tol);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
double DecisionTreeFactor::safe_div(const double& a, const double& b) {
|
||||||
// The use for safe_div is when we divide the product factor by the sum
|
// The use for safe_div is when we divide the product factor by the sum
|
||||||
// factor. If the product or sum is zero, we accord zero probability to the
|
// factor. If the product or sum is zero, we accord zero probability to the
|
||||||
|
|
@ -64,7 +64,7 @@ namespace gtsam {
|
||||||
return (a == 0 || b == 0) ? 0 : (a / b);
|
return (a == 0 || b == 0) ? 0 : (a / b);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
void DecisionTreeFactor::print(const string& s,
|
void DecisionTreeFactor::print(const string& s,
|
||||||
const KeyFormatter& formatter) const {
|
const KeyFormatter& formatter) const {
|
||||||
cout << s;
|
cout << s;
|
||||||
|
|
@ -75,7 +75,7 @@ namespace gtsam {
|
||||||
ADT::print("", formatter);
|
ADT::print("", formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
|
DecisionTreeFactor DecisionTreeFactor::apply(const DecisionTreeFactor& f,
|
||||||
ADT::Binary op) const {
|
ADT::Binary op) const {
|
||||||
map<Key, size_t> cs; // new cardinalities
|
map<Key, size_t> cs; // new cardinalities
|
||||||
|
|
@ -84,22 +84,23 @@ namespace gtsam {
|
||||||
for (Key j : f.keys()) cs[j] = f.cardinality(j);
|
for (Key j : f.keys()) cs[j] = f.cardinality(j);
|
||||||
// Convert map into keys
|
// Convert map into keys
|
||||||
DiscreteKeys keys;
|
DiscreteKeys keys;
|
||||||
for(const std::pair<const Key,size_t>& key: cs)
|
for (const std::pair<const Key, size_t>& key : cs) keys.push_back(key);
|
||||||
keys.push_back(key);
|
|
||||||
// apply operand
|
// apply operand
|
||||||
ADT result = ADT::apply(f, op);
|
ADT result = ADT::apply(f, op);
|
||||||
// Make a new factor
|
// Make a new factor
|
||||||
return DecisionTreeFactor(keys, result);
|
return DecisionTreeFactor(keys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(size_t nrFrontals,
|
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||||
ADT::Binary op) const {
|
size_t nrFrontals, ADT::Binary op) const {
|
||||||
|
if (nrFrontals > size())
|
||||||
if (nrFrontals > size()) throw invalid_argument(
|
throw invalid_argument(
|
||||||
(boost::format(
|
(boost::format(
|
||||||
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
|
"DecisionTreeFactor::combine: invalid number of frontal "
|
||||||
% nrFrontals % size()).str());
|
"keys %d, nr.keys=%d") %
|
||||||
|
nrFrontals % size())
|
||||||
|
.str());
|
||||||
|
|
||||||
// sum over nrFrontals keys
|
// sum over nrFrontals keys
|
||||||
size_t i;
|
size_t i;
|
||||||
|
|
@ -118,15 +119,16 @@ namespace gtsam {
|
||||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
/* ************************************************************************* */
|
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
|
||||||
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(const Ordering& frontalKeys,
|
const Ordering& frontalKeys, ADT::Binary op) const {
|
||||||
ADT::Binary op) const {
|
if (frontalKeys.size() > size())
|
||||||
|
throw invalid_argument(
|
||||||
if (frontalKeys.size() > size()) throw invalid_argument(
|
|
||||||
(boost::format(
|
(boost::format(
|
||||||
"DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d")
|
"DecisionTreeFactor::combine: invalid number of frontal "
|
||||||
% frontalKeys.size() % size()).str());
|
"keys %d, nr.keys=%d") %
|
||||||
|
frontalKeys.size() % size())
|
||||||
|
.str());
|
||||||
|
|
||||||
// sum over nrFrontals keys
|
// sum over nrFrontals keys
|
||||||
size_t i;
|
size_t i;
|
||||||
|
|
@ -137,20 +139,22 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
// create new factor, note we collect keys that are not in frontalKeys
|
// create new factor, note we collect keys that are not in frontalKeys
|
||||||
// TODO: why do we need this??? result should contain correct keys!!!
|
// TODO(frank): why do we need this??? result should contain correct keys!!!
|
||||||
DiscreteKeys dkeys;
|
DiscreteKeys dkeys;
|
||||||
for (i = 0; i < keys().size(); i++) {
|
for (i = 0; i < keys().size(); i++) {
|
||||||
Key j = keys()[i];
|
Key j = keys()[i];
|
||||||
// TODO: inefficient!
|
// TODO(frank): inefficient!
|
||||||
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) != frontalKeys.end())
|
if (std::find(frontalKeys.begin(), frontalKeys.end(), j) !=
|
||||||
|
frontalKeys.end())
|
||||||
continue;
|
continue;
|
||||||
dkeys.push_back(DiscreteKey(j, cardinality(j)));
|
dkeys.push_back(DiscreteKey(j, cardinality(j)));
|
||||||
}
|
}
|
||||||
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
return boost::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate() const {
|
std::vector<std::pair<DiscreteValues, double>> DecisionTreeFactor::enumerate()
|
||||||
|
const {
|
||||||
// Get all possible assignments
|
// Get all possible assignments
|
||||||
std::vector<std::pair<Key, size_t>> pairs;
|
std::vector<std::pair<Key, size_t>> pairs;
|
||||||
for (auto& key : keys()) {
|
for (auto& key : keys()) {
|
||||||
|
|
@ -168,7 +172,7 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
DiscreteKeys DecisionTreeFactor::discreteKeys() const {
|
||||||
DiscreteKeys result;
|
DiscreteKeys result;
|
||||||
for (auto&& key : keys()) {
|
for (auto&& key : keys()) {
|
||||||
|
|
@ -180,7 +184,7 @@ namespace gtsam {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
static std::string valueFormatter(const double& v) {
|
static std::string valueFormatter(const double& v) {
|
||||||
return (boost::format("%4.2g") % v).str();
|
return (boost::format("%4.2g") % v).str();
|
||||||
}
|
}
|
||||||
|
|
@ -206,7 +210,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Print out header.
|
// Print out header.
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
|
string DecisionTreeFactor::markdown(const KeyFormatter& keyFormatter,
|
||||||
const Names& names) const {
|
const Names& names) const {
|
||||||
stringstream ss;
|
stringstream ss;
|
||||||
|
|
@ -271,17 +275,19 @@ namespace gtsam {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) :
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
|
const vector<double>& table)
|
||||||
cardinalities_(keys.cardinalities()) {
|
: DiscreteFactor(keys.indices()),
|
||||||
}
|
AlgebraicDecisionTree<Key>(keys, table),
|
||||||
|
cardinalities_(keys.cardinalities()) {}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) :
|
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
|
const string& table)
|
||||||
cardinalities_(keys.cardinalities()) {
|
: DiscreteFactor(keys.indices()),
|
||||||
}
|
AlgebraicDecisionTree<Key>(keys, table),
|
||||||
|
cardinalities_(keys.cardinalities()) {}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************ */
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -18,16 +18,18 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteFactor.h>
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
|
||||||
#include <gtsam/inference/Ordering.h>
|
#include <gtsam/inference/Ordering.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
|
#include <map>
|
||||||
#include <vector>
|
|
||||||
#include <exception>
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -36,10 +38,9 @@ namespace gtsam {
|
||||||
/**
|
/**
|
||||||
* A discrete probabilistic factor
|
* A discrete probabilistic factor
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> {
|
class GTSAM_EXPORT DecisionTreeFactor : public DiscreteFactor,
|
||||||
|
public AlgebraicDecisionTree<Key> {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
// typedefs needed to play nice with gtsam
|
// typedefs needed to play nice with gtsam
|
||||||
typedef DecisionTreeFactor This;
|
typedef DecisionTreeFactor This;
|
||||||
typedef DiscreteFactor Base; ///< Typedef to base class
|
typedef DiscreteFactor Base; ///< Typedef to base class
|
||||||
|
|
@ -50,7 +51,6 @@ namespace gtsam {
|
||||||
std::map<Key, size_t> cardinalities_;
|
std::map<Key, size_t> cardinalities_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
@ -61,7 +61,8 @@ namespace gtsam {
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
|
||||||
|
|
||||||
/** Constructor from doubles */
|
/** Constructor from doubles */
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table);
|
DecisionTreeFactor(const DiscreteKeys& keys,
|
||||||
|
const std::vector<double>& table);
|
||||||
|
|
||||||
/** Constructor from string */
|
/** Constructor from string */
|
||||||
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
|
||||||
|
|
@ -86,7 +87,8 @@ namespace gtsam {
|
||||||
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
||||||
|
|
||||||
// print
|
// print
|
||||||
void print(const std::string& s = "DecisionTreeFactor:\n",
|
void print(
|
||||||
|
const std::string& s = "DecisionTreeFactor:\n",
|
||||||
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
@ -113,9 +115,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert into a decisiontree
|
/// Convert into a decisiontree
|
||||||
DecisionTreeFactor toDecisionTreeFactor() const override {
|
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create new factor by summing all values with the same separator values
|
/// Create new factor by summing all values with the same separator values
|
||||||
shared_ptr sum(size_t nrFrontals) const {
|
shared_ptr sum(size_t nrFrontals) const {
|
||||||
|
|
@ -164,27 +164,6 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
|
shared_ptr combine(const Ordering& keys, ADT::Binary op) const;
|
||||||
|
|
||||||
|
|
||||||
// /**
|
|
||||||
// * @brief Permutes the keys in Potentials and DiscreteFactor
|
|
||||||
// *
|
|
||||||
// * This re-implements the permuteWithInverse() in both Potentials
|
|
||||||
// * and DiscreteFactor by doing both of them together.
|
|
||||||
// */
|
|
||||||
//
|
|
||||||
// void permuteWithInverse(const Permutation& inversePermutation){
|
|
||||||
// DiscreteFactor::permuteWithInverse(inversePermutation);
|
|
||||||
// Potentials::permuteWithInverse(inversePermutation);
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// /**
|
|
||||||
// * Apply a reduction, which is a remapping of variable indices.
|
|
||||||
// */
|
|
||||||
// virtual void reduceWithInverse(const internal::Reduction& inverseReduction) {
|
|
||||||
// DiscreteFactor::reduceWithInverse(inverseReduction);
|
|
||||||
// Potentials::reduceWithInverse(inverseReduction);
|
|
||||||
// }
|
|
||||||
|
|
||||||
/// Enumerate all values into a map from values to double.
|
/// Enumerate all values into a map from values to double.
|
||||||
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
||||||
|
|
||||||
|
|
@ -230,11 +209,10 @@ namespace gtsam {
|
||||||
const Names& names = {}) const override;
|
const Names& names = {}) const override;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
};
|
};
|
||||||
// DecisionTreeFactor
|
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template<> struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
|
template <>
|
||||||
|
struct traits<DecisionTreeFactor> : public Testable<DecisionTreeFactor> {};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -25,25 +25,26 @@
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||||
#define DISABLE_TIMING
|
#define DISABLE_TIMING
|
||||||
|
|
||||||
#include <boost/tokenizer.hpp>
|
|
||||||
#include <boost/assign/std/map.hpp>
|
#include <boost/assign/std/map.hpp>
|
||||||
#include <boost/assign/std/vector.hpp>
|
#include <boost/assign/std/vector.hpp>
|
||||||
|
#include <boost/tokenizer.hpp>
|
||||||
using namespace boost::assign;
|
using namespace boost::assign;
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
|
||||||
#include <gtsam/base/timing.h>
|
#include <gtsam/base/timing.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
typedef AlgebraicDecisionTree<Key> ADT;
|
typedef AlgebraicDecisionTree<Key> ADT;
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
template<> struct traits<ADT> : public Testable<ADT> {};
|
template <>
|
||||||
}
|
struct traits<ADT> : public Testable<ADT> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
||||||
#define DISABLE_DOT
|
#define DISABLE_DOT
|
||||||
|
|
||||||
|
|
@ -63,8 +64,8 @@ void dot(const T&f, const string& filename) {
|
||||||
|
|
||||||
// If second argument of binary op is Leaf
|
// If second argument of binary op is Leaf
|
||||||
template<typename L>
|
template<typename L>
|
||||||
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L, double>::Choice::apply_fC_op_gL(
|
typename DecisionTree<L, double>::Node::Ptr DecisionTree<L,
|
||||||
Cache& cache, const Leaf& gL, Mul op) const {
|
double>::Choice::apply_fC_op_gL( Cache& cache, const Leaf& gL, Mul op) const {
|
||||||
Ptr h(new Choice(label(), cardinality()));
|
Ptr h(new Choice(label(), cardinality()));
|
||||||
for(const NodePtr& branch: branches_)
|
for(const NodePtr& branch: branches_)
|
||||||
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
h->push_back(branch->apply_f_op_g(cache, gL, op));
|
||||||
|
|
@ -72,9 +73,9 @@ void dot(const T&f, const string& filename) {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// instrumented operators
|
// instrumented operators
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t muls = 0, adds = 0;
|
size_t muls = 0, adds = 0;
|
||||||
double elapsed;
|
double elapsed;
|
||||||
void resetCounts() {
|
void resetCounts() {
|
||||||
|
|
@ -83,8 +84,9 @@ void resetCounts() {
|
||||||
}
|
}
|
||||||
void printCounts(const string& s) {
|
void printCounts(const string& s) {
|
||||||
#ifndef DISABLE_TIMING
|
#ifndef DISABLE_TIMING
|
||||||
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds
|
cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds %
|
||||||
% (1000 * elapsed) << endl;
|
(1000 * elapsed)
|
||||||
|
<< endl;
|
||||||
#endif
|
#endif
|
||||||
resetCounts();
|
resetCounts();
|
||||||
}
|
}
|
||||||
|
|
@ -97,10 +99,9 @@ double add_(const double& a, const double& b) {
|
||||||
return a + b;
|
return a + b;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test ADT
|
// test ADT
|
||||||
TEST(ADT, example3)
|
TEST(ADT, example3) {
|
||||||
{
|
|
||||||
// Create labels
|
// Create labels
|
||||||
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
|
DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(3, 2), E(4, 2);
|
||||||
|
|
||||||
|
|
@ -122,14 +123,13 @@ TEST(ADT, example3)
|
||||||
|
|
||||||
dot(acnotb, "ADT-acnotb");
|
dot(acnotb, "ADT-acnotb");
|
||||||
|
|
||||||
|
|
||||||
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
|
ADT big = apply(apply(d, note, &mul), acnotb, &add_);
|
||||||
dot(big, "ADT-big");
|
dot(big, "ADT-big");
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Asia Bayes Network
|
// Asia Bayes Network
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
|
||||||
/** Convert Signature into CPT */
|
/** Convert Signature into CPT */
|
||||||
ADT create(const Signature& signature) {
|
ADT create(const Signature& signature) {
|
||||||
|
|
@ -143,9 +143,9 @@ ADT create(const Signature& signature) {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test Asia Joint
|
// test Asia Joint
|
||||||
TEST(ADT, joint)
|
TEST(ADT, joint) {
|
||||||
{
|
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
|
||||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2);
|
D(7, 2);
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(asiaCPTs);
|
gttic_(asiaCPTs);
|
||||||
|
|
@ -204,8 +204,7 @@ TEST(ADT, joint)
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test Inference with joint
|
// test Inference with joint
|
||||||
TEST(ADT, inference)
|
TEST(ADT, inference) {
|
||||||
{
|
|
||||||
DiscreteKey A(0, 2), D(1, 2), //
|
DiscreteKey A(0, 2), D(1, 2), //
|
||||||
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
|
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
|
||||||
|
|
||||||
|
|
@ -271,8 +270,7 @@ TEST(ADT, inference)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(ADT, factor_graph)
|
TEST(ADT, factor_graph) {
|
||||||
{
|
|
||||||
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
|
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
|
|
@ -403,13 +401,14 @@ TEST(ADT, factor_graph)
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test equality
|
// test equality
|
||||||
TEST(ADT, equality_noparser)
|
TEST(ADT, equality_noparser) {
|
||||||
{
|
|
||||||
DiscreteKey A(0, 2), B(1, 2);
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
Signature::Table tableA, tableB;
|
Signature::Table tableA, tableB;
|
||||||
Signature::Row rA, rB;
|
Signature::Row rA, rB;
|
||||||
rA += 80, 20; rB += 60, 40;
|
rA += 80, 20;
|
||||||
tableA += rA; tableB += rB;
|
rB += 60, 40;
|
||||||
|
tableA += rA;
|
||||||
|
tableB += rB;
|
||||||
|
|
||||||
// Check straight equality
|
// Check straight equality
|
||||||
ADT pA1 = create(A % tableA);
|
ADT pA1 = create(A % tableA);
|
||||||
|
|
@ -425,8 +424,7 @@ TEST(ADT, equality_noparser)
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test equality
|
// test equality
|
||||||
TEST(ADT, equality_parser)
|
TEST(ADT, equality_parser) {
|
||||||
{
|
|
||||||
DiscreteKey A(0, 2), B(1, 2);
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
// Check straight equality
|
// Check straight equality
|
||||||
ADT pA1 = create(A % "80/20");
|
ADT pA1 = create(A % "80/20");
|
||||||
|
|
@ -440,11 +438,10 @@ TEST(ADT, equality_parser)
|
||||||
EXPECT(pAB2.equals(pAB1));
|
EXPECT(pAB2.equals(pAB1));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Factor graph construction
|
// Factor graph construction
|
||||||
// test constructor from strings
|
// test constructor from strings
|
||||||
TEST(ADT, constructor)
|
TEST(ADT, constructor) {
|
||||||
{
|
|
||||||
DiscreteKey v0(0, 2), v1(1, 3);
|
DiscreteKey v0(0, 2), v1(1, 3);
|
||||||
DiscreteValues x00, x01, x02, x10, x11, x12;
|
DiscreteValues x00, x01, x02, x10, x11, x12;
|
||||||
x00[0] = 0, x00[1] = 0;
|
x00[0] = 0, x00[1] = 0;
|
||||||
|
|
@ -473,8 +470,7 @@ TEST(ADT, constructor)
|
||||||
DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
|
DiscreteKey z0(0, 5), z1(1, 4), z2(2, 3), z3(3, 2);
|
||||||
vector<double> table(5 * 4 * 3 * 2);
|
vector<double> table(5 * 4 * 3 * 2);
|
||||||
double x = 0;
|
double x = 0;
|
||||||
for(double& t: table)
|
for (double& t : table) t = x++;
|
||||||
t = x++;
|
|
||||||
ADT f3(z0 & z1 & z2 & z3, table);
|
ADT f3(z0 & z1 & z2 & z3, table);
|
||||||
DiscreteValues assignment;
|
DiscreteValues assignment;
|
||||||
assignment[0] = 0;
|
assignment[0] = 0;
|
||||||
|
|
@ -487,8 +483,7 @@ TEST(ADT, constructor)
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test conversion to integer indices
|
// test conversion to integer indices
|
||||||
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
|
// Only works if DiscreteKeys are binary, as size_t has binary cardinality!
|
||||||
TEST(ADT, conversion)
|
TEST(ADT, conversion) {
|
||||||
{
|
|
||||||
DiscreteKey X(0, 2), Y(1, 2);
|
DiscreteKey X(0, 2), Y(1, 2);
|
||||||
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
|
ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6");
|
||||||
dot(fDiscreteKey, "conversion-f1");
|
dot(fDiscreteKey, "conversion-f1");
|
||||||
|
|
@ -513,10 +508,9 @@ TEST(ADT, conversion)
|
||||||
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test operations in elimination
|
// test operations in elimination
|
||||||
TEST(ADT, elimination)
|
TEST(ADT, elimination) {
|
||||||
{
|
|
||||||
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||||
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
|
ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5");
|
||||||
dot(f1, "elimination-f1");
|
dot(f1, "elimination-f1");
|
||||||
|
|
@ -552,10 +546,9 @@ TEST(ADT, elimination)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test non-commutative op
|
// Test non-commutative op
|
||||||
TEST(ADT, div)
|
TEST(ADT, div) {
|
||||||
{
|
|
||||||
DiscreteKey A(0, 2), B(1, 2);
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
|
|
||||||
// Literals
|
// Literals
|
||||||
|
|
@ -567,10 +560,9 @@ TEST(ADT, div)
|
||||||
EXPECT(assert_equal(expected_b_div_a, b / a));
|
EXPECT(assert_equal(expected_b_div_a, b / a));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test zero shortcut
|
// test zero shortcut
|
||||||
TEST(ADT, zero)
|
TEST(ADT, zero) {
|
||||||
{
|
|
||||||
DiscreteKey A(0, 2), B(1, 2);
|
DiscreteKey A(0, 2), B(1, 2);
|
||||||
|
|
||||||
// Literals
|
// Literals
|
||||||
|
|
|
||||||
|
|
@ -65,14 +65,15 @@ struct CrazyDecisionTree : public DecisionTree<string, Crazy> {
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
template<> struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
template <>
|
||||||
}
|
struct traits<CrazyDecisionTree> : public Testable<CrazyDecisionTree> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
GTSAM_CONCEPT_TESTABLE_INST(CrazyDecisionTree)
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test string labels and int range
|
// Test string labels and int range
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
|
||||||
struct DT : public DecisionTree<string, int> {
|
struct DT : public DecisionTree<string, int> {
|
||||||
using Base = DecisionTree<string, int>;
|
using Base = DecisionTree<string, int>;
|
||||||
|
|
@ -98,30 +99,21 @@ struct DT : public DecisionTree<string, int> {
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
template<> struct traits<DT> : public Testable<DT> {};
|
template <>
|
||||||
}
|
struct traits<DT> : public Testable<DT> {};
|
||||||
|
} // namespace gtsam
|
||||||
|
|
||||||
GTSAM_CONCEPT_TESTABLE_INST(DT)
|
GTSAM_CONCEPT_TESTABLE_INST(DT)
|
||||||
|
|
||||||
struct Ring {
|
struct Ring {
|
||||||
static inline int zero() {
|
static inline int zero() { return 0; }
|
||||||
return 0;
|
static inline int one() { return 1; }
|
||||||
}
|
static inline int id(const int& a) { return a; }
|
||||||
static inline int one() {
|
static inline int add(const int& a, const int& b) { return a + b; }
|
||||||
return 1;
|
static inline int mul(const int& a, const int& b) { return a * b; }
|
||||||
}
|
|
||||||
static inline int id(const int& a) {
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
static inline int add(const int& a, const int& b) {
|
|
||||||
return a + b;
|
|
||||||
}
|
|
||||||
static inline int mul(const int& a, const int& b) {
|
|
||||||
return a * b;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test DT
|
// test DT
|
||||||
TEST(DecisionTree, example) {
|
TEST(DecisionTree, example) {
|
||||||
// Create labels
|
// Create labels
|
||||||
|
|
@ -228,7 +220,7 @@ TEST(DecisionTree, example) {
|
||||||
DOT(acnotb);
|
DOT(acnotb);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test Conversion of values
|
// test Conversion of values
|
||||||
bool bool_of_int(const int& y) { return y != 0; };
|
bool bool_of_int(const int& y) { return y != 0; };
|
||||||
typedef DecisionTree<string, bool> StringBoolTree;
|
typedef DecisionTree<string, bool> StringBoolTree;
|
||||||
|
|
@ -249,11 +241,9 @@ TEST(DecisionTree, ConvertValuesOnly) {
|
||||||
EXPECT(!f2(x00));
|
EXPECT(!f2(x00));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test Conversion of both values and labels.
|
// test Conversion of both values and labels.
|
||||||
enum Label {
|
enum Label { U, V, X, Y, Z };
|
||||||
U, V, X, Y, Z
|
|
||||||
};
|
|
||||||
typedef DecisionTree<Label, bool> LabelBoolTree;
|
typedef DecisionTree<Label, bool> LabelBoolTree;
|
||||||
|
|
||||||
TEST(DecisionTree, ConvertBoth) {
|
TEST(DecisionTree, ConvertBoth) {
|
||||||
|
|
@ -281,7 +271,7 @@ TEST(DecisionTree, ConvertBoth) {
|
||||||
EXPECT(!f2(x11));
|
EXPECT(!f2(x11));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// test Compose expansion
|
// test Compose expansion
|
||||||
TEST(DecisionTree, Compose) {
|
TEST(DecisionTree, Compose) {
|
||||||
// Create labels
|
// Create labels
|
||||||
|
|
@ -308,7 +298,7 @@ TEST(DecisionTree, Compose) {
|
||||||
DOT(f5);
|
DOT(f5);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Check we can create a decision tree of containers.
|
// Check we can create a decision tree of containers.
|
||||||
TEST(DecisionTree, Containers) {
|
TEST(DecisionTree, Containers) {
|
||||||
using Container = std::vector<double>;
|
using Container = std::vector<double>;
|
||||||
|
|
@ -318,7 +308,7 @@ TEST(DecisionTree, Containers) {
|
||||||
StringContainerTree tree;
|
StringContainerTree tree;
|
||||||
|
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT stringIntTree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
|
|
||||||
// Check conversion
|
// Check conversion
|
||||||
|
|
@ -330,11 +320,11 @@ TEST(DecisionTree, Containers) {
|
||||||
StringContainerTree converted(stringIntTree, container_of_int);
|
StringContainerTree converted(stringIntTree, container_of_int);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test visit.
|
// Test visit.
|
||||||
TEST(DecisionTree, visit) {
|
TEST(DecisionTree, visit) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
auto visitor = [&](int y) { sum += y; };
|
auto visitor = [&](int y) { sum += y; };
|
||||||
|
|
@ -342,11 +332,11 @@ TEST(DecisionTree, visit) {
|
||||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test visit, with Choices argument.
|
// Test visit, with Choices argument.
|
||||||
TEST(DecisionTree, visitWith) {
|
TEST(DecisionTree, visitWith) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
|
auto visitor = [&](const Assignment<string>& choices, int y) { sum += y; };
|
||||||
|
|
@ -354,29 +344,29 @@ TEST(DecisionTree, visitWith) {
|
||||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test fold.
|
// Test fold.
|
||||||
TEST(DecisionTree, fold) {
|
TEST(DecisionTree, fold) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
|
||||||
auto add = [](const int& y, double x) { return y + x; };
|
auto add = [](const int& y, double x) { return y + x; };
|
||||||
double sum = tree.fold(add, 0.0);
|
double sum = tree.fold(add, 0.0);
|
||||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test retrieving all labels.
|
// Test retrieving all labels.
|
||||||
TEST(DecisionTree, labels) {
|
TEST(DecisionTree, labels) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
||||||
auto labels = tree.labels();
|
auto labels = tree.labels();
|
||||||
EXPECT_LONGS_EQUAL(2, labels.size());
|
EXPECT_LONGS_EQUAL(2, labels.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test retrieving all labels.
|
// Test unzip method.
|
||||||
TEST(DecisionTree, unzip) {
|
TEST(DecisionTree, unzip) {
|
||||||
using DTP = DecisionTree<string, std::pair<int, string>>;
|
using DTP = DecisionTree<string, std::pair<int, string>>;
|
||||||
using DT1 = DecisionTree<string, int>;
|
using DT1 = DecisionTree<string, int>;
|
||||||
|
|
@ -384,10 +374,8 @@ TEST(DecisionTree, unzip) {
|
||||||
|
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
DTP tree(B,
|
DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
|
||||||
DTP(A, {0, "zero"}, {1, "one"}),
|
DTP(A, {2, "two"}, {1337, "l33t"}));
|
||||||
DTP(A, {2, "two"}, {1337, "l33t"})
|
|
||||||
);
|
|
||||||
|
|
||||||
DT1 dt1;
|
DT1 dt1;
|
||||||
DT2 dt2;
|
DT2 dt2;
|
||||||
|
|
@ -400,6 +388,29 @@ TEST(DecisionTree, unzip) {
|
||||||
EXPECT(tree2.equals(dt2));
|
EXPECT(tree2.equals(dt2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Test thresholding.
|
||||||
|
TEST(DecisionTree, threshold) {
|
||||||
|
// Create three level tree
|
||||||
|
vector<DT::LabelC> keys;
|
||||||
|
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
|
||||||
|
DT tree(keys, "0 1 2 3 4 5 6 7");
|
||||||
|
|
||||||
|
// Check number of leaves equal to zero
|
||||||
|
auto count = [](const int& value, int count) {
|
||||||
|
return value == 0 ? count + 1 : count;
|
||||||
|
};
|
||||||
|
EXPECT_LONGS_EQUAL(1, tree.fold(count, 0));
|
||||||
|
|
||||||
|
// Now threshold
|
||||||
|
auto threshold = [](int value) { return value < 5 ? 0 : value; };
|
||||||
|
DT thresholded(tree, threshold);
|
||||||
|
|
||||||
|
// Check number of leaves equal to zero now = 2
|
||||||
|
// Note: it is 2, because the pruned branches are counted as 1!
|
||||||
|
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
|
|
@ -191,20 +191,36 @@ TEST(DiscreteConditional, marginals) {
|
||||||
DiscreteConditional prior(B % "1/2");
|
DiscreteConditional prior(B % "1/2");
|
||||||
DiscreteConditional pAB = prior * conditional;
|
DiscreteConditional pAB = prior * conditional;
|
||||||
|
|
||||||
|
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 1*1 + 2*2 = 5
|
||||||
|
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||||
DiscreteConditional actualA = pAB.marginal(A.first);
|
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||||
DiscreteConditional pA(A % "5/4");
|
DiscreteConditional pA(A % "5/4");
|
||||||
EXPECT(assert_equal(pA, actualA));
|
EXPECT(assert_equal(pA, actualA));
|
||||||
EXPECT_LONGS_EQUAL(1, actualA.nrFrontals());
|
EXPECT(actualA.frontals() == KeyVector{1});
|
||||||
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
|
EXPECT_LONGS_EQUAL(0, actualA.nrParents());
|
||||||
KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals());
|
|
||||||
EXPECT((frontalsA == KeyVector{1}));
|
|
||||||
|
|
||||||
DiscreteConditional actualB = pAB.marginal(B.first);
|
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||||
EXPECT(assert_equal(prior, actualB));
|
EXPECT(assert_equal(prior, actualB));
|
||||||
EXPECT_LONGS_EQUAL(1, actualB.nrFrontals());
|
EXPECT(actualB.frontals() == KeyVector{0});
|
||||||
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
|
EXPECT_LONGS_EQUAL(0, actualB.nrParents());
|
||||||
KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals());
|
}
|
||||||
EXPECT((frontalsB == KeyVector{0}));
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// Check calculation of marginals in case branches are pruned
|
||||||
|
TEST(DiscreteConditional, marginals2) {
|
||||||
|
DiscreteKey A(0, 2), B(1, 2); // changing keys need to make pruning happen!
|
||||||
|
DiscreteConditional conditional(A | B = "2/2 3/1");
|
||||||
|
DiscreteConditional prior(B % "1/2");
|
||||||
|
DiscreteConditional pAB = prior * conditional;
|
||||||
|
GTSAM_PRINT(pAB);
|
||||||
|
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
|
||||||
|
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
|
||||||
|
DiscreteConditional actualA = pAB.marginal(A.first);
|
||||||
|
DiscreteConditional pA(A % "8/4");
|
||||||
|
EXPECT(assert_equal(pA, actualA));
|
||||||
|
|
||||||
|
DiscreteConditional actualB = pAB.marginal(B.first);
|
||||||
|
EXPECT(assert_equal(prior, actualB));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue