new HybridBayesNet optimize implementation
parent
9c7bf36db6
commit
0edcfd4ff8
|
@ -15,8 +15,9 @@
|
||||||
* @date January 2022
|
* @date January 2022
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteBayesNet.h>
|
||||||
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridLookupDAG.h>
|
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
@ -139,8 +140,19 @@ GaussianBayesNet HybridBayesNet::choose(
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
HybridValues HybridBayesNet::optimize() const {
|
HybridValues HybridBayesNet::optimize() const {
|
||||||
auto dag = HybridLookupDAG::FromBayesNet(*this);
|
// Solve for the MPE
|
||||||
return dag.argmax();
|
DiscreteBayesNet discrete_bn;
|
||||||
|
for (auto &conditional : factors_) {
|
||||||
|
if (conditional->isDiscrete()) {
|
||||||
|
discrete_bn.push_back(conditional->asDiscreteConditional());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();
|
||||||
|
|
||||||
|
// Given the MPE, compute the optimal continuous values.
|
||||||
|
GaussianBayesNet gbn = this->choose(mpe);
|
||||||
|
return HybridValues(mpe, gbn.optimize());
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
@ -136,24 +136,19 @@ TEST(HybridBayesNet, Optimize) {
|
||||||
|
|
||||||
HybridValues delta = hybridBayesNet->optimize();
|
HybridValues delta = hybridBayesNet->optimize();
|
||||||
|
|
||||||
delta.print();
|
DiscreteValues expectedAssignment;
|
||||||
VectorValues correct;
|
expectedAssignment[M(1)] = 1;
|
||||||
correct.insert(X(1), 0 * Vector1::Ones());
|
expectedAssignment[M(2)] = 0;
|
||||||
correct.insert(X(2), 1 * Vector1::Ones());
|
expectedAssignment[M(3)] = 1;
|
||||||
correct.insert(X(3), 2 * Vector1::Ones());
|
EXPECT(assert_equal(expectedAssignment, delta.discrete()));
|
||||||
correct.insert(X(4), 3 * Vector1::Ones());
|
|
||||||
|
|
||||||
DiscreteValues assignment111;
|
VectorValues expectedValues;
|
||||||
assignment111[M(1)] = 1;
|
expectedValues.insert(X(1), -0.999904 * Vector1::Ones());
|
||||||
assignment111[M(2)] = 1;
|
expectedValues.insert(X(2), -0.99029 * Vector1::Ones());
|
||||||
assignment111[M(3)] = 1;
|
expectedValues.insert(X(3), -1.00971 * Vector1::Ones());
|
||||||
std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl;
|
expectedValues.insert(X(4), -1.0001 * Vector1::Ones());
|
||||||
|
|
||||||
DiscreteValues assignment101;
|
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||||
assignment101[M(1)] = 1;
|
|
||||||
assignment101[M(2)] = 0;
|
|
||||||
assignment101[M(3)] = 1;
|
|
||||||
std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue