diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 616ea0698..e84103a50 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include #include @@ -37,12 +38,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using shared_ptr = boost::shared_ptr; using sharedConditional = boost::shared_ptr; + /// @name Standard Constructors + /// @{ + /** Construct empty bayes net */ HybridBayesNet() = default; - /// Prune the Hybrid Bayes Net given the discrete decision tree. - HybridBayesNet prune( - const DecisionTreeFactor::shared_ptr &discreteFactor) const; + /// @} + /// @name Testable + /// @{ + + /** Check equality */ + bool equals(const This &bn, double tol = 1e-9) const { + return Base::equals(bn, tol); + } + + /// print graph + void print( + const std::string &s = "", + const KeyFormatter &formatter = DefaultKeyFormatter) const override { + Base::print(s, formatter); + } + + /// @} + /// @name Standard Interface + /// @{ /// Add HybridConditional to Bayes Net using Base::add; @@ -71,9 +91,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ GaussianBayesNet choose(const DiscreteValues &assignment) const; - /// Solve the HybridBayesNet by back-substitution. - /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and - /// put this method there? + /** + * @brief Solve the HybridBayesNet by first computing the MPE of all the + * discrete variables and then optimizing the continuous variables based on + * the MPE assignment. + * + * @return HybridValues + */ HybridValues optimize() const; /** @@ -84,6 +108,24 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @return Values */ VectorValues optimize(const DiscreteValues &assignment) const; + + /// Prune the Hybrid Bayes Net given the discrete decision tree. + HybridBayesNet prune( + const DecisionTreeFactor::shared_ptr &discreteFactor) const; + + /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } }; +/// traits +template <> +struct traits : public Testable {}; + } // namespace gtsam diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index c7516c0f6..bf9385bc4 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -18,6 +18,7 @@ * @date December 2021 */ +#include #include #include @@ -28,6 +29,8 @@ using namespace std; using namespace gtsam; +using namespace gtsam::serializationTestHelpers; + using noiseModel::Isotropic; using symbol_shorthand::M; using symbol_shorthand::X; @@ -146,6 +149,18 @@ TEST(HybridBayesNet, Optimize) { EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); } +/* ****************************************************************************/ +// Test HybridBayesNet serialization. +TEST(HybridBayesNet, Serialization) { + Switching s(4); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering)); + + EXPECT(equalsObj(hbn)); + EXPECT(equalsXML(hbn)); + EXPECT(equalsBinary(hbn)); +} + /* ************************************************************************* */ int main() { TestResult tr;