Make HybridBayesNet testable and add serialization
parent
eb5092897b
commit
8692ae63ea
|
@ -18,6 +18,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/global_includes.h>
|
||||
#include <gtsam/hybrid/HybridConditional.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
|
@ -37,12 +38,31 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
using shared_ptr = boost::shared_ptr<HybridBayesNet>;
|
||||
using sharedConditional = boost::shared_ptr<ConditionalType>;
|
||||
|
||||
/// @name Standard Constructors
|
||||
/// @{
|
||||
|
||||
/** Construct empty bayes net */
|
||||
HybridBayesNet() = default;
|
||||
|
||||
/// Prune the Hybrid Bayes Net given the discrete decision tree.
|
||||
HybridBayesNet prune(
|
||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/** Check equality */
|
||||
bool equals(const This &bn, double tol = 1e-9) const {
|
||||
return Base::equals(bn, tol);
|
||||
}
|
||||
|
||||
/// print graph
|
||||
void print(
|
||||
const std::string &s = "",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override {
|
||||
Base::print(s, formatter);
|
||||
}
|
||||
|
||||
/// @}
|
||||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
/// Add HybridConditional to Bayes Net
|
||||
using Base::add;
|
||||
|
@ -71,9 +91,13 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
*/
|
||||
GaussianBayesNet choose(const DiscreteValues &assignment) const;
|
||||
|
||||
/// Solve the HybridBayesNet by back-substitution.
|
||||
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
|
||||
/// put this method there?
|
||||
/**
|
||||
* @brief Solve the HybridBayesNet by first computing the MPE of all the
|
||||
* discrete variables and then optimizing the continuous variables based on
|
||||
* the MPE assignment.
|
||||
*
|
||||
* @return HybridValues
|
||||
*/
|
||||
HybridValues optimize() const;
|
||||
|
||||
/**
|
||||
|
@ -84,6 +108,24 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* @return Values
|
||||
*/
|
||||
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||
|
||||
/// Prune the Hybrid Bayes Net given the discrete decision tree.
|
||||
HybridBayesNet prune(
|
||||
const DecisionTreeFactor::shared_ptr &discreteFactor) const;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
/** Serialization function */
|
||||
friend class boost::serialization::access;
|
||||
template <class ARCHIVE>
|
||||
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||
}
|
||||
};
|
||||
|
||||
/// traits
|
||||
template <>
|
||||
struct traits<HybridBayesNet> : public Testable<HybridBayesNet> {};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
* @date December 2021
|
||||
*/
|
||||
|
||||
#include <gtsam/base/serializationTestHelpers.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
|
||||
|
||||
|
@ -28,6 +29,8 @@
|
|||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using namespace gtsam::serializationTestHelpers;
|
||||
|
||||
using noiseModel::Isotropic;
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
|
@ -146,6 +149,18 @@ TEST(HybridBayesNet, Optimize) {
|
|||
EXPECT(assert_equal(expectedValues, delta.continuous(), 1e-5));
|
||||
}
|
||||
|
||||
/* ****************************************************************************/
|
||||
// Test HybridBayesNet serialization.
|
||||
TEST(HybridBayesNet, Serialization) {
|
||||
Switching s(4);
|
||||
Ordering ordering = s.linearizedFactorGraph.getHybridOrdering();
|
||||
HybridBayesNet hbn = *(s.linearizedFactorGraph.eliminateSequential(ordering));
|
||||
|
||||
EXPECT(equalsObj<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsXML<HybridBayesNet>(hbn));
|
||||
EXPECT(equalsBinary<HybridBayesNet>(hbn));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
|
|
Loading…
Reference in New Issue