serialize DiscreteConditional

release/4.3a0
Varun Agrawal 2023-01-03 18:15:59 -05:00
parent 2bb4fd6530
commit 6fcc087030
2 changed files with 29 additions and 3 deletions

View File

@ -239,6 +239,15 @@ class GTSAM_EXPORT DiscreteConditional
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;
private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
}
};
// DiscreteConditional

View File

@ -17,13 +17,14 @@
* @date Feb 14, 2011
*/
#include <boost/make_shared.hpp>
#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/Symbol.h>
#include <boost/make_shared.hpp>
using namespace std;
using namespace gtsam;
@ -209,7 +210,6 @@ TEST(DiscreteConditional, marginals2) {
DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);
@ -368,6 +368,23 @@ TEST(DiscreteConditional, html) {
EXPECT(actual == expected);
}
/* ************************************************************************* */
using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_ADT_Leaf")
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_ADT_Choice")
// Check serialization for DiscreteConditional
TEST(DiscreteConditional, Serialization) {
using namespace serializationTestHelpers;
DiscreteKey A(Symbol('x', 1), 3);
DiscreteConditional conditional(A % "1/2/2");
EXPECT(equalsObj<DiscreteConditional>(conditional));
EXPECT(equalsXML<DiscreteConditional>(conditional));
EXPECT(equalsBinary<DiscreteConditional>(conditional));
}
/* ************************************************************************* */
int main() {
TestResult tr;