DecisionTree serialization

release/4.3a0
Varun Agrawal 2023-01-03 16:17:40 -05:00
parent 3771d63835
commit 99a3fbac2c
3 changed files with 83 additions and 5 deletions

View File

@ -64,6 +64,9 @@ namespace gtsam {
*/
size_t nrAssignments_;
/// Default constructor for serialization.
Leaf() {}
/// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}
@ -154,6 +157,18 @@ namespace gtsam {
}
bool isLeaf() const override { return true; }
private:
using Base = DecisionTree<L, Y>::Node;
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(constant_);
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
}
}; // Leaf
/****************************************************************************/
@ -177,6 +192,9 @@ namespace gtsam {
using ChoicePtr = boost::shared_ptr<const Choice>;
public:
/// Default constructor for serialization.
Choice() {}
~Choice() override {
#ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
@ -428,6 +446,19 @@ namespace gtsam {
r->push_back(branch->choose(label, index));
return Unique(r);
}
private:
using Base = DecisionTree<L, Y>::Node;
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(label_);
ar& BOOST_SERIALIZATION_NVP(branches_);
ar& BOOST_SERIALIZATION_NVP(allSame_);
}
}; // Choice
/****************************************************************************/

View File

@ -19,9 +19,11 @@
#pragma once
#include <gtsam/base/Testable.h>
#include <gtsam/base/types.h>
#include <gtsam/discrete/Assignment.h>
#include <boost/serialization/nvp.hpp>
#include <boost/shared_ptr.hpp>
#include <functional>
#include <iostream>
@ -113,6 +115,12 @@ namespace gtsam {
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
virtual Ptr choose(const L& label, size_t index) const = 0;
virtual bool isLeaf() const = 0;
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
};
/** ------------------------ Node base class --------------------------- */
@ -364,8 +372,19 @@ namespace gtsam {
compose(Iterator begin, Iterator end, const L& label) const;
/// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_NVP(root_);
}
}; // DecisionTree
template <class L, class Y>
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
/** free versions of apply */
/// Apply unary operator `op` to DecisionTree `f`.

View File

@ -20,12 +20,11 @@
// #define DT_DEBUG_MEMORY
// #define GTSAM_DT_NO_PRUNING
#define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/Signature.h>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Signature.h>
using namespace std;
using namespace gtsam;
@ -529,6 +528,35 @@ TEST(DecisionTree, ApplyWithAssignment) {
EXPECT_LONGS_EQUAL(5, count);
}
/* ****************************************************************************/
using Tree = gtsam::DecisionTree<string, int>;
BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTree_Leaf")
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTree_Choice")
// Test HybridBayesNet serialization.
TEST(DecisionTree, Serialization) {
Tree tree({{"A", 2}}, std::vector<int>{1, 2});
using namespace serializationTestHelpers;
// Object roundtrip
Tree outputObj = create<Tree>();
roundtrip<Tree>(tree, outputObj);
EXPECT(tree.equals(outputObj));
// XML roundtrip
Tree outputXml = create<Tree>();
roundtripXML<Tree>(tree, outputXml);
EXPECT(tree.equals(outputXml));
// Binary roundtrip
Tree outputBinary = create<Tree>();
roundtripBinary<Tree>(tree, outputBinary);
EXPECT(tree.equals(outputBinary));
}
/* ************************************************************************* */
int main() {
TestResult tr;