new HybridBayesNet optimize implementation

release/4.3a0
Varun Agrawal 2022-08-26 19:36:11 -04:00
parent 9c7bf36db6
commit 0edcfd4ff8
2 changed files with 26 additions and 19 deletions

View File

@ -15,8 +15,9 @@
* @date January 2022
*/
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>
namespace gtsam {
@ -139,8 +140,19 @@ GaussianBayesNet HybridBayesNet::choose(
/* *******************************************************************************/
HybridValues HybridBayesNet::optimize() const {
auto dag = HybridLookupDAG::FromBayesNet(*this);
return dag.argmax();
// Solve for the MPE
DiscreteBayesNet discrete_bn;
for (auto &conditional : factors_) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional());
}
}
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = this->choose(mpe);
return HybridValues(mpe, gbn.optimize());
}
/* *******************************************************************************/

View File

@ -136,24 +136,19 @@ TEST(HybridBayesNet, Optimize) {
HybridValues delta = hybridBayesNet->optimize();
delta.print();
VectorValues correct;
correct.insert(X(1), 0 * Vector1::Ones());
correct.insert(X(2), 1 * Vector1::Ones());
correct.insert(X(3), 2 * Vector1::Ones());
correct.insert(X(4), 3 * Vector1::Ones());
DiscreteValues expectedAssignment;
expectedAssignment[M(1)] = 1;
expectedAssignment[M(2)] = 0;
expectedAssignment[M(3)] = 1;
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
DiscreteValues assignment111;
assignment111[M(1)] = 1;
assignment111[M(2)] = 1;
assignment111[M(3)] = 1;
std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl;
VectorValues expectedValues;
expectedValues.insert(X(1), -0.999904 * Vector1::Ones());
expectedValues.insert(X(2), -0.99029 * Vector1::Ones());
expectedValues.insert(X(3), -1.00971 * Vector1::Ones());
expectedValues.insert(X(4), -1.0001 * Vector1::Ones());
DiscreteValues assignment101;
assignment101[M(1)] = 1;
assignment101[M(2)] = 0;
assignment101[M(3)] = 1;
std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl;
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
}
/* ************************************************************************* */