serialization tests for HybridBayesNet and HybridBayesTree

release/4.3a0
Varun Agrawal 2023-01-03 18:21:47 -05:00
parent 6fcc087030
commit 230f65dd44
4 changed files with 49 additions and 34 deletions

View File

@ -188,6 +188,19 @@ class GTSAM_EXPORT HybridConditional
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar& BOOST_SERIALIZATION_NVP(inner_); ar& BOOST_SERIALIZATION_NVP(inner_);
// register the various casts based on the type of inner_
// https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting
if (isDiscrete()) {
boost::serialization::void_cast_register<DiscreteConditional, Factor>(
static_cast<DiscreteConditional*>(NULL), static_cast<Factor*>(NULL));
} else if (isContinuous()) {
boost::serialization::void_cast_register<GaussianConditional, Factor>(
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
} else {
boost::serialization::void_cast_register<GaussianMixture, Factor>(
static_cast<GaussianMixture*>(NULL), static_cast<Factor*>(NULL));
}
} }
}; // HybridConditional }; // HybridConditional

View File

@ -18,7 +18,6 @@
* @date December 2021 * @date December 2021
*/ */
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h> #include <gtsam/nonlinear/NonlinearFactorGraph.h>
@ -31,7 +30,6 @@
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
using namespace gtsam::serializationTestHelpers;
using noiseModel::Isotropic; using noiseModel::Isotropic;
using symbol_shorthand::M; using symbol_shorthand::M;
@ -330,20 +328,6 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
discrete_conditional_tree->apply(checker); discrete_conditional_tree->apply(checker);
} }
/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridBayesNet, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
// TODO(Varun) Serialization of inner factor doesn't work. Requires
// serialization support for all hybrid factors.
// EXPECT(equalsObj<HybridBayesNet>(hbn));
// EXPECT(equalsXML<HybridBayesNet>(hbn));
// EXPECT(equalsBinary<HybridBayesNet>(hbn));
}
/* ****************************************************************************/ /* ****************************************************************************/
// Test HybridBayesNet sampling. // Test HybridBayesNet sampling.
TEST(HybridBayesNet, Sampling) { TEST(HybridBayesNet, Sampling) {

View File

@ -220,22 +220,6 @@ TEST(HybridBayesTree, Choose) {
EXPECT(assert_equal(expected_gbt, gbt)); EXPECT(assert_equal(expected_gbt, gbt));
} }
/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridBayesTree, Serialization) {
Switching s(4);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
using namespace gtsam::serializationTestHelpers;
// TODO(Varun) Serialization of inner factor doesn't work. Requires
// serialization support for all hybrid factors.
// EXPECT(equalsObj<HybridBayesTree>(hbt));
// EXPECT(equalsXML<HybridBayesTree>(hbt));
// EXPECT(equalsBinary<HybridBayesTree>(hbt));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;

View File

@ -17,14 +17,19 @@
*/ */
#include <gtsam/base/serializationTestHelpers.h> #include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/GaussianMixture.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h> #include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridDiscreteFactor.h> #include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
#include "Switching.h"
// Include for test suite // Include for test suite
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
@ -36,15 +41,17 @@ using symbol_shorthand::Z;
using namespace serializationTestHelpers; using namespace serializationTestHelpers;
BOOST_CLASS_EXPORT_GUID(Factor, "gtsam_Factor");
BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor"); BOOST_CLASS_EXPORT_GUID(HybridFactor, "gtsam_HybridFactor");
BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor"); BOOST_CLASS_EXPORT_GUID(JacobianFactor, "gtsam_JacobianFactor");
BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional"); BOOST_CLASS_EXPORT_GUID(GaussianConditional, "gtsam_GaussianConditional");
BOOST_CLASS_EXPORT_GUID(DiscreteConditional, "gtsam_DiscreteConditional");
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor"); BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
using ADT = AlgebraicDecisionTree<Key>; using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree"); BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_DecisionTree_Leaf"); BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_AlgebraicDecisionTree_Leaf");
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_DecisionTree_Choice"); BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_AlgebraicDecisionTree_Choice")
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor"); BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor, "gtsam_GaussianMixtureFactor");
BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors, BOOST_CLASS_EXPORT_GUID(GaussianMixtureFactor::Factors,
@ -64,6 +71,8 @@ BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice,
// Needed since GaussianConditional::FromMeanAndStddev uses it // Needed since GaussianConditional::FromMeanAndStddev uses it
BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic"); BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic");
BOOST_CLASS_EXPORT_GUID(HybridBayesNet, "gtsam_HybridBayesNet");
/* ****************************************************************************/ /* ****************************************************************************/
// Test HybridGaussianFactor serialization. // Test HybridGaussianFactor serialization.
TEST(HybridSerialization, HybridGaussianFactor) { TEST(HybridSerialization, HybridGaussianFactor) {
@ -137,6 +146,31 @@ TEST(HybridSerialization, GaussianMixture) {
EXPECT(equalsBinary<GaussianMixture>(gm)); EXPECT(equalsBinary<GaussianMixture>(gm));
} }
/* ****************************************************************************/
// Test HybridBayesNet serialization.
TEST(HybridSerialization, HybridBayesNet) {
Switching s(2);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
EXPECT(equalsObj<HybridBayesNet>(hbn));
EXPECT(equalsXML<HybridBayesNet>(hbn));
EXPECT(equalsBinary<HybridBayesNet>(hbn));
}
/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridSerialization, HybridBayesTree) {
Switching s(2);
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesTree hbt =
*(s.linearizedFactorGraph.eliminateMultifrontal(ordering));
EXPECT(equalsObj<HybridBayesTree>(hbt));
EXPECT(equalsXML<HybridBayesTree>(hbt));
EXPECT(equalsBinary<HybridBayesTree>(hbt));
}
/* ************************************************************************* */ /* ************************************************************************* */
int main() { int main() {
TestResult tr; TestResult tr;