Fixed thresholding and fold example
parent
fa1cde2f60
commit
8db7f25021
|
@ -24,8 +24,8 @@ using namespace boost::assign;
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
//#define DT_DEBUG_MEMORY
|
// #define DT_DEBUG_MEMORY
|
||||||
//#define DT_NO_PRUNING
|
// #define DT_NO_PRUNING
|
||||||
#define DISABLE_DOT
|
#define DISABLE_DOT
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
@ -349,10 +349,10 @@ TEST(DecisionTree, visitWith) {
|
||||||
TEST(DecisionTree, fold) {
|
TEST(DecisionTree, fold) {
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B");
|
string A("A"), B("B");
|
||||||
DT tree(B, DT(A, 0, 1), DT(A, 2, 3));
|
DT tree(B, DT(A, 1, 1), DT(A, 2, 3));
|
||||||
auto add = [](const int& y, double x) { return y + x; };
|
auto add = [](const int& y, double x) { return y + x; };
|
||||||
double sum = tree.fold(add, 0.0);
|
double sum = tree.fold(add, 0.0);
|
||||||
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9);
|
EXPECT_DOUBLES_EQUAL(6.0, sum, 1e-9); // Note, not 7, due to pruning!
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
|
@ -365,7 +365,7 @@ TEST(DecisionTree, labels) {
|
||||||
EXPECT_LONGS_EQUAL(2, labels.size());
|
EXPECT_LONGS_EQUAL(2, labels.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ******************************************************************************** */
|
/* ************************************************************************** */
|
||||||
// Test unzip method.
|
// Test unzip method.
|
||||||
TEST(DecisionTree, unzip) {
|
TEST(DecisionTree, unzip) {
|
||||||
using DTP = DecisionTree<string, std::pair<int, string>>;
|
using DTP = DecisionTree<string, std::pair<int, string>>;
|
||||||
|
@ -374,15 +374,13 @@ TEST(DecisionTree, unzip) {
|
||||||
|
|
||||||
// Create small two-level tree
|
// Create small two-level tree
|
||||||
string A("A"), B("B"), C("C");
|
string A("A"), B("B"), C("C");
|
||||||
DTP tree(B,
|
DTP tree(B, DTP(A, {0, "zero"}, {1, "one"}),
|
||||||
DTP(A, {0, "zero"}, {1, "one"}),
|
DTP(A, {2, "two"}, {1337, "l33t"}));
|
||||||
DTP(A, {2, "two"}, {1337, "l33t"})
|
|
||||||
);
|
|
||||||
|
|
||||||
DT1 dt1;
|
DT1 dt1;
|
||||||
DT2 dt2;
|
DT2 dt2;
|
||||||
std::tie(dt1, dt2) = unzip(tree);
|
std::tie(dt1, dt2) = unzip(tree);
|
||||||
|
|
||||||
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
|
DT1 tree1(B, DT1(A, 0, 1), DT1(A, 2, 1337));
|
||||||
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
|
DT2 tree2(B, DT2(A, "zero", "one"), DT2(A, "two", "l33t"));
|
||||||
|
|
||||||
|
@ -398,7 +396,7 @@ TEST(DecisionTree, threshold) {
|
||||||
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
|
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
|
||||||
DT tree(keys, "0 1 2 3 4 5 6 7");
|
DT tree(keys, "0 1 2 3 4 5 6 7");
|
||||||
|
|
||||||
// Check number of elements equal to zero
|
// Check number of leaves equal to zero
|
||||||
auto count = [](const int& value, int count) {
|
auto count = [](const int& value, int count) {
|
||||||
return value == 0 ? count + 1 : count;
|
return value == 0 ? count + 1 : count;
|
||||||
};
|
};
|
||||||
|
@ -408,9 +406,9 @@ TEST(DecisionTree, threshold) {
|
||||||
auto threshold = [](int value) { return value < 5 ? 0 : value; };
|
auto threshold = [](int value) { return value < 5 ? 0 : value; };
|
||||||
DT thresholded(tree, threshold);
|
DT thresholded(tree, threshold);
|
||||||
|
|
||||||
// Check number of elements equal to zero now = 5
|
// Check number of leaves equal to zero now = 2
|
||||||
// TODO(frank): it is 2, because the pruned branches are counted as 1!
|
// Note: it is 2, because the pruned branches are counted as 1!
|
||||||
EXPECT_LONGS_EQUAL(5, thresholded.fold(count, 0));
|
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue