fix equality of HybridDiscreteFactor and HybridGaussianFactor

release/4.3a0
Varun Agrawal 2023-01-03 16:48:33 -05:00
parent 2653c2f8fb
commit 0ab15cc456
4 changed files with 39 additions and 6 deletions

View File

@ -40,8 +40,10 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
/* ************************************************************************ */ /* ************************************************************************ */
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const { bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf); const This *e = dynamic_cast<const This *>(&lf);
// TODO(Varun) How to compare inner_ when they are abstract types? if (e == nullptr) return false;
return e != nullptr && Base::equals(*e, tol); if (!Base::equals(*e, tol)) return false;
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -45,6 +45,9 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
/// @name Constructors /// @name Constructors
/// @{ /// @{
/// Default constructor - for serialization.
HybridDiscreteFactor() = default;
// Implicit conversion from a shared ptr of DF // Implicit conversion from a shared ptr of DF
HybridDiscreteFactor(DiscreteFactor::shared_ptr other); HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
@ -70,6 +73,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
/// Return the error of the underlying Discrete Factor. /// Return the error of the underlying Discrete Factor.
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
/// @} /// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(inner_);
}
}; };
// traits // traits

View File

@ -44,15 +44,21 @@ HybridGaussianFactor::HybridGaussianFactor(HessianFactor &&hf)
/* ************************************************************************* */ /* ************************************************************************* */
bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const { bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other); const This *e = dynamic_cast<const This *>(&other);
// TODO(Varun) How to compare inner_ when they are abstract types? if (e == nullptr) return false;
return e != nullptr && Base::equals(*e, tol); if (!Base::equals(*e, tol)) return false;
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
} }
/* ************************************************************************* */ /* ************************************************************************* */
void HybridGaussianFactor::print(const std::string &s, void HybridGaussianFactor::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
HybridFactor::print(s, formatter); HybridFactor::print(s, formatter);
if (inner_) {
inner_->print("\n", formatter); inner_->print("\n", formatter);
} else {
std::cout << "\nGaussian: nullptr" << std::endl;
}
}; };
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -43,6 +43,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
using This = HybridGaussianFactor; using This = HybridGaussianFactor;
using shared_ptr = boost::shared_ptr<This>; using shared_ptr = boost::shared_ptr<This>;
/// @name Constructors
/// @{
/// Default constructor - for serialization.
HybridGaussianFactor() = default; HybridGaussianFactor() = default;
/** /**
@ -79,7 +83,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/ */
explicit HybridGaussianFactor(HessianFactor &&hf); explicit HybridGaussianFactor(HessianFactor &&hf);
public: /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -101,6 +105,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// Return the error of the underlying Gaussian factor. /// Return the error of the underlying Gaussian factor.
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
/// @} /// @}
private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(inner_);
}
}; };
// traits // traits