diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index f72743206..d01c91840 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -64,6 +64,9 @@ namespace gtsam { */ size_t nrAssignments_; + /// Default constructor for serialization. + Leaf() {} + /// Constructor from constant Leaf(const Y& constant, size_t nrAssignments = 1) : constant_(constant), nrAssignments_(nrAssignments) {} @@ -154,6 +157,18 @@ namespace gtsam { } bool isLeaf() const override { return true; } + + private: + using Base = DecisionTree::Node; + + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar& BOOST_SERIALIZATION_NVP(constant_); + ar& BOOST_SERIALIZATION_NVP(nrAssignments_); + } }; // Leaf /****************************************************************************/ @@ -177,6 +192,9 @@ namespace gtsam { using ChoicePtr = boost::shared_ptr; public: + /// Default constructor for serialization. + Choice() {} + ~Choice() override { #ifdef DT_DEBUG_MEMORY std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() @@ -428,6 +446,19 @@ namespace gtsam { r->push_back(branch->choose(label, index)); return Unique(r); } + + private: + using Base = DecisionTree::Node; + + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); + ar& BOOST_SERIALIZATION_NVP(label_); + ar& BOOST_SERIALIZATION_NVP(branches_); + ar& BOOST_SERIALIZATION_NVP(allSame_); + } }; // Choice /****************************************************************************/ diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index 01a57637f..a8764a98f 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -19,9 +19,11 @@ #pragma once +#include #include #include +#include #include #include #include @@ -113,6 +115,12 @@ namespace gtsam { virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; virtual Ptr choose(const L& label, size_t index) const = 0; virtual bool isLeaf() const = 0; + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) {} }; /** ------------------------ Node base class --------------------------- */ @@ -364,8 +372,19 @@ namespace gtsam { compose(Iterator begin, Iterator end, const L& label) const; /// @} + + private: + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(ARCHIVE& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_NVP(root_); + } }; // DecisionTree + template + struct traits> : public Testable> {}; + /** free versions of apply */ /// Apply unary operator `op` to DecisionTree `f`. diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index 63b14b05e..02ef52ab8 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -20,12 +20,11 @@ // #define DT_DEBUG_MEMORY // #define GTSAM_DT_NO_PRUNING #define DISABLE_DOT -#include - -#include -#include - #include +#include +#include +#include +#include using namespace std; using namespace gtsam; @@ -529,6 +528,35 @@ TEST(DecisionTree, ApplyWithAssignment) { EXPECT_LONGS_EQUAL(5, count); } +/* ****************************************************************************/ +using Tree = gtsam::DecisionTree; + +BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt") +BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTree_Leaf") +BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTree_Choice") + +// Test HybridBayesNet serialization. +TEST(DecisionTree, Serialization) { + Tree tree({{"A", 2}}, std::vector{1, 2}); + + using namespace serializationTestHelpers; + + // Object roundtrip + Tree outputObj = create(); + roundtrip(tree, outputObj); + EXPECT(tree.equals(outputObj)); + + // XML roundtrip + Tree outputXml = create(); + roundtripXML(tree, outputXml); + EXPECT(tree.equals(outputXml)); + + // Binary roundtrip + Tree outputBinary = create(); + roundtripBinary(tree, outputBinary); + EXPECT(tree.equals(outputBinary)); +} + /* ************************************************************************* */ int main() { TestResult tr;