diff --git a/gtsam/discrete/DiscreteKey.cpp b/gtsam/discrete/DiscreteKey.cpp index 121d61103..06ed2ca3b 100644 --- a/gtsam/discrete/DiscreteKey.cpp +++ b/gtsam/discrete/DiscreteKey.cpp @@ -48,4 +48,25 @@ namespace gtsam { return keys & key2; } + void DiscreteKeys::print(const std::string& s, + const KeyFormatter& keyFormatter) const { + for (auto&& dkey : *this) { + std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second + << std::endl; + } + } + + bool DiscreteKeys::equals(const DiscreteKeys& other, double tol) const { + if (this->size() != other.size()) { + return false; + } + + for (size_t i = 0; i < this->size(); i++) { + if (this->at(i).first != other.at(i).first || + this->at(i).second != other.at(i).second) { + return false; + } + } + return true; + } } diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h index 297d5570d..8e0802d83 100644 --- a/gtsam/discrete/DiscreteKey.h +++ b/gtsam/discrete/DiscreteKey.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -72,15 +73,27 @@ namespace gtsam { /// Print the keys and cardinalities. void print(const std::string& s = "", - const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { - for (auto&& dkey : *this) { - std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second - << std::endl; - } + const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; + + /// Check equality to another DiscreteKeys object. + bool equals(const DiscreteKeys& other, double tol = 0) const; + + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& boost::serialization::make_nvp( + "DiscreteKeys", + boost::serialization::base_object>(*this)); } }; // DiscreteKeys /// Create a list from two keys GTSAM_EXPORT DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); -} + + // traits + template <> + struct traits : public Testable {}; + + } // namespace gtsam diff --git a/gtsam/discrete/tests/testDiscreteFactor.cpp b/gtsam/discrete/tests/testDiscreteFactor.cpp index 8681cf7eb..db0491c9d 100644 --- a/gtsam/discrete/tests/testDiscreteFactor.cpp +++ b/gtsam/discrete/tests/testDiscreteFactor.cpp @@ -16,14 +16,29 @@ * @author Duy-Nguyen Ta */ -#include -#include #include +#include +#include +#include + #include using namespace boost::assign; using namespace std; using namespace gtsam; +using namespace gtsam::serializationTestHelpers; + +/* ************************************************************************* */ +TEST(DisreteKeys, Serialization) { + DiscreteKeys keys; + keys& DiscreteKey(0, 2); + keys& DiscreteKey(1, 3); + keys& DiscreteKey(2, 4); + + EXPECT(equalsObj(keys)); + EXPECT(equalsXML(keys)); + EXPECT(equalsBinary(keys)); +} /* ************************************************************************* */ int main() { @@ -31,4 +46,3 @@ int main() { return TestRegistry::runAllTests(tr); } /* ************************************************************************* */ - diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 616ea0698..e84103a50 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include #include @@ -37,12 +38,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { using shared_ptr = boost::shared_ptr; using sharedConditional = boost::shared_ptr; + /// @name Standard Constructors + /// @{ + /** Construct empty bayes net */ HybridBayesNet() = default; - /// Prune the Hybrid Bayes Net given the discrete decision tree. - HybridBayesNet prune( - const DecisionTreeFactor::shared_ptr &discreteFactor) const; + /// @} + /// @name Testable + /// @{ + + /** Check equality */ + bool equals(const This &bn, double tol = 1e-9) const { + return Base::equals(bn, tol); + } + + /// print graph + void print( + const std::string &s = "", + const KeyFormatter &formatter = DefaultKeyFormatter) const override { + Base::print(s, formatter); + } + + /// @} + /// @name Standard Interface + /// @{ /// Add HybridConditional to Bayes Net using Base::add; @@ -71,9 +91,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ GaussianBayesNet choose(const DiscreteValues &assignment) const; - /// Solve the HybridBayesNet by back-substitution. - /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and - /// put this method there? + /** + * @brief Solve the HybridBayesNet by first computing the MPE of all the + * discrete variables and then optimizing the continuous variables based on + * the MPE assignment. + * + * @return HybridValues + */ HybridValues optimize() const; /** @@ -84,6 +108,24 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { * @return Values */ VectorValues optimize(const DiscreteValues &assignment) const; + + /// Prune the Hybrid Bayes Net given the discrete decision tree. + HybridBayesNet prune( + const DecisionTreeFactor::shared_ptr &discreteFactor) const; + + /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } }; +/// traits +template <> +struct traits : public Testable {}; + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h index 361fbe86f..3fa344d4d 100644 --- a/gtsam/hybrid/HybridBayesTree.h +++ b/gtsam/hybrid/HybridBayesTree.h @@ -89,8 +89,20 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree { VectorValues optimize(const DiscreteValues& assignment) const; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + } }; +/// traits +template <> +struct traits : public Testable {}; + /** * @brief Class for Hybrid Bayes tree orphan subtrees. * diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 96ea6d969..b43bb9945 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -178,6 +178,15 @@ class GTSAM_EXPORT HybridConditional /// Get the type-erased pointer to the inner type boost::shared_ptr inner() { return inner_; } + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + } + }; // HybridConditional // traits diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 13dc2e6e6..b3cdc231b 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -47,6 +47,7 @@ class GTSAM_EXPORT HybridFactor : public Factor { bool isContinuous_ = false; bool isHybrid_ = false; + // TODO(Varun) remove size_t nrContinuous_ = 0; protected: @@ -129,6 +130,19 @@ class GTSAM_EXPORT HybridFactor : public Factor { const KeyVector &continuousKeys() const { return continuousKeys_; } /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE &ar, const unsigned int /*version*/) { + ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar &BOOST_SERIALIZATION_NVP(isDiscrete_); + ar &BOOST_SERIALIZATION_NVP(isContinuous_); + ar &BOOST_SERIALIZATION_NVP(isHybrid_); + ar &BOOST_SERIALIZATION_NVP(discreteKeys_); + ar &BOOST_SERIALIZATION_NVP(continuousKeys_); + } }; // HybridFactor diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index c7516c0f6..bf9385bc4 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -18,6 +18,7 @@ * @date December 2021 */ +#include #include #include @@ -28,6 +29,8 @@ using namespace std; using namespace gtsam; +using namespace gtsam::serializationTestHelpers; + using noiseModel::Isotropic; using symbol_shorthand::M; using symbol_shorthand::X; @@ -146,6 +149,18 @@ TEST(HybridBayesNet, Optimize) { EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5)); } +/* ****************************************************************************/ +// Test HybridBayesNet serialization. +TEST(HybridBayesNet, Serialization) { + Switching s(4); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering)); + + EXPECT(equalsObj(hbn)); + EXPECT(equalsXML(hbn)); + EXPECT(equalsBinary(hbn)); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index d457e6b74..0908b8cb5 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -16,6 +16,7 @@ * @date August 2022 */ +#include #include #include #include @@ -143,6 +144,20 @@ TEST(HybridBayesTree, Optimize) { EXPECT(assert_equal(expectedValues, delta.continuous())); } +/* ****************************************************************************/ +// Test HybridBayesTree serialization. +TEST(HybridBayesTree, Serialization) { + Switching s(4); + Ordering ordering = s.linearizedFactorGraph.getHybridOrdering(); + HybridBayesTree hbt = + *(s.linearizedFactorGraph.eliminateMultifrontal(ordering)); + + using namespace gtsam::serializationTestHelpers; + EXPECT(equalsObj(hbt)); + EXPECT(equalsXML(hbt)); + EXPECT(equalsBinary(hbt)); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/linear/tests/testSerializationLinear.cpp b/gtsam/linear/tests/testSerializationLinear.cpp index 881b2830e..ee21de364 100644 --- a/gtsam/linear/tests/testSerializationLinear.cpp +++ b/gtsam/linear/tests/testSerializationLinear.cpp @@ -198,6 +198,33 @@ TEST (Serialization, gaussian_factor_graph) { EXPECT(equalsBinary(graph)); } +/* ****************************************************************************/ +TEST(Serialization, gaussian_bayes_net) { + // Create an arbitrary Bayes Net + GaussianBayesNet gbn; + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 0, Vector2(1.0, 2.0), (Matrix2() << 3.0, 4.0, 0.0, 6.0).finished(), 3, + (Matrix2() << 7.0, 8.0, 9.0, 10.0).finished(), 4, + (Matrix2() << 11.0, 12.0, 13.0, 14.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 1, Vector2(15.0, 16.0), (Matrix2() << 17.0, 18.0, 0.0, 20.0).finished(), + 2, (Matrix2() << 21.0, 22.0, 23.0, 24.0).finished(), 4, + (Matrix2() << 25.0, 26.0, 27.0, 28.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 2, Vector2(29.0, 30.0), (Matrix2() << 31.0, 32.0, 0.0, 34.0).finished(), + 3, (Matrix2() << 35.0, 36.0, 37.0, 38.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 3, Vector2(39.0, 40.0), (Matrix2() << 41.0, 42.0, 0.0, 44.0).finished(), + 4, (Matrix2() << 45.0, 46.0, 47.0, 48.0).finished())); + gbn += GaussianConditional::shared_ptr(new GaussianConditional( + 4, Vector2(49.0, 50.0), (Matrix2() << 51.0, 52.0, 0.0, 54.0).finished())); + + std::string serialized = serialize(gbn); + GaussianBayesNet actual; + deserialize(serialized, actual); + EXPECT(assert_equal(gbn, actual)); +} + /* ************************************************************************* */ TEST (Serialization, gaussian_bayes_tree) { const Key x1=1, x2=2, x3=3, x4=4;