diff --git a/gtsam/discrete/DiscreteLookupDAG.cpp b/gtsam/discrete/DiscreteLookupDAG.cpp index 37e45de80..4fe3a53a4 100644 --- a/gtsam/discrete/DiscreteLookupDAG.cpp +++ b/gtsam/discrete/DiscreteLookupDAG.cpp @@ -27,10 +27,6 @@ using std::vector; namespace gtsam { -// Instantiate base class -template class GTSAM_EXPORT - Conditional; - /* ************************************************************************** */ // TODO(dellaert): copy/paste from DiscreteConditional.cpp :-( void DiscreteLookupTable::print(const std::string& s, diff --git a/gtsam/discrete/DiscreteLookupDAG.h b/gtsam/discrete/DiscreteLookupDAG.h index a69b0b1ee..31cb3dfbf 100644 --- a/gtsam/discrete/DiscreteLookupDAG.h +++ b/gtsam/discrete/DiscreteLookupDAG.h @@ -22,17 +22,19 @@ #include #include - +#include +#include #include namespace gtsam { /** * @brief DiscreteLookupTable table for max-product + * + * Inherits from discrete conditional for convenience, but is not normalized. + * Is used in pax-product algorithm. */ -class DiscreteLookupTable - : public DecisionTreeFactor, - public Conditional { +class DiscreteLookupTable : public DiscreteConditional { public: using This = DiscreteLookupTable; using shared_ptr = boost::shared_ptr; @@ -47,7 +49,7 @@ class DiscreteLookupTable */ DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys, const ADT& potentials) - : DecisionTreeFactor(keys, potentials), BaseConditional(nFrontals) {} + : DiscreteConditional(nFrontals, keys, potentials) {} /// GTSAM-style print void print( @@ -100,6 +102,12 @@ class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet { /// @name Standard Interface /// @{ + /** Add a DiscreteLookupTable */ + template + void add(Args&&... args) { + emplace_shared(std::forward(args)...); + } + /** * @brief argmax by back-substitution. * diff --git a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp index e63cc26b8..f4819dab5 100644 --- a/gtsam/discrete/tests/testDiscreteFactorGraph.cpp +++ b/gtsam/discrete/tests/testDiscreteFactorGraph.cpp @@ -239,7 +239,7 @@ TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) { Ordering ordering; ordering += Key(0), Key(1), Key(2), Key(3), Key(4); auto chordal = graph.eliminateSequential(ordering); - EXPECT_LONGS_EQUAL(2, chordal->size()); + EXPECT_LONGS_EQUAL(5, chordal->size()); #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 auto notOptimal = chordal->optimize(); // not MPE ! EXPECT(graph(notOptimal) < graph(mpe)); diff --git a/gtsam/discrete/tests/testDiscreteLookupDAG.cpp b/gtsam/discrete/tests/testDiscreteLookupDAG.cpp new file mode 100644 index 000000000..04b859780 --- /dev/null +++ b/gtsam/discrete/tests/testDiscreteLookupDAG.cpp @@ -0,0 +1,58 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * testDiscreteLookupDAG.cpp + * + * @date January, 2022 + * @author Frank Dellaert + */ + +#include +#include +#include + +#include +#include + +using namespace gtsam; +using namespace boost::assign; + +/* ************************************************************************* */ +TEST(DiscreteLookupDAG, argmax) { + using ADT = AlgebraicDecisionTree; + + // Declare 2 keys + DiscreteKey A(0, 2), B(1, 2); + + // Create lookup table corresponding to "marginalIsNotMPE" in testDFG. + DiscreteLookupDAG dag; + + ADT adtB(DiscreteKeys{B, A}, std::vector{0.5, 1. / 3, 0.5, 2. / 3}); + dag.add(1, DiscreteKeys{B, A}, adtB); + + ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19)); + dag.add(1, DiscreteKeys{A}, adtA); + + // The expected MPE is A=1, B=1 + DiscreteValues mpe; + insert(mpe)(0, 1)(1, 1); + + // check: + auto actualMPE = dag.argmax(); + EXPECT(assert_equal(mpe, actualMPE)); +} +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */