new HybridBayesNet optimize implementation

release/4.3a0
Varun Agrawal 2022-08-26 19:36:11 -04:00
parent 9c7bf36db6
commit 0edcfd4ff8
2 changed files with 26 additions and 19 deletions

View File

@ -15,8 +15,9 @@
* @date January 2022 * @date January 2022
*/ */
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
namespace gtsam { namespace gtsam {
@ -139,8 +140,19 @@ GaussianBayesNet HybridBayesNet::choose(
/* *******************************************************************************/ /* *******************************************************************************/
HybridValues HybridBayesNet::optimize() const { HybridValues HybridBayesNet::optimize() const {
auto dag = HybridLookupDAG::FromBayesNet(*this); // Solve for the MPE
return dag.argmax(); DiscreteBayesNet discrete_bn;
for (auto &conditional : factors_) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional());
}
}
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = this->choose(mpe);
return HybridValues(mpe, gbn.optimize());
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -136,24 +136,19 @@ TEST(HybridBayesNet, Optimize) {
HybridValues delta = hybridBayesNet->optimize(); HybridValues delta = hybridBayesNet->optimize();
delta.print(); DiscreteValues expectedAssignment;
VectorValues correct; expectedAssignment[M(1)] = 1;
correct.insert(X(1), 0 * Vector1::Ones()); expectedAssignment[M(2)] = 0;
correct.insert(X(2), 1 * Vector1::Ones()); expectedAssignment[M(3)] = 1;
correct.insert(X(3), 2 * Vector1::Ones()); EXPECT(assert_equal(expectedAssignment, delta.discrete()));
correct.insert(X(4), 3 * Vector1::Ones());
DiscreteValues assignment111; VectorValues expectedValues;
assignment111[M(1)] = 1; expectedValues.insert(X(1), -0.999904 * Vector1::Ones());
assignment111[M(2)] = 1; expectedValues.insert(X(2), -0.99029 * Vector1::Ones());
assignment111[M(3)] = 1; expectedValues.insert(X(3), -1.00971 * Vector1::Ones());
std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl; expectedValues.insert(X(4), -1.0001 * Vector1::Ones());
DiscreteValues assignment101; EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
assignment101[M(1)] = 1;
assignment101[M(2)] = 0;
assignment101[M(3)] = 1;
std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl;
} }
/* ************************************************************************* */ /* ************************************************************************* */