Make HybridBayesNet testable and add serialization
parent
eb5092897b
commit
8692ae63ea
|
@ -18,6 +18,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
|
#include <gtsam/global_includes.h>
|
||||||
#include <gtsam/hybrid/HybridConditional.h>
|
#include <gtsam/hybrid/HybridConditional.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/BayesNet.h>
|
#include <gtsam/inference/BayesNet.h>
|
||||||
|
@ -37,12 +38,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
using shared_ptr = boost::shared_ptr<HybridBayesNet>;
|
using shared_ptr = boost::shared_ptr<HybridBayesNet>;
|
||||||
using sharedConditional = boost::shared_ptr<ConditionalType>;
|
using sharedConditional = boost::shared_ptr<ConditionalType>;
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
/** Construct empty bayes net */
|
/** Construct empty bayes net */
|
||||||
HybridBayesNet() = default;
|
HybridBayesNet() = default;
|
||||||
|
|
||||||
/// Prune the Hybrid Bayes Net given the discrete decision tree.
|
/// @}
|
||||||
HybridBayesNet prune(
|
/// @name Testable
|
||||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
/// @{
|
||||||
|
|
||||||
|
/** Check equality */
|
||||||
|
bool equals(const This &bn, double tol = 1e-9) const {
|
||||||
|
return Base::equals(bn, tol);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// print graph
|
||||||
|
void print(
|
||||||
|
const std::string &s = "",
|
||||||
|
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
|
||||||
|
Base::print(s, formatter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
/// Add HybridConditional to Bayes Net
|
/// Add HybridConditional to Bayes Net
|
||||||
using Base::add;
|
using Base::add;
|
||||||
|
@ -71,9 +91,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
*/
|
*/
|
||||||
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
/// Solve the HybridBayesNet by back-substitution.
|
/**
|
||||||
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
|
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
||||||
/// put this method there?
|
* discrete variables and then optimizing the continuous variables based on
|
||||||
|
* the MPE assignment.
|
||||||
|
*
|
||||||
|
* @return HybridValues
|
||||||
|
*/
|
||||||
HybridValues optimize() const;
|
HybridValues optimize() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -84,6 +108,24 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
* @return Values
|
* @return Values
|
||||||
*/
|
*/
|
||||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||||
|
|
||||||
|
/// Prune the Hybrid Bayes Net given the discrete decision tree.
|
||||||
|
HybridBayesNet prune(
|
||||||
|
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// traits
|
||||||
|
template <>
|
||||||
|
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
* @date December 2021
|
* @date December 2021
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||||
|
|
||||||
|
@ -28,6 +29,8 @@
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
using namespace gtsam::serializationTestHelpers;
|
||||||
|
|
||||||
using noiseModel::Isotropic;
|
using noiseModel::Isotropic;
|
||||||
using symbol_shorthand::M;
|
using symbol_shorthand::M;
|
||||||
using symbol_shorthand::X;
|
using symbol_shorthand::X;
|
||||||
|
@ -146,6 +149,18 @@ TEST(HybridBayesNet, Optimize) {
|
||||||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test HybridBayesNet serialization.
|
||||||
|
TEST(HybridBayesNet, Serialization) {
|
||||||
|
Switching s(4);
|
||||||
|
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||||
|
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
|
||||||
|
|
||||||
|
EXPECT(equalsObj<HybridBayesNet>(hbn));
|
||||||
|
EXPECT(equalsXML<HybridBayesNet>(hbn));
|
||||||
|
EXPECT(equalsBinary<HybridBayesNet>(hbn));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
Loading…
Reference in New Issue