add improved versions of push_back for HybridBayesNet
							parent
							
								
									b54ed7209e
								
							
						
					
					
						commit
						351f0bd3a5
					
				|  | @ -33,6 +33,18 @@ namespace gtsam { | |||
|  * @ingroup hybrid | ||||
|  */ | ||||
| class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | ||||
|   template <typename T> | ||||
|   struct is_shared_ptr : std::false_type {}; | ||||
|   template <typename T> | ||||
|   struct is_shared_ptr<std::shared_ptr<T>> : std::true_type {}; | ||||
| 
 | ||||
|   /// Helper templates for checking if a type is a shared pointer or not
 | ||||
|   template <typename T> | ||||
|   using IsSharedPtr = typename std::enable_if<is_shared_ptr<T>::value>::type; | ||||
|   template <typename T> | ||||
|   using IsNotSharedPtr = | ||||
|       typename std::enable_if<!is_shared_ptr<T>::value>::type; | ||||
| 
 | ||||
|  public: | ||||
|   using Base = BayesNet<HybridConditional>; | ||||
|   using This = HybridBayesNet; | ||||
|  | @ -70,20 +82,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | |||
|     factors_.push_back(conditional); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * Preferred: add a conditional directly using a pointer. | ||||
|    * | ||||
|    * Examples: | ||||
|    *   hbn.emplace_back(new GaussianMixture(...))); | ||||
|    *   hbn.emplace_back(new GaussianConditional(...))); | ||||
|    *   hbn.emplace_back(new DiscreteConditional(...))); | ||||
|    */ | ||||
|   template <class Conditional> | ||||
|   void emplace_back(Conditional *conditional) { | ||||
|     factors_.push_back(std::make_shared<HybridConditional>( | ||||
|         std::shared_ptr<Conditional>(conditional))); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * Add a conditional using a shared_ptr, using implicit conversion to | ||||
|    * a HybridConditional. | ||||
|  | @ -101,6 +99,54 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | |||
|         std::make_shared<HybridConditional>(std::move(conditional))); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Add a conditional to the Bayes net. | ||||
|    * Implicitly convert to a HybridConditional. | ||||
|    * | ||||
|    * E.g. | ||||
|    * hbn.push_back(std::make_shared<DiscreteConditional>(m, "1/1")); | ||||
|    * | ||||
|    * @tparam CONDITIONAL Type of conditional. This is shared_ptr version. | ||||
|    * @param conditional The conditional as a shared pointer. | ||||
|    * @return IsSharedPtr<CONDITIONAL> | ||||
|    */ | ||||
|   template <class CONDITIONAL> | ||||
|   IsSharedPtr<CONDITIONAL> push_back(const CONDITIONAL &conditional) { | ||||
|     factors_.push_back(std::make_shared<HybridConditional>(conditional)); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Add a conditional to the Bayes net. | ||||
|    * Implicitly convert to a HybridConditional. | ||||
|    * | ||||
|    * E.g. | ||||
|    * hbn.push_back(DiscreteConditional(m, "1/1")); | ||||
|    * hbn.push_back(GaussianConditional(X(0), Vector1(0.0), I_1x1)); | ||||
|    * | ||||
|    * @tparam CONDITIONAL Type of conditional. This is const ref version. | ||||
|    * @param conditional The conditional as a const reference. | ||||
|    * @return IsSharedPtr<CONDITIONAL> | ||||
|    */ | ||||
|   template <class CONDITIONAL> | ||||
|   IsNotSharedPtr<CONDITIONAL> push_back(const CONDITIONAL &conditional) { | ||||
|     auto cond_shared_ptr = std::make_shared<CONDITIONAL>(conditional); | ||||
|     push_back(cond_shared_ptr); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * Preferred: add a conditional directly using a pointer. | ||||
|    * | ||||
|    * Examples: | ||||
|    *   hbn.emplace_back(new GaussianMixture(...))); | ||||
|    *   hbn.emplace_back(new GaussianConditional(...))); | ||||
|    *   hbn.emplace_back(new DiscreteConditional(...))); | ||||
|    */ | ||||
|   template <class Conditional> | ||||
|   void emplace_back(Conditional *conditional) { | ||||
|     factors_.push_back(std::make_shared<HybridConditional>( | ||||
|         std::shared_ptr<Conditional>(conditional))); | ||||
|   } | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Get the Gaussian Bayes Net which corresponds to a specific discrete | ||||
|    * value assignment. | ||||
|  |  | |||
|  | @ -221,12 +221,12 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel) { | |||
|   auto c0 = make_shared<GaussianConditional>(z, Vector1(mu0), I_1x1, model), | ||||
|        c1 = make_shared<GaussianConditional>(z, Vector1(mu1), I_1x1, model); | ||||
| 
 | ||||
|   auto gm = new GaussianMixture({z}, {}, {m}, {c0, c1}); | ||||
|   auto mixing = new DiscreteConditional(m, "0.5/0.5"); | ||||
|   GaussianMixture gm({z}, {}, {m}, {c0, c1}); | ||||
|   DiscreteConditional mixing(m, "0.5/0.5"); | ||||
| 
 | ||||
|   HybridBayesNet hbn; | ||||
|   hbn.emplace_back(gm); | ||||
|   hbn.emplace_back(mixing); | ||||
|   hbn.push_back(gm); | ||||
|   hbn.push_back(mixing); | ||||
| 
 | ||||
|   // The result should be a sigmoid.
 | ||||
|   // So should be m = 0.5 at z=3.0 - 1.0=2.0
 | ||||
|  | @ -237,7 +237,7 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel) { | |||
|   HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); | ||||
| 
 | ||||
|   HybridBayesNet expected; | ||||
|   expected.emplace_back(new DiscreteConditional(m, "0.5/0.5")); | ||||
|   expected.push_back(DiscreteConditional(m, "0.5/0.5")); | ||||
| 
 | ||||
|   EXPECT(assert_equal(expected, *bn)); | ||||
| } | ||||
|  | @ -265,12 +265,12 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel2) { | |||
|   auto c0 = make_shared<GaussianConditional>(z, Vector1(mu0), I_1x1, model0), | ||||
|        c1 = make_shared<GaussianConditional>(z, Vector1(mu1), I_1x1, model1); | ||||
| 
 | ||||
|   auto gm = new GaussianMixture({z}, {}, {m}, {c0, c1}); | ||||
|   auto mixing = new DiscreteConditional(m, "0.5/0.5"); | ||||
|   GaussianMixture gm({z}, {}, {m}, {c0, c1}); | ||||
|   DiscreteConditional mixing(m, "0.5/0.5"); | ||||
| 
 | ||||
|   HybridBayesNet hbn; | ||||
|   hbn.emplace_back(gm); | ||||
|   hbn.emplace_back(mixing); | ||||
|   hbn.push_back(gm); | ||||
|   hbn.push_back(mixing); | ||||
| 
 | ||||
|   // The result should be a sigmoid leaning towards model1
 | ||||
|   // since it has the tighter covariance.
 | ||||
|  | @ -281,8 +281,7 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel2) { | |||
|   HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); | ||||
| 
 | ||||
|   HybridBayesNet expected; | ||||
|   expected.emplace_back( | ||||
|       new DiscreteConditional(m, "0.338561851224/0.661438148776")); | ||||
|   expected.push_back(DiscreteConditional(m, "0.338561851224/0.661438148776")); | ||||
| 
 | ||||
|   EXPECT(assert_equal(expected, *bn)); | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue