Merge pull request #1844 from borglab/feature/timeHybrid
commit
e4ec8d3b9c
|
|
@ -91,7 +91,7 @@ namespace gtsam {
|
||||||
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||||
const ValueFormatter& valueFormatter,
|
const ValueFormatter& valueFormatter,
|
||||||
bool showZero) const override {
|
bool showZero) const override {
|
||||||
std::string value = valueFormatter(constant_);
|
const std::string value = valueFormatter(constant_);
|
||||||
if (showZero || value.compare("0"))
|
if (showZero || value.compare("0"))
|
||||||
os << "\"" << this->id() << "\" [label=\"" << value
|
os << "\"" << this->id() << "\" [label=\"" << value
|
||||||
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
|
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
|
||||||
|
|
@ -306,7 +306,8 @@ namespace gtsam {
|
||||||
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
void dot(std::ostream& os, const LabelFormatter& labelFormatter,
|
||||||
const ValueFormatter& valueFormatter,
|
const ValueFormatter& valueFormatter,
|
||||||
bool showZero) const override {
|
bool showZero) const override {
|
||||||
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_
|
const std::string label = labelFormatter(label_);
|
||||||
|
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label
|
||||||
<< "\"]\n";
|
<< "\"]\n";
|
||||||
size_t B = branches_.size();
|
size_t B = branches_.size();
|
||||||
for (size_t i = 0; i < B; i++) {
|
for (size_t i = 0; i < B; i++) {
|
||||||
|
|
|
||||||
|
|
@ -147,14 +147,14 @@ namespace gtsam {
|
||||||
size_t i;
|
size_t i;
|
||||||
ADT result(*this);
|
ADT result(*this);
|
||||||
for (i = 0; i < nrFrontals; i++) {
|
for (i = 0; i < nrFrontals; i++) {
|
||||||
Key j = keys()[i];
|
Key j = keys_[i];
|
||||||
result = result.combine(j, cardinality(j), op);
|
result = result.combine(j, cardinality(j), op);
|
||||||
}
|
}
|
||||||
|
|
||||||
// create new factor, note we start keys after nrFrontals
|
// Create new factor, note we start with keys after nrFrontals:
|
||||||
DiscreteKeys dkeys;
|
DiscreteKeys dkeys;
|
||||||
for (; i < keys().size(); i++) {
|
for (; i < keys_.size(); i++) {
|
||||||
Key j = keys()[i];
|
Key j = keys_[i];
|
||||||
dkeys.push_back(DiscreteKey(j, cardinality(j)));
|
dkeys.push_back(DiscreteKey(j, cardinality(j)));
|
||||||
}
|
}
|
||||||
return std::make_shared<DecisionTreeFactor>(dkeys, result);
|
return std::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||||
|
|
@ -179,24 +179,22 @@ namespace gtsam {
|
||||||
result = result.combine(j, cardinality(j), op);
|
result = result.combine(j, cardinality(j), op);
|
||||||
}
|
}
|
||||||
|
|
||||||
// create new factor, note we collect keys that are not in frontalKeys
|
|
||||||
/*
|
/*
|
||||||
Due to branch merging, the labels in `result` may be missing some keys
|
Create new factor, note we collect keys that are not in frontalKeys.
|
||||||
|
|
||||||
|
Due to branch merging, the labels in `result` may be missing some keys.
|
||||||
E.g. After branch merging, we may get a ADT like:
|
E.g. After branch merging, we may get a ADT like:
|
||||||
Leaf [2] 1.0204082
|
Leaf [2] 1.0204082
|
||||||
|
|
||||||
This is missing the key values used for branching.
|
Hence, code below traverses the original keys and omits those in
|
||||||
|
frontalKeys. We loop over cardinalities, which is O(n) even for a map, and
|
||||||
|
then "contains" is a binary search on a small vector.
|
||||||
*/
|
*/
|
||||||
KeyVector difference, frontalKeys_(frontalKeys), keys_(keys());
|
|
||||||
// Get the difference of the frontalKeys and the factor keys using set_difference
|
|
||||||
std::sort(keys_.begin(), keys_.end());
|
|
||||||
std::sort(frontalKeys_.begin(), frontalKeys_.end());
|
|
||||||
std::set_difference(keys_.begin(), keys_.end(), frontalKeys_.begin(),
|
|
||||||
frontalKeys_.end(), back_inserter(difference));
|
|
||||||
|
|
||||||
DiscreteKeys dkeys;
|
DiscreteKeys dkeys;
|
||||||
for (Key key : difference) {
|
for (auto&& [key, cardinality] : cardinalities_) {
|
||||||
dkeys.push_back(DiscreteKey(key, cardinality(key)));
|
if (!frontalKeys.contains(key)) {
|
||||||
|
dkeys.push_back(DiscreteKey(key, cardinality));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return std::make_shared<DecisionTreeFactor>(dkeys, result);
|
return std::make_shared<DecisionTreeFactor>(dkeys, result);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,9 @@
|
||||||
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
|
||||||
#include <gtsam/discrete/DiscreteValues.h>
|
#include <gtsam/discrete/DiscreteValues.h>
|
||||||
// headers first to make sure no missing headers
|
// headers first to make sure no missing headers
|
||||||
|
#include <CppUnitLite/TestHarness.h>
|
||||||
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
#include <gtsam/discrete/AlgebraicDecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only
|
||||||
#define DISABLE_TIMING
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
|
||||||
#include <gtsam/base/timing.h>
|
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
@ -71,16 +68,14 @@ void dot(const T& f, const string& filename) {
|
||||||
// instrumented operators
|
// instrumented operators
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
size_t muls = 0, adds = 0;
|
size_t muls = 0, adds = 0;
|
||||||
double elapsed;
|
|
||||||
void resetCounts() {
|
void resetCounts() {
|
||||||
muls = 0;
|
muls = 0;
|
||||||
adds = 0;
|
adds = 0;
|
||||||
}
|
}
|
||||||
void printCounts(const string& s) {
|
void printCounts(const string& s) {
|
||||||
#ifndef DISABLE_TIMING
|
#ifndef DISABLE_TIMING
|
||||||
cout << s << ": " << std::setw(3) << muls << " muls, " <<
|
cout << s << ": " << std::setw(3) << muls << " muls, " << std::setw(3) << adds
|
||||||
std::setw(3) << adds << " adds, " << 1000 * elapsed << " ms."
|
<< " adds" << endl;
|
||||||
<< endl;
|
|
||||||
#endif
|
#endif
|
||||||
resetCounts();
|
resetCounts();
|
||||||
}
|
}
|
||||||
|
|
@ -131,37 +126,35 @@ ADT create(const Signature& signature) {
|
||||||
static size_t count = 0;
|
static size_t count = 0;
|
||||||
const DiscreteKey& key = signature.key();
|
const DiscreteKey& key = signature.key();
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-" << key.first;
|
ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-"
|
||||||
|
<< key.first;
|
||||||
string DOTfile = ss.str();
|
string DOTfile = ss.str();
|
||||||
dot(p, DOTfile);
|
dot(p, DOTfile);
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
namespace asiaCPTs {
|
||||||
|
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
|
||||||
|
D(7, 2);
|
||||||
|
|
||||||
|
ADT pA = create(A % "99/1");
|
||||||
|
ADT pS = create(S % "50/50");
|
||||||
|
ADT pT = create(T | A = "99/1 95/5");
|
||||||
|
ADT pL = create(L | S = "99/1 90/10");
|
||||||
|
ADT pB = create(B | S = "70/30 40/60");
|
||||||
|
ADT pE = create((E | T, L) = "F T T T");
|
||||||
|
ADT pX = create(X | E = "95/5 2/98");
|
||||||
|
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
|
||||||
|
} // namespace asiaCPTs
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test Asia Joint
|
// test Asia Joint
|
||||||
TEST(ADT, joint) {
|
TEST(ADT, joint) {
|
||||||
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
|
using namespace asiaCPTs;
|
||||||
D(7, 2);
|
|
||||||
|
|
||||||
resetCounts();
|
|
||||||
gttic_(asiaCPTs);
|
|
||||||
ADT pA = create(A % "99/1");
|
|
||||||
ADT pS = create(S % "50/50");
|
|
||||||
ADT pT = create(T | A = "99/1 95/5");
|
|
||||||
ADT pL = create(L | S = "99/1 90/10");
|
|
||||||
ADT pB = create(B | S = "70/30 40/60");
|
|
||||||
ADT pE = create((E | T, L) = "F T T T");
|
|
||||||
ADT pX = create(X | E = "95/5 2/98");
|
|
||||||
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
|
|
||||||
gttoc_(asiaCPTs);
|
|
||||||
tictoc_getNode(asiaCPTsNode, asiaCPTs);
|
|
||||||
elapsed = asiaCPTsNode->secs() + asiaCPTsNode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
printCounts("Asia CPTs");
|
|
||||||
|
|
||||||
// Create joint
|
// Create joint
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(asiaJoint);
|
|
||||||
ADT joint = pA;
|
ADT joint = pA;
|
||||||
dot(joint, "Asia-A");
|
dot(joint, "Asia-A");
|
||||||
joint = apply(joint, pS, &mul);
|
joint = apply(joint, pS, &mul);
|
||||||
|
|
@ -183,11 +176,12 @@ TEST(ADT, joint) {
|
||||||
#else
|
#else
|
||||||
EXPECT_LONGS_EQUAL(508, muls);
|
EXPECT_LONGS_EQUAL(508, muls);
|
||||||
#endif
|
#endif
|
||||||
gttoc_(asiaJoint);
|
|
||||||
tictoc_getNode(asiaJointNode, asiaJoint);
|
|
||||||
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
printCounts("Asia joint");
|
printCounts("Asia joint");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
TEST(ADT, combine) {
|
||||||
|
using namespace asiaCPTs;
|
||||||
|
|
||||||
// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
|
// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
|
||||||
ADT pASTL = pA;
|
ADT pASTL = pA;
|
||||||
|
|
@ -203,13 +197,11 @@ TEST(ADT, joint) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
// test Inference with joint
|
// test Inference with joint, created using different ordering
|
||||||
TEST(ADT, inference) {
|
TEST(ADT, inference) {
|
||||||
DiscreteKey A(0, 2), D(1, 2), //
|
DiscreteKey A(0, 2), D(1, 2), //
|
||||||
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
|
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
|
||||||
|
|
||||||
resetCounts();
|
|
||||||
gttic_(infCPTs);
|
|
||||||
ADT pA = create(A % "99/1");
|
ADT pA = create(A % "99/1");
|
||||||
ADT pS = create(S % "50/50");
|
ADT pS = create(S % "50/50");
|
||||||
ADT pT = create(T | A = "99/1 95/5");
|
ADT pT = create(T | A = "99/1 95/5");
|
||||||
|
|
@ -218,15 +210,9 @@ TEST(ADT, inference) {
|
||||||
ADT pE = create((E | T, L) = "F T T T");
|
ADT pE = create((E | T, L) = "F T T T");
|
||||||
ADT pX = create(X | E = "95/5 2/98");
|
ADT pX = create(X | E = "95/5 2/98");
|
||||||
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
|
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
|
||||||
gttoc_(infCPTs);
|
|
||||||
tictoc_getNode(infCPTsNode, infCPTs);
|
|
||||||
elapsed = infCPTsNode->secs() + infCPTsNode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
// printCounts("Inference CPTs");
|
|
||||||
|
|
||||||
// Create joint
|
// Create joint, note different ordering than above: different tree!
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(asiaProd);
|
|
||||||
ADT joint = pA;
|
ADT joint = pA;
|
||||||
dot(joint, "Joint-Product-A");
|
dot(joint, "Joint-Product-A");
|
||||||
joint = apply(joint, pS, &mul);
|
joint = apply(joint, pS, &mul);
|
||||||
|
|
@ -248,14 +234,9 @@ TEST(ADT, inference) {
|
||||||
#else
|
#else
|
||||||
EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering
|
EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering
|
||||||
#endif
|
#endif
|
||||||
gttoc_(asiaProd);
|
|
||||||
tictoc_getNode(asiaProdNode, asiaProd);
|
|
||||||
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
printCounts("Asia product");
|
printCounts("Asia product");
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(asiaSum);
|
|
||||||
ADT marginal = joint;
|
ADT marginal = joint;
|
||||||
marginal = marginal.combine(X, &add_);
|
marginal = marginal.combine(X, &add_);
|
||||||
dot(marginal, "Joint-Sum-ADBLEST");
|
dot(marginal, "Joint-Sum-ADBLEST");
|
||||||
|
|
@ -270,10 +251,6 @@ TEST(ADT, inference) {
|
||||||
#else
|
#else
|
||||||
EXPECT_LONGS_EQUAL(240, (long)adds);
|
EXPECT_LONGS_EQUAL(240, (long)adds);
|
||||||
#endif
|
#endif
|
||||||
gttoc_(asiaSum);
|
|
||||||
tictoc_getNode(asiaSumNode, asiaSum);
|
|
||||||
elapsed = asiaSumNode->secs() + asiaSumNode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
printCounts("Asia sum");
|
printCounts("Asia sum");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -281,8 +258,6 @@ TEST(ADT, inference) {
|
||||||
TEST(ADT, factor_graph) {
|
TEST(ADT, factor_graph) {
|
||||||
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
|
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
|
||||||
|
|
||||||
resetCounts();
|
|
||||||
gttic_(createCPTs);
|
|
||||||
ADT pS = create(S % "50/50");
|
ADT pS = create(S % "50/50");
|
||||||
ADT pT = create(T % "95/5");
|
ADT pT = create(T % "95/5");
|
||||||
ADT pL = create(L | S = "99/1 90/10");
|
ADT pL = create(L | S = "99/1 90/10");
|
||||||
|
|
@ -290,15 +265,9 @@ TEST(ADT, factor_graph) {
|
||||||
ADT pX = create(X | E = "95/5 2/98");
|
ADT pX = create(X | E = "95/5 2/98");
|
||||||
ADT pD = create(B | E = "1/8 7/9");
|
ADT pD = create(B | E = "1/8 7/9");
|
||||||
ADT pB = create(B | S = "70/30 40/60");
|
ADT pB = create(B | S = "70/30 40/60");
|
||||||
gttoc_(createCPTs);
|
|
||||||
tictoc_getNode(createCPTsNode, createCPTs);
|
|
||||||
elapsed = createCPTsNode->secs() + createCPTsNode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
// printCounts("Create CPTs");
|
|
||||||
|
|
||||||
// Create joint
|
// Create joint
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(asiaFG);
|
|
||||||
ADT fg = pS;
|
ADT fg = pS;
|
||||||
fg = apply(fg, pT, &mul);
|
fg = apply(fg, pT, &mul);
|
||||||
fg = apply(fg, pL, &mul);
|
fg = apply(fg, pL, &mul);
|
||||||
|
|
@ -312,14 +281,9 @@ TEST(ADT, factor_graph) {
|
||||||
#else
|
#else
|
||||||
EXPECT_LONGS_EQUAL(188, (long)muls);
|
EXPECT_LONGS_EQUAL(188, (long)muls);
|
||||||
#endif
|
#endif
|
||||||
gttoc_(asiaFG);
|
|
||||||
tictoc_getNode(asiaFGNode, asiaFG);
|
|
||||||
elapsed = asiaFGNode->secs() + asiaFGNode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
printCounts("Asia FG");
|
printCounts("Asia FG");
|
||||||
|
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(marg);
|
|
||||||
fg = fg.combine(X, &add_);
|
fg = fg.combine(X, &add_);
|
||||||
dot(fg, "Marginalized-6X");
|
dot(fg, "Marginalized-6X");
|
||||||
fg = fg.combine(T, &add_);
|
fg = fg.combine(T, &add_);
|
||||||
|
|
@ -335,83 +299,54 @@ TEST(ADT, factor_graph) {
|
||||||
#else
|
#else
|
||||||
LONGS_EQUAL(62, adds);
|
LONGS_EQUAL(62, adds);
|
||||||
#endif
|
#endif
|
||||||
gttoc_(marg);
|
|
||||||
tictoc_getNode(margNode, marg);
|
|
||||||
elapsed = margNode->secs() + margNode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
printCounts("marginalize");
|
printCounts("marginalize");
|
||||||
|
|
||||||
// BLESTX
|
// BLESTX
|
||||||
|
|
||||||
// Eliminate X
|
// Eliminate X
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(elimX);
|
|
||||||
ADT fE = pX;
|
ADT fE = pX;
|
||||||
dot(fE, "Eliminate-01-fEX");
|
dot(fE, "Eliminate-01-fEX");
|
||||||
fE = fE.combine(X, &add_);
|
fE = fE.combine(X, &add_);
|
||||||
dot(fE, "Eliminate-02-fE");
|
dot(fE, "Eliminate-02-fE");
|
||||||
gttoc_(elimX);
|
|
||||||
tictoc_getNode(elimXNode, elimX);
|
|
||||||
elapsed = elimXNode->secs() + elimXNode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
printCounts("Eliminate X");
|
printCounts("Eliminate X");
|
||||||
|
|
||||||
// Eliminate T
|
// Eliminate T
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(elimT);
|
|
||||||
ADT fLE = pT;
|
ADT fLE = pT;
|
||||||
fLE = apply(fLE, pE, &mul);
|
fLE = apply(fLE, pE, &mul);
|
||||||
dot(fLE, "Eliminate-03-fLET");
|
dot(fLE, "Eliminate-03-fLET");
|
||||||
fLE = fLE.combine(T, &add_);
|
fLE = fLE.combine(T, &add_);
|
||||||
dot(fLE, "Eliminate-04-fLE");
|
dot(fLE, "Eliminate-04-fLE");
|
||||||
gttoc_(elimT);
|
|
||||||
tictoc_getNode(elimTNode, elimT);
|
|
||||||
elapsed = elimTNode->secs() + elimTNode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
printCounts("Eliminate T");
|
printCounts("Eliminate T");
|
||||||
|
|
||||||
// Eliminate S
|
// Eliminate S
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(elimS);
|
|
||||||
ADT fBL = pS;
|
ADT fBL = pS;
|
||||||
fBL = apply(fBL, pL, &mul);
|
fBL = apply(fBL, pL, &mul);
|
||||||
fBL = apply(fBL, pB, &mul);
|
fBL = apply(fBL, pB, &mul);
|
||||||
dot(fBL, "Eliminate-05-fBLS");
|
dot(fBL, "Eliminate-05-fBLS");
|
||||||
fBL = fBL.combine(S, &add_);
|
fBL = fBL.combine(S, &add_);
|
||||||
dot(fBL, "Eliminate-06-fBL");
|
dot(fBL, "Eliminate-06-fBL");
|
||||||
gttoc_(elimS);
|
|
||||||
tictoc_getNode(elimSNode, elimS);
|
|
||||||
elapsed = elimSNode->secs() + elimSNode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
printCounts("Eliminate S");
|
printCounts("Eliminate S");
|
||||||
|
|
||||||
// Eliminate E
|
// Eliminate E
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(elimE);
|
|
||||||
ADT fBL2 = fE;
|
ADT fBL2 = fE;
|
||||||
fBL2 = apply(fBL2, fLE, &mul);
|
fBL2 = apply(fBL2, fLE, &mul);
|
||||||
fBL2 = apply(fBL2, pD, &mul);
|
fBL2 = apply(fBL2, pD, &mul);
|
||||||
dot(fBL2, "Eliminate-07-fBLE");
|
dot(fBL2, "Eliminate-07-fBLE");
|
||||||
fBL2 = fBL2.combine(E, &add_);
|
fBL2 = fBL2.combine(E, &add_);
|
||||||
dot(fBL2, "Eliminate-08-fBL2");
|
dot(fBL2, "Eliminate-08-fBL2");
|
||||||
gttoc_(elimE);
|
|
||||||
tictoc_getNode(elimENode, elimE);
|
|
||||||
elapsed = elimENode->secs() + elimENode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
printCounts("Eliminate E");
|
printCounts("Eliminate E");
|
||||||
|
|
||||||
// Eliminate L
|
// Eliminate L
|
||||||
resetCounts();
|
resetCounts();
|
||||||
gttic_(elimL);
|
|
||||||
ADT fB = fBL;
|
ADT fB = fBL;
|
||||||
fB = apply(fB, fBL2, &mul);
|
fB = apply(fB, fBL2, &mul);
|
||||||
dot(fB, "Eliminate-09-fBL");
|
dot(fB, "Eliminate-09-fBL");
|
||||||
fB = fB.combine(L, &add_);
|
fB = fB.combine(L, &add_);
|
||||||
dot(fB, "Eliminate-10-fB");
|
dot(fB, "Eliminate-10-fB");
|
||||||
gttoc_(elimL);
|
|
||||||
tictoc_getNode(elimLNode, elimL);
|
|
||||||
elapsed = elimLNode->secs() + elimLNode->wall();
|
|
||||||
tictoc_reset_();
|
|
||||||
printCounts("Eliminate L");
|
printCounts("Eliminate L");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,10 @@
|
||||||
#include <gtsam/base/serializationTestHelpers.h>
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteDistribution.h>
|
#include <gtsam/discrete/DiscreteDistribution.h>
|
||||||
|
#include <gtsam/discrete/DiscreteFactor.h>
|
||||||
#include <gtsam/discrete/Signature.h>
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
#include <gtsam/inference/Key.h>
|
||||||
|
#include <gtsam/inference/Ordering.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
@ -33,25 +36,24 @@ TEST(DecisionTreeFactor, ConstructorsMatch) {
|
||||||
DiscreteKey X(0, 2), Y(1, 3);
|
DiscreteKey X(0, 2), Y(1, 3);
|
||||||
|
|
||||||
// Create with vector and with string
|
// Create with vector and with string
|
||||||
const std::vector<double> table {2, 5, 3, 6, 4, 7};
|
const std::vector<double> table{2, 5, 3, 6, 4, 7};
|
||||||
DecisionTreeFactor f1({X, Y}, table);
|
DecisionTreeFactor f1({X, Y}, table);
|
||||||
DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
|
DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
|
||||||
EXPECT(assert_equal(f1, f2));
|
EXPECT(assert_equal(f1, f2));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DecisionTreeFactor, constructors)
|
TEST(DecisionTreeFactor, constructors) {
|
||||||
{
|
|
||||||
// Declare a bunch of keys
|
// Declare a bunch of keys
|
||||||
DiscreteKey X(0,2), Y(1,3), Z(2,2);
|
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
|
||||||
|
|
||||||
// Create factors
|
// Create factors
|
||||||
DecisionTreeFactor f1(X, {2, 8});
|
DecisionTreeFactor f1(X, {2, 8});
|
||||||
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
|
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
|
||||||
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
||||||
EXPECT_LONGS_EQUAL(1,f1.size());
|
EXPECT_LONGS_EQUAL(1, f1.size());
|
||||||
EXPECT_LONGS_EQUAL(2,f2.size());
|
EXPECT_LONGS_EQUAL(2, f2.size());
|
||||||
EXPECT_LONGS_EQUAL(3,f3.size());
|
EXPECT_LONGS_EQUAL(3, f3.size());
|
||||||
|
|
||||||
DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
|
DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
|
||||||
EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
|
EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
|
||||||
|
|
@ -70,7 +72,7 @@ TEST( DecisionTreeFactor, constructors)
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DecisionTreeFactor, Error) {
|
TEST(DecisionTreeFactor, Error) {
|
||||||
// Declare a bunch of keys
|
// Declare a bunch of keys
|
||||||
DiscreteKey X(0,2), Y(1,3), Z(2,2);
|
DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
|
||||||
|
|
||||||
// Create factors
|
// Create factors
|
||||||
DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
|
||||||
|
|
@ -104,9 +106,8 @@ TEST(DecisionTreeFactor, multiplication) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST( DecisionTreeFactor, sum_max)
|
TEST(DecisionTreeFactor, sum_max) {
|
||||||
{
|
DiscreteKey v0(0, 3), v1(1, 2);
|
||||||
DiscreteKey v0(0,3), v1(1,2);
|
|
||||||
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
|
||||||
|
|
||||||
DecisionTreeFactor expected(v1, "9 12");
|
DecisionTreeFactor expected(v1, "9 12");
|
||||||
|
|
@ -165,22 +166,85 @@ TEST(DecisionTreeFactor, Prune) {
|
||||||
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
|
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
|
||||||
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
|
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
|
||||||
|
|
||||||
DecisionTreeFactor expected3(
|
DecisionTreeFactor expected3(D & C & B & A,
|
||||||
D & C & B & A,
|
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
|
||||||
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
|
"0.999952870000 1.0 1.0 1.0 1.0");
|
||||||
"0.999952870000 1.0 1.0 1.0 1.0");
|
|
||||||
maxNrAssignments = 5;
|
maxNrAssignments = 5;
|
||||||
auto pruned3 = factor.prune(maxNrAssignments);
|
auto pruned3 = factor.prune(maxNrAssignments);
|
||||||
EXPECT(assert_equal(expected3, pruned3));
|
EXPECT(assert_equal(expected3, pruned3));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
// Asia Bayes Network
|
||||||
|
/* ************************************************************************** */
|
||||||
|
|
||||||
|
#define DISABLE_DOT
|
||||||
|
|
||||||
|
void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
|
||||||
|
#ifndef DISABLE_DOT
|
||||||
|
std::vector<std::string> names = {"A", "S", "T", "L", "B", "E", "X", "D"};
|
||||||
|
auto formatter = [&](Key key) { return names[key]; };
|
||||||
|
f.dot(filename, formatter, true);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Convert Signature into CPT */
|
||||||
|
DecisionTreeFactor create(const Signature& signature) {
|
||||||
|
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
|
||||||
|
return p;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
|
// test Asia Joint
|
||||||
|
TEST(DecisionTreeFactor, joint) {
|
||||||
|
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
|
||||||
|
D(7, 2);
|
||||||
|
|
||||||
|
gttic_(asiaCPTs);
|
||||||
|
DecisionTreeFactor pA = create(A % "99/1");
|
||||||
|
DecisionTreeFactor pS = create(S % "50/50");
|
||||||
|
DecisionTreeFactor pT = create(T | A = "99/1 95/5");
|
||||||
|
DecisionTreeFactor pL = create(L | S = "99/1 90/10");
|
||||||
|
DecisionTreeFactor pB = create(B | S = "70/30 40/60");
|
||||||
|
DecisionTreeFactor pE = create((E | T, L) = "F T T T");
|
||||||
|
DecisionTreeFactor pX = create(X | E = "95/5 2/98");
|
||||||
|
DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
|
||||||
|
|
||||||
|
// Create joint
|
||||||
|
gttic_(asiaJoint);
|
||||||
|
DecisionTreeFactor joint = pA;
|
||||||
|
maybeSaveDotFile(joint, "Asia-A");
|
||||||
|
joint = joint * pS;
|
||||||
|
maybeSaveDotFile(joint, "Asia-AS");
|
||||||
|
joint = joint * pT;
|
||||||
|
maybeSaveDotFile(joint, "Asia-AST");
|
||||||
|
joint = joint * pL;
|
||||||
|
maybeSaveDotFile(joint, "Asia-ASTL");
|
||||||
|
joint = joint * pB;
|
||||||
|
maybeSaveDotFile(joint, "Asia-ASTLB");
|
||||||
|
joint = joint * pE;
|
||||||
|
maybeSaveDotFile(joint, "Asia-ASTLBE");
|
||||||
|
joint = joint * pX;
|
||||||
|
maybeSaveDotFile(joint, "Asia-ASTLBEX");
|
||||||
|
joint = joint * pD;
|
||||||
|
maybeSaveDotFile(joint, "Asia-ASTLBEXD");
|
||||||
|
|
||||||
|
// Check that discrete keys are as expected
|
||||||
|
EXPECT(assert_equal(joint.discreteKeys(), {A, S, T, L, B, E, X, D}));
|
||||||
|
|
||||||
|
// Check that summing out variables maintains the keys even if merged, as is
|
||||||
|
// the case with S.
|
||||||
|
auto noAB = joint.sum(Ordering{A.first, B.first});
|
||||||
|
EXPECT(assert_equal(noAB->discreteKeys(), {S, T, L, E, X, D}));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
TEST(DecisionTreeFactor, DotWithNames) {
|
TEST(DecisionTreeFactor, DotWithNames) {
|
||||||
DiscreteKey A(12, 3), B(5, 2);
|
DiscreteKey A(12, 3), B(5, 2);
|
||||||
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
|
||||||
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
|
||||||
|
|
||||||
for (bool showZero:{true, false}) {
|
for (bool showZero : {true, false}) {
|
||||||
string actual = f.dot(formatter, showZero);
|
string actual = f.dot(formatter, showZero);
|
||||||
// pretty weak test, as ids are pointers and not stable across platforms.
|
// pretty weak test, as ids are pointers and not stable across platforms.
|
||||||
string expected = "digraph G {";
|
string expected = "digraph G {";
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ namespace gtsam {
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
static void checkKeys(const KeyVector& continuousKeys,
|
static void checkKeys(const KeyVector& continuousKeys,
|
||||||
std::vector<NonlinearFactorValuePair>& pairs) {
|
const std::vector<NonlinearFactorValuePair>& pairs) {
|
||||||
KeySet factor_keys_set;
|
KeySet factor_keys_set;
|
||||||
for (const auto& pair : pairs) {
|
for (const auto& pair : pairs) {
|
||||||
auto f = pair.first;
|
auto f = pair.first;
|
||||||
|
|
@ -55,14 +55,9 @@ HybridNonlinearFactor::HybridNonlinearFactor(
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridNonlinearFactor::HybridNonlinearFactor(
|
HybridNonlinearFactor::HybridNonlinearFactor(
|
||||||
const KeyVector& continuousKeys, const DiscreteKey& discreteKey,
|
const KeyVector& continuousKeys, const DiscreteKey& discreteKey,
|
||||||
const std::vector<NonlinearFactorValuePair>& factors)
|
const std::vector<NonlinearFactorValuePair>& pairs)
|
||||||
: Base(continuousKeys, {discreteKey}) {
|
: Base(continuousKeys, {discreteKey}) {
|
||||||
std::vector<NonlinearFactorValuePair> pairs;
|
|
||||||
KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end());
|
KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end());
|
||||||
KeySet factor_keys_set;
|
|
||||||
for (auto&& [f, val] : factors) {
|
|
||||||
pairs.emplace_back(f, val);
|
|
||||||
}
|
|
||||||
checkKeys(continuousKeys, pairs);
|
checkKeys(continuousKeys, pairs);
|
||||||
factors_ = FactorValuePairs({discreteKey}, pairs);
|
factors_ = FactorValuePairs({discreteKey}, pairs);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -106,11 +106,11 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
|
||||||
*
|
*
|
||||||
* @param continuousKeys Vector of keys for continuous factors.
|
* @param continuousKeys Vector of keys for continuous factors.
|
||||||
* @param discreteKey The discrete key for the "mode", indexing components.
|
* @param discreteKey The discrete key for the "mode", indexing components.
|
||||||
* @param factors Vector of gaussian factor-scalar pairs, one per mode.
|
* @param pairs Vector of gaussian factor-scalar pairs, one per mode.
|
||||||
*/
|
*/
|
||||||
HybridNonlinearFactor(const KeyVector& continuousKeys,
|
HybridNonlinearFactor(const KeyVector& continuousKeys,
|
||||||
const DiscreteKey& discreteKey,
|
const DiscreteKey& discreteKey,
|
||||||
const std::vector<NonlinearFactorValuePair>& factors);
|
const std::vector<NonlinearFactorValuePair>& pairs);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct a new HybridNonlinearFactor on a several discrete keys M,
|
* @brief Construct a new HybridNonlinearFactor on a several discrete keys M,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue