MPE now works
parent
e22f8f04bc
commit
2f49612b8c
|
|
@ -27,10 +27,6 @@ using std::vector;
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
// Instantiate base class
|
||||
template class GTSAM_EXPORT
|
||||
Conditional<DecisionTreeFactor, DiscreteLookupTable>;
|
||||
|
||||
/* ************************************************************************** */
|
||||
// TODO(dellaert): copy/paste from DiscreteConditional.cpp :-(
|
||||
void DiscreteLookupTable::print(const std::string& s,
|
||||
|
|
|
|||
|
|
@ -22,17 +22,19 @@
|
|||
#include <gtsam/inference/FactorGraph.h>
|
||||
|
||||
#include <boost/shared_ptr.hpp>
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* @brief DiscreteLookupTable table for max-product
|
||||
*
|
||||
* Inherits from discrete conditional for convenience, but is not normalized.
|
||||
* Is used in pax-product algorithm.
|
||||
*/
|
||||
class DiscreteLookupTable
|
||||
: public DecisionTreeFactor,
|
||||
public Conditional<DecisionTreeFactor, DiscreteLookupTable> {
|
||||
class DiscreteLookupTable : public DiscreteConditional {
|
||||
public:
|
||||
using This = DiscreteLookupTable;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
|
@ -47,7 +49,7 @@ class DiscreteLookupTable
|
|||
*/
|
||||
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
|
||||
const ADT& potentials)
|
||||
: DecisionTreeFactor(keys, potentials), BaseConditional(nFrontals) {}
|
||||
: DiscreteConditional(nFrontals, keys, potentials) {}
|
||||
|
||||
/// GTSAM-style print
|
||||
void print(
|
||||
|
|
@ -100,6 +102,12 @@ class GTSAM_EXPORT DiscreteLookupDAG : public BayesNet<DiscreteLookupTable> {
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/** Add a DiscreteLookupTable */
|
||||
template <typename... Args>
|
||||
void add(Args&&... args) {
|
||||
emplace_shared<DiscreteLookupTable>(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief argmax by back-substitution.
|
||||
*
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ TEST(DiscreteFactorGraph, testMPE_Darwiche09book_p244) {
|
|||
Ordering ordering;
|
||||
ordering += Key(0), Key(1), Key(2), Key(3), Key(4);
|
||||
auto chordal = graph.eliminateSequential(ordering);
|
||||
EXPECT_LONGS_EQUAL(2, chordal->size());
|
||||
EXPECT_LONGS_EQUAL(5, chordal->size());
|
||||
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
|
||||
auto notOptimal = chordal->optimize(); // not MPE !
|
||||
EXPECT(graph(notOptimal) < graph(mpe));
|
||||
|
|
|
|||
|
|
@ -0,0 +1,58 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/*
|
||||
* testDiscreteLookupDAG.cpp
|
||||
*
|
||||
* @date January, 2022
|
||||
* @author Frank Dellaert
|
||||
*/
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/DiscreteLookupDAG.h>
|
||||
|
||||
#include <boost/assign/list_inserter.hpp>
|
||||
#include <boost/assign/std/map.hpp>
|
||||
|
||||
using namespace gtsam;
|
||||
using namespace boost::assign;
|
||||
|
||||
/* ************************************************************************* */
|
||||
TEST(DiscreteLookupDAG, argmax) {
|
||||
using ADT = AlgebraicDecisionTree<Key>;
|
||||
|
||||
// Declare 2 keys
|
||||
DiscreteKey A(0, 2), B(1, 2);
|
||||
|
||||
// Create lookup table corresponding to "marginalIsNotMPE" in testDFG.
|
||||
DiscreteLookupDAG dag;
|
||||
|
||||
ADT adtB(DiscreteKeys{B, A}, std::vector<double>{0.5, 1. / 3, 0.5, 2. / 3});
|
||||
dag.add(1, DiscreteKeys{B, A}, adtB);
|
||||
|
||||
ADT adtA(A, 0.5 * 10 / 19, (2. / 3) * (9. / 19));
|
||||
dag.add(1, DiscreteKeys{A}, adtA);
|
||||
|
||||
// The expected MPE is A=1, B=1
|
||||
DiscreteValues mpe;
|
||||
insert(mpe)(0, 1)(1, 1);
|
||||
|
||||
// check:
|
||||
auto actualMPE = dag.argmax();
|
||||
EXPECT(assert_equal(mpe, actualMPE));
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
Loading…
Reference in New Issue