use enum to categorize HybridFactor
							parent
							
								
									1c74da26f4
								
							
						
					
					
						commit
						3a7a0b84fe
					
				| 
						 | 
				
			
			@ -50,31 +50,37 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
 | 
			
		|||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
HybridFactor::HybridFactor(const KeyVector &keys)
 | 
			
		||||
    : Base(keys), isContinuous_(true), continuousKeys_(keys) {}
 | 
			
		||||
    : Base(keys),
 | 
			
		||||
      category_(HybridCategory::Continuous),
 | 
			
		||||
      continuousKeys_(keys) {}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
HybridFactor::HybridFactor(const KeyVector &continuousKeys,
 | 
			
		||||
                           const DiscreteKeys &discreteKeys)
 | 
			
		||||
    : Base(CollectKeys(continuousKeys, discreteKeys)),
 | 
			
		||||
      isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)),
 | 
			
		||||
      isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)),
 | 
			
		||||
      isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)),
 | 
			
		||||
      discreteKeys_(discreteKeys),
 | 
			
		||||
      continuousKeys_(continuousKeys) {}
 | 
			
		||||
      continuousKeys_(continuousKeys) {
 | 
			
		||||
  if ((continuousKeys.size() == 0) && (discreteKeys.size() != 0)) {
 | 
			
		||||
    category_ = HybridCategory::Discrete;
 | 
			
		||||
  } else if ((continuousKeys.size() != 0) && (discreteKeys.size() == 0)) {
 | 
			
		||||
    category_ = HybridCategory::Continuous;
 | 
			
		||||
  } else {
 | 
			
		||||
    category_ = HybridCategory::Hybrid;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
 | 
			
		||||
    : Base(CollectKeys({}, discreteKeys)),
 | 
			
		||||
      isDiscrete_(true),
 | 
			
		||||
      category_(HybridCategory::Discrete),
 | 
			
		||||
      discreteKeys_(discreteKeys),
 | 
			
		||||
      continuousKeys_({}) {}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
 | 
			
		||||
  const This *e = dynamic_cast<const This *>(&lf);
 | 
			
		||||
  return e != nullptr && Base::equals(*e, tol) &&
 | 
			
		||||
         isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ &&
 | 
			
		||||
         isHybrid_ == e->isHybrid_ && continuousKeys_ == e->continuousKeys_ &&
 | 
			
		||||
  return e != nullptr && Base::equals(*e, tol) && category_ == e->category_ &&
 | 
			
		||||
         continuousKeys_ == e->continuousKeys_ &&
 | 
			
		||||
         discreteKeys_ == e->discreteKeys_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -82,9 +88,18 @@ bool HybridFactor::equals(const HybridFactor &lf, double tol) const {
 | 
			
		|||
void HybridFactor::print(const std::string &s,
 | 
			
		||||
                         const KeyFormatter &formatter) const {
 | 
			
		||||
  std::cout << (s.empty() ? "" : s + "\n");
 | 
			
		||||
  if (isContinuous_) std::cout << "Continuous ";
 | 
			
		||||
  if (isDiscrete_) std::cout << "Discrete ";
 | 
			
		||||
  if (isHybrid_) std::cout << "Hybrid ";
 | 
			
		||||
  switch (category_) {
 | 
			
		||||
    case HybridCategory::Continuous:
 | 
			
		||||
      std::cout << "Continuous ";
 | 
			
		||||
      break;
 | 
			
		||||
    case HybridCategory::Discrete:
 | 
			
		||||
      std::cout << "Discrete ";
 | 
			
		||||
      break;
 | 
			
		||||
    case HybridCategory::Hybrid:
 | 
			
		||||
      std::cout << "Hybrid ";
 | 
			
		||||
      break;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::cout << "[";
 | 
			
		||||
  for (size_t c = 0; c < continuousKeys_.size(); c++) {
 | 
			
		||||
    std::cout << formatter(continuousKeys_.at(c));
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -41,6 +41,9 @@ KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
 | 
			
		|||
DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
 | 
			
		||||
                                 const DiscreteKeys &key2);
 | 
			
		||||
 | 
			
		||||
/// Enum to help with categorizing hybrid factors.
 | 
			
		||||
enum class HybridCategory { Discrete, Continuous, Hybrid };
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Base class for *truly* hybrid probabilistic factors
 | 
			
		||||
 *
 | 
			
		||||
| 
						 | 
				
			
			@ -53,9 +56,8 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
 | 
			
		|||
 */
 | 
			
		||||
class GTSAM_EXPORT HybridFactor : public Factor {
 | 
			
		||||
 private:
 | 
			
		||||
  bool isDiscrete_ = false;
 | 
			
		||||
  bool isContinuous_ = false;
 | 
			
		||||
  bool isHybrid_ = false;
 | 
			
		||||
  /// Record what category of HybridFactor this is.
 | 
			
		||||
  HybridCategory category_;
 | 
			
		||||
 | 
			
		||||
 protected:
 | 
			
		||||
  // Set of DiscreteKeys for this factor.
 | 
			
		||||
| 
						 | 
				
			
			@ -116,13 +118,13 @@ class GTSAM_EXPORT HybridFactor : public Factor {
 | 
			
		|||
  /// @{
 | 
			
		||||
 | 
			
		||||
  /// True if this is a factor of discrete variables only.
 | 
			
		||||
  bool isDiscrete() const { return isDiscrete_; }
 | 
			
		||||
  bool isDiscrete() const { return category_ == HybridCategory::Discrete; }
 | 
			
		||||
 | 
			
		||||
  /// True if this is a factor of continuous variables only.
 | 
			
		||||
  bool isContinuous() const { return isContinuous_; }
 | 
			
		||||
  bool isContinuous() const { return category_ == HybridCategory::Continuous; }
 | 
			
		||||
 | 
			
		||||
  /// True is this is a Discrete-Continuous factor.
 | 
			
		||||
  bool isHybrid() const { return isHybrid_; }
 | 
			
		||||
  bool isHybrid() const { return category_ == HybridCategory::Hybrid; }
 | 
			
		||||
 | 
			
		||||
  /// Return the number of continuous variables in this factor.
 | 
			
		||||
  size_t nrContinuous() const { return continuousKeys_.size(); }
 | 
			
		||||
| 
						 | 
				
			
			@ -142,9 +144,7 @@ class GTSAM_EXPORT HybridFactor : public Factor {
 | 
			
		|||
  template <class ARCHIVE>
 | 
			
		||||
  void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
 | 
			
		||||
    ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
 | 
			
		||||
    ar &BOOST_SERIALIZATION_NVP(isDiscrete_);
 | 
			
		||||
    ar &BOOST_SERIALIZATION_NVP(isContinuous_);
 | 
			
		||||
    ar &BOOST_SERIALIZATION_NVP(isHybrid_);
 | 
			
		||||
    ar &BOOST_SERIALIZATION_NVP(category_);
 | 
			
		||||
    ar &BOOST_SERIALIZATION_NVP(discreteKeys_);
 | 
			
		||||
    ar &BOOST_SERIALIZATION_NVP(continuousKeys_);
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -387,11 +387,13 @@ TEST(HybridBayesNet, Sampling) {
 | 
			
		|||
      std::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model);
 | 
			
		||||
  auto one_motion =
 | 
			
		||||
      std::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model);
 | 
			
		||||
  std::vector<NonlinearFactorValuePair> factors = {{zero_motion, 0.0},
 | 
			
		||||
                                                   {one_motion, 0.0}};
 | 
			
		||||
 | 
			
		||||
  DiscreteKeys discreteKeys{DiscreteKey(M(0), 2)};
 | 
			
		||||
  HybridNonlinearFactor::Factors factors(
 | 
			
		||||
      discreteKeys, {{zero_motion, 0.0}, {one_motion, 0.0}});
 | 
			
		||||
  nfg.emplace_shared<PriorFactor<double>>(X(0), 0.0, noise_model);
 | 
			
		||||
  nfg.emplace_shared<HybridNonlinearFactor>(
 | 
			
		||||
      KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors);
 | 
			
		||||
  nfg.emplace_shared<HybridNonlinearFactor>(KeyVector{X(0), X(1)}, discreteKeys,
 | 
			
		||||
                                            factors);
 | 
			
		||||
 | 
			
		||||
  DiscreteKey mode(M(0), 2);
 | 
			
		||||
  nfg.emplace_shared<DiscreteDistribution>(mode, "1/1");
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue