DecisionTree serialization
parent
3771d63835
commit
99a3fbac2c
|
|
@ -64,6 +64,9 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
size_t nrAssignments_;
|
size_t nrAssignments_;
|
||||||
|
|
||||||
|
/// Default constructor for serialization.
|
||||||
|
Leaf() {}
|
||||||
|
|
||||||
/// Constructor from constant
|
/// Constructor from constant
|
||||||
Leaf(const Y& constant, size_t nrAssignments = 1)
|
Leaf(const Y& constant, size_t nrAssignments = 1)
|
||||||
: constant_(constant), nrAssignments_(nrAssignments) {}
|
: constant_(constant), nrAssignments_(nrAssignments) {}
|
||||||
|
|
@ -154,6 +157,18 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isLeaf() const override { return true; }
|
bool isLeaf() const override { return true; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
using Base = DecisionTree<L, Y>::Node;
|
||||||
|
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(constant_);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
|
||||||
|
}
|
||||||
}; // Leaf
|
}; // Leaf
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
|
@ -177,6 +192,9 @@ namespace gtsam {
|
||||||
using ChoicePtr = boost::shared_ptr<const Choice>;
|
using ChoicePtr = boost::shared_ptr<const Choice>;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
/// Default constructor for serialization.
|
||||||
|
Choice() {}
|
||||||
|
|
||||||
~Choice() override {
|
~Choice() override {
|
||||||
#ifdef DT_DEBUG_MEMORY
|
#ifdef DT_DEBUG_MEMORY
|
||||||
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
|
||||||
|
|
@ -428,6 +446,19 @@ namespace gtsam {
|
||||||
r->push_back(branch->choose(label, index));
|
r->push_back(branch->choose(label, index));
|
||||||
return Unique(r);
|
return Unique(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
using Base = DecisionTree<L, Y>::Node;
|
||||||
|
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(label_);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(branches_);
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(allSame_);
|
||||||
|
}
|
||||||
}; // Choice
|
}; // Choice
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,11 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
#include <gtsam/base/types.h>
|
#include <gtsam/base/types.h>
|
||||||
#include <gtsam/discrete/Assignment.h>
|
#include <gtsam/discrete/Assignment.h>
|
||||||
|
|
||||||
|
#include <boost/serialization/nvp.hpp>
|
||||||
#include <boost/shared_ptr.hpp>
|
#include <boost/shared_ptr.hpp>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
@ -113,6 +115,12 @@ namespace gtsam {
|
||||||
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
|
||||||
virtual Ptr choose(const L& label, size_t index) const = 0;
|
virtual Ptr choose(const L& label, size_t index) const = 0;
|
||||||
virtual bool isLeaf() const = 0;
|
virtual bool isLeaf() const = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
|
||||||
};
|
};
|
||||||
/** ------------------------ Node base class --------------------------- */
|
/** ------------------------ Node base class --------------------------- */
|
||||||
|
|
||||||
|
|
@ -364,8 +372,19 @@ namespace gtsam {
|
||||||
compose(Iterator begin, Iterator end, const L& label) const;
|
compose(Iterator begin, Iterator end, const L& label) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_NVP(root_);
|
||||||
|
}
|
||||||
}; // DecisionTree
|
}; // DecisionTree
|
||||||
|
|
||||||
|
template <class L, class Y>
|
||||||
|
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};
|
||||||
|
|
||||||
/** free versions of apply */
|
/** free versions of apply */
|
||||||
|
|
||||||
/// Apply unary operator `op` to DecisionTree `f`.
|
/// Apply unary operator `op` to DecisionTree `f`.
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,11 @@
|
||||||
// #define DT_DEBUG_MEMORY
|
// #define DT_DEBUG_MEMORY
|
||||||
// #define GTSAM_DT_NO_PRUNING
|
// #define GTSAM_DT_NO_PRUNING
|
||||||
#define DISABLE_DOT
|
#define DISABLE_DOT
|
||||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
|
||||||
#include <gtsam/discrete/Signature.h>
|
|
||||||
|
|
||||||
#include <CppUnitLite/TestHarness.h>
|
#include <CppUnitLite/TestHarness.h>
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/serializationTestHelpers.h>
|
||||||
|
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace gtsam;
|
using namespace gtsam;
|
||||||
|
|
@ -529,6 +528,35 @@ TEST(DecisionTree, ApplyWithAssignment) {
|
||||||
EXPECT_LONGS_EQUAL(5, count);
|
EXPECT_LONGS_EQUAL(5, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
using Tree = gtsam::DecisionTree<string, int>;
|
||||||
|
|
||||||
|
BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
|
||||||
|
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTree_Leaf")
|
||||||
|
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTree_Choice")
|
||||||
|
|
||||||
|
// Test HybridBayesNet serialization.
|
||||||
|
TEST(DecisionTree, Serialization) {
|
||||||
|
Tree tree({{"A", 2}}, std::vector<int>{1, 2});
|
||||||
|
|
||||||
|
using namespace serializationTestHelpers;
|
||||||
|
|
||||||
|
// Object roundtrip
|
||||||
|
Tree outputObj = create<Tree>();
|
||||||
|
roundtrip<Tree>(tree, outputObj);
|
||||||
|
EXPECT(tree.equals(outputObj));
|
||||||
|
|
||||||
|
// XML roundtrip
|
||||||
|
Tree outputXml = create<Tree>();
|
||||||
|
roundtripXML<Tree>(tree, outputXml);
|
||||||
|
EXPECT(tree.equals(outputXml));
|
||||||
|
|
||||||
|
// Binary roundtrip
|
||||||
|
Tree outputBinary = create<Tree>();
|
||||||
|
roundtripBinary<Tree>(tree, outputBinary);
|
||||||
|
EXPECT(tree.equals(outputBinary));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue