fix equality of HybridDiscreteFactor and HybridGaussianFactor
parent
2653c2f8fb
commit
0ab15cc456
|
|
@ -40,8 +40,10 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
|
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
|
||||||
const This *e = dynamic_cast<const This *>(&lf);
|
const This *e = dynamic_cast<const This *>(&lf);
|
||||||
// TODO(Varun) How to compare inner_ when they are abstract types?
|
if (e == nullptr) return false;
|
||||||
return e != nullptr && Base::equals(*e, tol);
|
if (!Base::equals(*e, tol)) return false;
|
||||||
|
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||||
|
: !(e->inner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
||||||
|
|
@ -45,6 +45,9 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
/// Default constructor - for serialization.
|
||||||
|
HybridDiscreteFactor() = default;
|
||||||
|
|
||||||
// Implicit conversion from a shared ptr of DF
|
// Implicit conversion from a shared ptr of DF
|
||||||
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
|
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);
|
||||||
|
|
||||||
|
|
@ -70,6 +73,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
|
||||||
/// Return the error of the underlying Discrete Factor.
|
/// Return the error of the underlying Discrete Factor.
|
||||||
double error(const HybridValues &values) const override;
|
double error(const HybridValues &values) const override;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(inner_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
||||||
|
|
@ -44,15 +44,21 @@ HybridGaussianFactor::HybridGaussianFactor(HessianFactor &&hf)
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
|
bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const {
|
||||||
const This *e = dynamic_cast<const This *>(&other);
|
const This *e = dynamic_cast<const This *>(&other);
|
||||||
// TODO(Varun) How to compare inner_ when they are abstract types?
|
if (e == nullptr) return false;
|
||||||
return e != nullptr && Base::equals(*e, tol);
|
if (!Base::equals(*e, tol)) return false;
|
||||||
|
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
|
||||||
|
: !(e->inner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridGaussianFactor::print(const std::string &s,
|
void HybridGaussianFactor::print(const std::string &s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter &formatter) const {
|
||||||
HybridFactor::print(s, formatter);
|
HybridFactor::print(s, formatter);
|
||||||
|
if (inner_) {
|
||||||
inner_->print("\n", formatter);
|
inner_->print("\n", formatter);
|
||||||
|
} else {
|
||||||
|
std::cout << "\nGaussian: nullptr" << std::endl;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,10 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
using This = HybridGaussianFactor;
|
using This = HybridGaussianFactor;
|
||||||
using shared_ptr = boost::shared_ptr<This>;
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
|
||||||
|
/// @name Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Default constructor - for serialization.
|
||||||
HybridGaussianFactor() = default;
|
HybridGaussianFactor() = default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -79,7 +83,7 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
*/
|
*/
|
||||||
explicit HybridGaussianFactor(HessianFactor &&hf);
|
explicit HybridGaussianFactor(HessianFactor &&hf);
|
||||||
|
|
||||||
public:
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
|
@ -101,6 +105,15 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
||||||
/// Return the error of the underlying Gaussian factor.
|
/// Return the error of the underlying Gaussian factor.
|
||||||
double error(const HybridValues &values) const override;
|
double error(const HybridValues &values) const override;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class ARCHIVE>
|
||||||
|
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
|
||||||
|
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
||||||
|
ar &BOOST_SERIALIZATION_NVP(inner_);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue