diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index f898178c2..2238b08ce 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -15,8 +15,9 @@ * @date January 2022 */ +#include +#include #include -#include #include namespace gtsam { @@ -139,8 +140,19 @@ GaussianBayesNet HybridBayesNet::choose( /* *******************************************************************************/ HybridValues HybridBayesNet::optimize() const { - auto dag = HybridLookupDAG::FromBayesNet(*this); - return dag.argmax(); + // Solve for the MPE + DiscreteBayesNet discrete_bn; + for (auto &conditional : factors_) { + if (conditional->isDiscrete()) { + discrete_bn.push_back(conditional->asDiscreteConditional()); + } + } + + DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize(); + + // Given the MPE, compute the optimal continuous values. + GaussianBayesNet gbn = this->choose(mpe); + return HybridValues(mpe, gbn.optimize()); } /* *******************************************************************************/ diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index 4602e8bac..bcaf7e599 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -136,24 +136,19 @@ TEST(HybridBayesNet, Optimize) { HybridValues delta = hybridBayesNet->optimize(); - delta.print(); - VectorValues correct; - correct.insert(X(1), 0 * Vector1::Ones()); - correct.insert(X(2), 1 * Vector1::Ones()); - correct.insert(X(3), 2 * Vector1::Ones()); - correct.insert(X(4), 3 * Vector1::Ones()); + DiscreteValues expectedAssignment; + expectedAssignment[M(1)] = 1; + expectedAssignment[M(2)] = 0; + expectedAssignment[M(3)] = 1; + EXPECT(assert_equal(expectedAssignment, delta.discrete())); - DiscreteValues assignment111; - assignment111[M(1)] = 1; - assignment111[M(2)] = 1; - assignment111[M(3)] = 1; - std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl; + VectorValues expectedValues; + expectedValues.insert(X(1), -0.999904 * Vector1::Ones()); + expectedValues.insert(X(2), -0.99029 * Vector1::Ones()); + expectedValues.insert(X(3), -1.00971 * Vector1::Ones()); + expectedValues.insert(X(4), -1.0001 * Vector1::Ones()); - DiscreteValues assignment101; - assignment101[M(1)] = 1; - assignment101[M(2)] = 0; - assignment101[M(3)] = 1; - std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl; + EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); } /* ************************************************************************* */