Merge pull request #1844 from borglab/feature/timeHybrid

release/4.3a0
Varun Agrawal 2024-09-24 15:32:03 -04:00 committed by GitHub
commit e4ec8d3b9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 129 additions and 136 deletions

View File

@ -91,7 +91,7 @@ namespace gtsam {
void dot(std::ostream& os, const LabelFormatter& labelFormatter, void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, const ValueFormatter& valueFormatter,
bool showZero) const override { bool showZero) const override {
std::string value = valueFormatter(constant_); const std::string value = valueFormatter(constant_);
if (showZero || value.compare("0")) if (showZero || value.compare("0"))
os << "\"" << this->id() << "\" [label=\"" << value os << "\"" << this->id() << "\" [label=\"" << value
<< "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
@ -306,7 +306,8 @@ namespace gtsam {
void dot(std::ostream& os, const LabelFormatter& labelFormatter, void dot(std::ostream& os, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter, const ValueFormatter& valueFormatter,
bool showZero) const override { bool showZero) const override {
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ const std::string label = labelFormatter(label_);
os << "\"" << this->id() << "\" [shape=circle, label=\"" << label
<< "\"]\n"; << "\"]\n";
size_t B = branches_.size(); size_t B = branches_.size();
for (size_t i = 0; i < B; i++) { for (size_t i = 0; i < B; i++) {

View File

@ -147,14 +147,14 @@ namespace gtsam {
size_t i; size_t i;
ADT result(*this); ADT result(*this);
for (i = 0; i < nrFrontals; i++) { for (i = 0; i < nrFrontals; i++) {
Key j = keys()[i]; Key j = keys_[i];
result = result.combine(j, cardinality(j), op); result = result.combine(j, cardinality(j), op);
} }
// create new factor, note we start keys after nrFrontals // Create new factor, note we start with keys after nrFrontals:
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (; i < keys().size(); i++) { for (; i < keys_.size(); i++) {
Key j = keys()[i]; Key j = keys_[i];
dkeys.push_back(DiscreteKey(j, cardinality(j))); dkeys.push_back(DiscreteKey(j, cardinality(j)));
} }
return std::make_shared<DecisionTreeFactor>(dkeys, result); return std::make_shared<DecisionTreeFactor>(dkeys, result);
@ -179,24 +179,22 @@ namespace gtsam {
result = result.combine(j, cardinality(j), op); result = result.combine(j, cardinality(j), op);
} }
// create new factor, note we collect keys that are not in frontalKeys
/* /*
Due to branch merging, the labels in `result` may be missing some keys Create new factor, note we collect keys that are not in frontalKeys.
Due to branch merging, the labels in `result` may be missing some keys.
E.g. After branch merging, we may get a ADT like: E.g. After branch merging, we may get a ADT like:
Leaf [2] 1.0204082 Leaf [2] 1.0204082
This is missing the key values used for branching. Hence, code below traverses the original keys and omits those in
frontalKeys. We loop over cardinalities, which is O(n) even for a map, and
then "contains" is a binary search on a small vector.
*/ */
KeyVector difference, frontalKeys_(frontalKeys), keys_(keys());
// Get the difference of the frontalKeys and the factor keys using set_difference
std::sort(keys_.begin(), keys_.end());
std::sort(frontalKeys_.begin(), frontalKeys_.end());
std::set_difference(keys_.begin(), keys_.end(), frontalKeys_.begin(),
frontalKeys_.end(), back_inserter(difference));
DiscreteKeys dkeys; DiscreteKeys dkeys;
for (Key key : difference) { for (auto&& [key, cardinality] : cardinalities_) {
dkeys.push_back(DiscreteKey(key, cardinality(key))); if (!frontalKeys.contains(key)) {
dkeys.push_back(DiscreteKey(key, cardinality));
}
} }
return std::make_shared<DecisionTreeFactor>(dkeys, result); return std::make_shared<DecisionTreeFactor>(dkeys, result);
} }

View File

@ -20,12 +20,9 @@
#include <gtsam/discrete/DiscreteKey.h> // make sure we have traits #include <gtsam/discrete/DiscreteKey.h> // make sure we have traits
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
// headers first to make sure no missing headers // headers first to make sure no missing headers
#include <CppUnitLite/TestHarness.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h> #include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DecisionTree-inl.h> // for convert only #include <gtsam/discrete/DecisionTree-inl.h> // for convert only
#define DISABLE_TIMING
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/timing.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
using namespace std; using namespace std;
@ -71,16 +68,14 @@ void dot(const T& f, const string& filename) {
// instrumented operators // instrumented operators
/* ************************************************************************** */ /* ************************************************************************** */
size_t muls = 0, adds = 0; size_t muls = 0, adds = 0;
double elapsed;
void resetCounts() { void resetCounts() {
muls = 0; muls = 0;
adds = 0; adds = 0;
} }
void printCounts(const string& s) { void printCounts(const string& s) {
#ifndef DISABLE_TIMING #ifndef DISABLE_TIMING
cout << s << ": " << std::setw(3) << muls << " muls, " << cout << s << ": " << std::setw(3) << muls << " muls, " << std::setw(3) << adds
std::setw(3) << adds << " adds, " << 1000 * elapsed << " ms." << " adds" << endl;
<< endl;
#endif #endif
resetCounts(); resetCounts();
} }
@ -131,37 +126,35 @@ ADT create(const Signature& signature) {
static size_t count = 0; static size_t count = 0;
const DiscreteKey& key = signature.key(); const DiscreteKey& key = signature.key();
std::stringstream ss; std::stringstream ss;
ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-" << key.first; ss << "CPT-" << std::setw(3) << std::setfill('0') << ++count << "-"
<< key.first;
string DOTfile = ss.str(); string DOTfile = ss.str();
dot(p, DOTfile); dot(p, DOTfile);
return p; return p;
} }
/* ************************************************************************* */
namespace asiaCPTs {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);
ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
ADT pL = create(L | S = "99/1 90/10");
ADT pB = create(B | S = "70/30 40/60");
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
} // namespace asiaCPTs
/* ************************************************************************* */ /* ************************************************************************* */
// test Asia Joint // test Asia Joint
TEST(ADT, joint) { TEST(ADT, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), using namespace asiaCPTs;
D(7, 2);
resetCounts();
gttic_(asiaCPTs);
ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5");
ADT pL = create(L | S = "99/1 90/10");
ADT pB = create(B | S = "70/30 40/60");
ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
gttoc_(asiaCPTs);
tictoc_getNode(asiaCPTsNode, asiaCPTs);
elapsed = asiaCPTsNode->secs() + asiaCPTsNode->wall();
tictoc_reset_();
printCounts("Asia CPTs");
// Create joint // Create joint
resetCounts(); resetCounts();
gttic_(asiaJoint);
ADT joint = pA; ADT joint = pA;
dot(joint, "Asia-A"); dot(joint, "Asia-A");
joint = apply(joint, pS, &mul); joint = apply(joint, pS, &mul);
@ -183,11 +176,12 @@ TEST(ADT, joint) {
#else #else
EXPECT_LONGS_EQUAL(508, muls); EXPECT_LONGS_EQUAL(508, muls);
#endif #endif
gttoc_(asiaJoint);
tictoc_getNode(asiaJointNode, asiaJoint);
elapsed = asiaJointNode->secs() + asiaJointNode->wall();
tictoc_reset_();
printCounts("Asia joint"); printCounts("Asia joint");
}
/* ************************************************************************* */
TEST(ADT, combine) {
using namespace asiaCPTs;
// Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S) // Form P(A,S,T,L) = P(A) P(S) P(T|A) P(L|S)
ADT pASTL = pA; ADT pASTL = pA;
@ -203,13 +197,11 @@ TEST(ADT, joint) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
// test Inference with joint // test Inference with joint, created using different ordering
TEST(ADT, inference) { TEST(ADT, inference) {
DiscreteKey A(0, 2), D(1, 2), // DiscreteKey A(0, 2), D(1, 2), //
B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2); B(2, 2), L(3, 2), E(4, 2), S(5, 2), T(6, 2), X(7, 2);
resetCounts();
gttic_(infCPTs);
ADT pA = create(A % "99/1"); ADT pA = create(A % "99/1");
ADT pS = create(S % "50/50"); ADT pS = create(S % "50/50");
ADT pT = create(T | A = "99/1 95/5"); ADT pT = create(T | A = "99/1 95/5");
@ -218,15 +210,9 @@ TEST(ADT, inference) {
ADT pE = create((E | T, L) = "F T T T"); ADT pE = create((E | T, L) = "F T T T");
ADT pX = create(X | E = "95/5 2/98"); ADT pX = create(X | E = "95/5 2/98");
ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
gttoc_(infCPTs);
tictoc_getNode(infCPTsNode, infCPTs);
elapsed = infCPTsNode->secs() + infCPTsNode->wall();
tictoc_reset_();
// printCounts("Inference CPTs");
// Create joint // Create joint, note different ordering than above: different tree!
resetCounts(); resetCounts();
gttic_(asiaProd);
ADT joint = pA; ADT joint = pA;
dot(joint, "Joint-Product-A"); dot(joint, "Joint-Product-A");
joint = apply(joint, pS, &mul); joint = apply(joint, pS, &mul);
@ -248,14 +234,9 @@ TEST(ADT, inference) {
#else #else
EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering EXPECT_LONGS_EQUAL(508, (long)muls); // different ordering
#endif #endif
gttoc_(asiaProd);
tictoc_getNode(asiaProdNode, asiaProd);
elapsed = asiaProdNode->secs() + asiaProdNode->wall();
tictoc_reset_();
printCounts("Asia product"); printCounts("Asia product");
resetCounts(); resetCounts();
gttic_(asiaSum);
ADT marginal = joint; ADT marginal = joint;
marginal = marginal.combine(X, &add_); marginal = marginal.combine(X, &add_);
dot(marginal, "Joint-Sum-ADBLEST"); dot(marginal, "Joint-Sum-ADBLEST");
@ -270,10 +251,6 @@ TEST(ADT, inference) {
#else #else
EXPECT_LONGS_EQUAL(240, (long)adds); EXPECT_LONGS_EQUAL(240, (long)adds);
#endif #endif
gttoc_(asiaSum);
tictoc_getNode(asiaSumNode, asiaSum);
elapsed = asiaSumNode->secs() + asiaSumNode->wall();
tictoc_reset_();
printCounts("Asia sum"); printCounts("Asia sum");
} }
@ -281,8 +258,6 @@ TEST(ADT, inference) {
TEST(ADT, factor_graph) { TEST(ADT, factor_graph) {
DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2); DiscreteKey B(0, 2), L(1, 2), E(2, 2), S(3, 2), T(4, 2), X(5, 2);
resetCounts();
gttic_(createCPTs);
ADT pS = create(S % "50/50"); ADT pS = create(S % "50/50");
ADT pT = create(T % "95/5"); ADT pT = create(T % "95/5");
ADT pL = create(L | S = "99/1 90/10"); ADT pL = create(L | S = "99/1 90/10");
@ -290,15 +265,9 @@ TEST(ADT, factor_graph) {
ADT pX = create(X | E = "95/5 2/98"); ADT pX = create(X | E = "95/5 2/98");
ADT pD = create(B | E = "1/8 7/9"); ADT pD = create(B | E = "1/8 7/9");
ADT pB = create(B | S = "70/30 40/60"); ADT pB = create(B | S = "70/30 40/60");
gttoc_(createCPTs);
tictoc_getNode(createCPTsNode, createCPTs);
elapsed = createCPTsNode->secs() + createCPTsNode->wall();
tictoc_reset_();
// printCounts("Create CPTs");
// Create joint // Create joint
resetCounts(); resetCounts();
gttic_(asiaFG);
ADT fg = pS; ADT fg = pS;
fg = apply(fg, pT, &mul); fg = apply(fg, pT, &mul);
fg = apply(fg, pL, &mul); fg = apply(fg, pL, &mul);
@ -312,14 +281,9 @@ TEST(ADT, factor_graph) {
#else #else
EXPECT_LONGS_EQUAL(188, (long)muls); EXPECT_LONGS_EQUAL(188, (long)muls);
#endif #endif
gttoc_(asiaFG);
tictoc_getNode(asiaFGNode, asiaFG);
elapsed = asiaFGNode->secs() + asiaFGNode->wall();
tictoc_reset_();
printCounts("Asia FG"); printCounts("Asia FG");
resetCounts(); resetCounts();
gttic_(marg);
fg = fg.combine(X, &add_); fg = fg.combine(X, &add_);
dot(fg, "Marginalized-6X"); dot(fg, "Marginalized-6X");
fg = fg.combine(T, &add_); fg = fg.combine(T, &add_);
@ -335,83 +299,54 @@ TEST(ADT, factor_graph) {
#else #else
LONGS_EQUAL(62, adds); LONGS_EQUAL(62, adds);
#endif #endif
gttoc_(marg);
tictoc_getNode(margNode, marg);
elapsed = margNode->secs() + margNode->wall();
tictoc_reset_();
printCounts("marginalize"); printCounts("marginalize");
// BLESTX // BLESTX
// Eliminate X // Eliminate X
resetCounts(); resetCounts();
gttic_(elimX);
ADT fE = pX; ADT fE = pX;
dot(fE, "Eliminate-01-fEX"); dot(fE, "Eliminate-01-fEX");
fE = fE.combine(X, &add_); fE = fE.combine(X, &add_);
dot(fE, "Eliminate-02-fE"); dot(fE, "Eliminate-02-fE");
gttoc_(elimX);
tictoc_getNode(elimXNode, elimX);
elapsed = elimXNode->secs() + elimXNode->wall();
tictoc_reset_();
printCounts("Eliminate X"); printCounts("Eliminate X");
// Eliminate T // Eliminate T
resetCounts(); resetCounts();
gttic_(elimT);
ADT fLE = pT; ADT fLE = pT;
fLE = apply(fLE, pE, &mul); fLE = apply(fLE, pE, &mul);
dot(fLE, "Eliminate-03-fLET"); dot(fLE, "Eliminate-03-fLET");
fLE = fLE.combine(T, &add_); fLE = fLE.combine(T, &add_);
dot(fLE, "Eliminate-04-fLE"); dot(fLE, "Eliminate-04-fLE");
gttoc_(elimT);
tictoc_getNode(elimTNode, elimT);
elapsed = elimTNode->secs() + elimTNode->wall();
tictoc_reset_();
printCounts("Eliminate T"); printCounts("Eliminate T");
// Eliminate S // Eliminate S
resetCounts(); resetCounts();
gttic_(elimS);
ADT fBL = pS; ADT fBL = pS;
fBL = apply(fBL, pL, &mul); fBL = apply(fBL, pL, &mul);
fBL = apply(fBL, pB, &mul); fBL = apply(fBL, pB, &mul);
dot(fBL, "Eliminate-05-fBLS"); dot(fBL, "Eliminate-05-fBLS");
fBL = fBL.combine(S, &add_); fBL = fBL.combine(S, &add_);
dot(fBL, "Eliminate-06-fBL"); dot(fBL, "Eliminate-06-fBL");
gttoc_(elimS);
tictoc_getNode(elimSNode, elimS);
elapsed = elimSNode->secs() + elimSNode->wall();
tictoc_reset_();
printCounts("Eliminate S"); printCounts("Eliminate S");
// Eliminate E // Eliminate E
resetCounts(); resetCounts();
gttic_(elimE);
ADT fBL2 = fE; ADT fBL2 = fE;
fBL2 = apply(fBL2, fLE, &mul); fBL2 = apply(fBL2, fLE, &mul);
fBL2 = apply(fBL2, pD, &mul); fBL2 = apply(fBL2, pD, &mul);
dot(fBL2, "Eliminate-07-fBLE"); dot(fBL2, "Eliminate-07-fBLE");
fBL2 = fBL2.combine(E, &add_); fBL2 = fBL2.combine(E, &add_);
dot(fBL2, "Eliminate-08-fBL2"); dot(fBL2, "Eliminate-08-fBL2");
gttoc_(elimE);
tictoc_getNode(elimENode, elimE);
elapsed = elimENode->secs() + elimENode->wall();
tictoc_reset_();
printCounts("Eliminate E"); printCounts("Eliminate E");
// Eliminate L // Eliminate L
resetCounts(); resetCounts();
gttic_(elimL);
ADT fB = fBL; ADT fB = fBL;
fB = apply(fB, fBL2, &mul); fB = apply(fB, fBL2, &mul);
dot(fB, "Eliminate-09-fBL"); dot(fB, "Eliminate-09-fBL");
fB = fB.combine(L, &add_); fB = fB.combine(L, &add_);
dot(fB, "Eliminate-10-fB"); dot(fB, "Eliminate-10-fB");
gttoc_(elimL);
tictoc_getNode(elimLNode, elimL);
elapsed = elimLNode->secs() + elimLNode->wall();
tictoc_reset_();
printCounts("Eliminate L"); printCounts("Eliminate L");
} }

View File

@ -22,7 +22,10 @@
#include <gtsam/base/serializationTestHelpers.h> #include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h> #include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/Signature.h> #include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Key.h>
#include <gtsam/inference/Ordering.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -33,25 +36,24 @@ TEST(DecisionTreeFactor, ConstructorsMatch) {
DiscreteKey X(0, 2), Y(1, 3); DiscreteKey X(0, 2), Y(1, 3);
// Create with vector and with string // Create with vector and with string
const std::vector<double> table {2, 5, 3, 6, 4, 7}; const std::vector<double> table{2, 5, 3, 6, 4, 7};
DecisionTreeFactor f1({X, Y}, table); DecisionTreeFactor f1({X, Y}, table);
DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7"); DecisionTreeFactor f2({X, Y}, "2 5 3 6 4 7");
EXPECT(assert_equal(f1, f2)); EXPECT(assert_equal(f1, f2));
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DecisionTreeFactor, constructors) TEST(DecisionTreeFactor, constructors) {
{
// Declare a bunch of keys // Declare a bunch of keys
DiscreteKey X(0,2), Y(1,3), Z(2,2); DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
// Create factors // Create factors
DecisionTreeFactor f1(X, {2, 8}); DecisionTreeFactor f1(X, {2, 8});
DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7"); DecisionTreeFactor f2(X & Y, "2 5 3 6 4 7");
DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); DecisionTreeFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
EXPECT_LONGS_EQUAL(1,f1.size()); EXPECT_LONGS_EQUAL(1, f1.size());
EXPECT_LONGS_EQUAL(2,f2.size()); EXPECT_LONGS_EQUAL(2, f2.size());
EXPECT_LONGS_EQUAL(3,f3.size()); EXPECT_LONGS_EQUAL(3, f3.size());
DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}}; DiscreteValues x121{{0, 1}, {1, 2}, {2, 1}};
EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9); EXPECT_DOUBLES_EQUAL(8, f1(x121), 1e-9);
@ -70,7 +72,7 @@ TEST( DecisionTreeFactor, constructors)
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DecisionTreeFactor, Error) { TEST(DecisionTreeFactor, Error) {
// Declare a bunch of keys // Declare a bunch of keys
DiscreteKey X(0,2), Y(1,3), Z(2,2); DiscreteKey X(0, 2), Y(1, 3), Z(2, 2);
// Create factors // Create factors
DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); DecisionTreeFactor f(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
@ -104,9 +106,8 @@ TEST(DecisionTreeFactor, multiplication) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST( DecisionTreeFactor, sum_max) TEST(DecisionTreeFactor, sum_max) {
{ DiscreteKey v0(0, 3), v1(1, 2);
DiscreteKey v0(0,3), v1(1,2);
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");
DecisionTreeFactor expected(v1, "9 12"); DecisionTreeFactor expected(v1, "9 12");
@ -165,22 +166,85 @@ TEST(DecisionTreeFactor, Prune) {
"0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
"0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
DecisionTreeFactor expected3( DecisionTreeFactor expected3(D & C & B & A,
D & C & B & A, "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
"0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " "0.999952870000 1.0 1.0 1.0 1.0");
"0.999952870000 1.0 1.0 1.0 1.0");
maxNrAssignments = 5; maxNrAssignments = 5;
auto pruned3 = factor.prune(maxNrAssignments); auto pruned3 = factor.prune(maxNrAssignments);
EXPECT(assert_equal(expected3, pruned3)); EXPECT(assert_equal(expected3, pruned3));
} }
/* ************************************************************************** */
// Asia Bayes Network
/* ************************************************************************** */
#define DISABLE_DOT
void maybeSaveDotFile(const DecisionTreeFactor& f, const string& filename) {
#ifndef DISABLE_DOT
std::vector<std::string> names = {"A", "S", "T", "L", "B", "E", "X", "D"};
auto formatter = [&](Key key) { return names[key]; };
f.dot(filename, formatter, true);
#endif
}
/** Convert Signature into CPT */
DecisionTreeFactor create(const Signature& signature) {
DecisionTreeFactor p(signature.discreteKeys(), signature.cpt());
return p;
}
/* ************************************************************************* */
// test Asia Joint
TEST(DecisionTreeFactor, joint) {
DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2),
D(7, 2);
gttic_(asiaCPTs);
DecisionTreeFactor pA = create(A % "99/1");
DecisionTreeFactor pS = create(S % "50/50");
DecisionTreeFactor pT = create(T | A = "99/1 95/5");
DecisionTreeFactor pL = create(L | S = "99/1 90/10");
DecisionTreeFactor pB = create(B | S = "70/30 40/60");
DecisionTreeFactor pE = create((E | T, L) = "F T T T");
DecisionTreeFactor pX = create(X | E = "95/5 2/98");
DecisionTreeFactor pD = create((D | E, B) = "9/1 2/8 3/7 1/9");
// Create joint
gttic_(asiaJoint);
DecisionTreeFactor joint = pA;
maybeSaveDotFile(joint, "Asia-A");
joint = joint * pS;
maybeSaveDotFile(joint, "Asia-AS");
joint = joint * pT;
maybeSaveDotFile(joint, "Asia-AST");
joint = joint * pL;
maybeSaveDotFile(joint, "Asia-ASTL");
joint = joint * pB;
maybeSaveDotFile(joint, "Asia-ASTLB");
joint = joint * pE;
maybeSaveDotFile(joint, "Asia-ASTLBE");
joint = joint * pX;
maybeSaveDotFile(joint, "Asia-ASTLBEX");
joint = joint * pD;
maybeSaveDotFile(joint, "Asia-ASTLBEXD");
// Check that discrete keys are as expected
EXPECT(assert_equal(joint.discreteKeys(), {A, S, T, L, B, E, X, D}));
// Check that summing out variables maintains the keys even if merged, as is
// the case with S.
auto noAB = joint.sum(Ordering{A.first, B.first});
EXPECT(assert_equal(noAB->discreteKeys(), {S, T, L, E, X, D}));
}
/* ************************************************************************* */ /* ************************************************************************* */
TEST(DecisionTreeFactor, DotWithNames) { TEST(DecisionTreeFactor, DotWithNames) {
DiscreteKey A(12, 3), B(5, 2); DiscreteKey A(12, 3), B(5, 2);
DecisionTreeFactor f(A & B, "1 2 3 4 5 6"); DecisionTreeFactor f(A & B, "1 2 3 4 5 6");
auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
for (bool showZero:{true, false}) { for (bool showZero : {true, false}) {
string actual = f.dot(formatter, showZero); string actual = f.dot(formatter, showZero);
// pretty weak test, as ids are pointers and not stable across platforms. // pretty weak test, as ids are pointers and not stable across platforms.
string expected = "digraph G {"; string expected = "digraph G {";

View File

@ -22,7 +22,7 @@ namespace gtsam {
/* *******************************************************************************/ /* *******************************************************************************/
static void checkKeys(const KeyVector& continuousKeys, static void checkKeys(const KeyVector& continuousKeys,
std::vector<NonlinearFactorValuePair>& pairs) { const std::vector<NonlinearFactorValuePair>& pairs) {
KeySet factor_keys_set; KeySet factor_keys_set;
for (const auto& pair : pairs) { for (const auto& pair : pairs) {
auto f = pair.first; auto f = pair.first;
@ -55,14 +55,9 @@ HybridNonlinearFactor::HybridNonlinearFactor(
/* *******************************************************************************/ /* *******************************************************************************/
HybridNonlinearFactor::HybridNonlinearFactor( HybridNonlinearFactor::HybridNonlinearFactor(
const KeyVector& continuousKeys, const DiscreteKey& discreteKey, const KeyVector& continuousKeys, const DiscreteKey& discreteKey,
const std::vector<NonlinearFactorValuePair>& factors) const std::vector<NonlinearFactorValuePair>& pairs)
: Base(continuousKeys, {discreteKey}) { : Base(continuousKeys, {discreteKey}) {
std::vector<NonlinearFactorValuePair> pairs;
KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end()); KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end());
KeySet factor_keys_set;
for (auto&& [f, val] : factors) {
pairs.emplace_back(f, val);
}
checkKeys(continuousKeys, pairs); checkKeys(continuousKeys, pairs);
factors_ = FactorValuePairs({discreteKey}, pairs); factors_ = FactorValuePairs({discreteKey}, pairs);
} }

View File

@ -106,11 +106,11 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
* *
* @param continuousKeys Vector of keys for continuous factors. * @param continuousKeys Vector of keys for continuous factors.
* @param discreteKey The discrete key for the "mode", indexing components. * @param discreteKey The discrete key for the "mode", indexing components.
* @param factors Vector of gaussian factor-scalar pairs, one per mode. * @param pairs Vector of gaussian factor-scalar pairs, one per mode.
*/ */
HybridNonlinearFactor(const KeyVector& continuousKeys, HybridNonlinearFactor(const KeyVector& continuousKeys,
const DiscreteKey& discreteKey, const DiscreteKey& discreteKey,
const std::vector<NonlinearFactorValuePair>& factors); const std::vector<NonlinearFactorValuePair>& pairs);
/** /**
* @brief Construct a new HybridNonlinearFactor on a several discrete keys M, * @brief Construct a new HybridNonlinearFactor on a several discrete keys M,