DecisionTree serialization
parent
3771d63835
commit
99a3fbac2c
|
@ -64,6 +64,9 @@ namespace gtsam {
|
|||
*/
|
||||
size_t nrAssignments_;
|
||||
|
||||
/// Default constructor for serialization.
|
||||
Leaf() {}
|
||||
|
||||
/// Constructor from constant
|
||||
Leaf(const Y& constant, size_t nrAssignments = 1)
|
||||
: constant_(constant), nrAssignments_(nrAssignments) {}
|
||||
|
@ -154,6 +157,18 @@ namespace gtsam {
|
|||
}
|
||||
|
||||
bool isLeaf() const override { return true; }
|
||||
|
||||
private:
|
||||
using Base = DecisionTree<L, Y>::Node;
|
||||
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar& BOOST_SERIALIZATION_NVP(constant_);
|
||||
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
|
||||
}
|
||||
}; // Leaf
|
||||
|
||||
/****************************************************************************/
|
||||
|
@ -177,6 +192,9 @@ namespace gtsam {
|
|||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||
|
||||
public:
|
||||
/// Default constructor for serialization.
|
||||
Choice() {}
|
||||
|
||||
~Choice() override {
|
||||
#ifdef DT_DEBUG_MEMORY
|
||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
||||
|
@ -428,6 +446,19 @@ namespace gtsam {
|
|||
r->push_back(branch->choose(label, index));
|
||||
return Unique(r);
|
||||
}
|
||||
|
||||
private:
|
||||
using Base = DecisionTree<L, Y>::Node;
|
||||
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
ar& BOOST_SERIALIZATION_NVP(label_);
|
||||
ar& BOOST_SERIALIZATION_NVP(branches_);
|
||||
ar& BOOST_SERIALIZATION_NVP(allSame_);
|
||||
}
|
||||
}; // Choice
|
||||
|
||||
/****************************************************************************/
|
||||
|
|
|
@ -19,9 +19,11 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/types.h>
|
||||
#include <gtsam/discrete/Assignment.h>
|
||||
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
#include <boost/shared_ptr.hpp>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
|
@ -113,6 +115,12 @@ namespace gtsam {
|
|||
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
||||
virtual Ptr choose(const L& label, size_t index) const = 0;
|
||||
virtual bool isLeaf() const = 0;
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
|
||||
};
|
||||
/** ------------------------ Node base class --------------------------- */
|
||||
|
||||
|
@ -364,8 +372,19 @@ namespace gtsam {
|
|||
compose(Iterator begin, Iterator end, const L& label) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||
ar& BOOST_SERIALIZATION_NVP(root_);
|
||||
}
|
||||
}; // DecisionTree
|
||||
|
||||
template <class L, class Y>
|
||||
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
|
||||
|
||||
/** free versions of apply */
|
||||
|
||||
/// Apply unary operator `op` to DecisionTree `f`.
|
||||
|
|
|
@ -20,12 +20,11 @@
|
|||
// #define DT_DEBUG_MEMORY
|
||||
// #define GTSAM_DT_NO_PRUNING
|
||||
#define DISABLE_DOT
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
#include <gtsam/discrete/Signature.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
|
@ -529,6 +528,35 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
|||
EXPECT_LONGS_EQUAL(5, count);
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
using Tree = gtsam::DecisionTree<string, int>;
|
||||
|
||||
BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
|
||||
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTree_Leaf")
|
||||
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTree_Choice")
|
||||
|
||||
// Test HybridBayesNet serialization.
|
||||
TEST(DecisionTree, Serialization) {
|
||||
Tree tree({{"A", 2}}, std::vector<int>{1, 2});
|
||||
|
||||
using namespace serializationTestHelpers;
|
||||
|
||||
// Object roundtrip
|
||||
Tree outputObj = create<Tree>();
|
||||
roundtrip<Tree>(tree, outputObj);
|
||||
EXPECT(tree.equals(outputObj));
|
||||
|
||||
// XML roundtrip
|
||||
Tree outputXml = create<Tree>();
|
||||
roundtripXML<Tree>(tree, outputXml);
|
||||
EXPECT(tree.equals(outputXml));
|
||||
|
||||
// Binary roundtrip
|
||||
Tree outputBinary = create<Tree>();
|
||||
roundtripBinary<Tree>(tree, outputBinary);
|
||||
EXPECT(tree.equals(outputBinary));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue