new HybridBayesNet optimize implementation
parent
9c7bf36db6
commit
0edcfd4ff8
|
@ -15,8 +15,9 @@
|
|||
* @date January 2022
|
||||
*/
|
||||
|
||||
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
@ -139,8 +140,19 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
|
||||
/* *******************************************************************************/
|
||||
HybridValues HybridBayesNet::optimize() const {
|
||||
auto dag = HybridLookupDAG::FromBayesNet(*this);
|
||||
return dag.argmax();
|
||||
// Solve for the MPE
|
||||
DiscreteBayesNet discrete_bn;
|
||||
for (auto &conditional : factors_) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discrete_bn.push_back(conditional->asDiscreteConditional());
|
||||
}
|
||||
}
|
||||
|
||||
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
||||
|
||||
// Given the MPE, compute the optimal continuous values.
|
||||
GaussianBayesNet gbn = this->choose(mpe);
|
||||
return HybridValues(mpe, gbn.optimize());
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
@ -136,24 +136,19 @@ TEST(HybridBayesNet, Optimize) {
|
|||
|
||||
HybridValues delta = hybridBayesNet->optimize();
|
||||
|
||||
delta.print();
|
||||
VectorValues correct;
|
||||
correct.insert(X(1), 0 * Vector1::Ones());
|
||||
correct.insert(X(2), 1 * Vector1::Ones());
|
||||
correct.insert(X(3), 2 * Vector1::Ones());
|
||||
correct.insert(X(4), 3 * Vector1::Ones());
|
||||
DiscreteValues expectedAssignment;
|
||||
expectedAssignment[M(1)] = 1;
|
||||
expectedAssignment[M(2)] = 0;
|
||||
expectedAssignment[M(3)] = 1;
|
||||
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
|
||||
|
||||
DiscreteValues assignment111;
|
||||
assignment111[M(1)] = 1;
|
||||
assignment111[M(2)] = 1;
|
||||
assignment111[M(3)] = 1;
|
||||
std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl;
|
||||
VectorValues expectedValues;
|
||||
expectedValues.insert(X(1), -0.999904 * Vector1::Ones());
|
||||
expectedValues.insert(X(2), -0.99029 * Vector1::Ones());
|
||||
expectedValues.insert(X(3), -1.00971 * Vector1::Ones());
|
||||
expectedValues.insert(X(4), -1.0001 * Vector1::Ones());
|
||||
|
||||
DiscreteValues assignment101;
|
||||
assignment101[M(1)] = 1;
|
||||
assignment101[M(2)] = 0;
|
||||
assignment101[M(3)] = 1;
|
||||
std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl;
|
||||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue