add improved versions of push_back for HybridBayesNet
							parent
							
								
									b54ed7209e
								
							
						
					
					
						commit
						351f0bd3a5
					
				|  | @ -33,6 +33,18 @@ namespace gtsam { | ||||||
|  * @ingroup hybrid |  * @ingroup hybrid | ||||||
|  */ |  */ | ||||||
| class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | ||||||
|  |   template <typename T> | ||||||
|  |   struct is_shared_ptr : std::false_type {}; | ||||||
|  |   template <typename T> | ||||||
|  |   struct is_shared_ptr<std::shared_ptr<T>> : std::true_type {}; | ||||||
|  | 
 | ||||||
|  |   /// Helper templates for checking if a type is a shared pointer or not
 | ||||||
|  |   template <typename T> | ||||||
|  |   using IsSharedPtr = typename std::enable_if<is_shared_ptr<T>::value>::type; | ||||||
|  |   template <typename T> | ||||||
|  |   using IsNotSharedPtr = | ||||||
|  |       typename std::enable_if<!is_shared_ptr<T>::value>::type; | ||||||
|  | 
 | ||||||
|  public: |  public: | ||||||
|   using Base = BayesNet<HybridConditional>; |   using Base = BayesNet<HybridConditional>; | ||||||
|   using This = HybridBayesNet; |   using This = HybridBayesNet; | ||||||
|  | @ -70,20 +82,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | ||||||
|     factors_.push_back(conditional); |     factors_.push_back(conditional); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /**
 |  | ||||||
|    * Preferred: add a conditional directly using a pointer. |  | ||||||
|    * |  | ||||||
|    * Examples: |  | ||||||
|    *   hbn.emplace_back(new GaussianMixture(...))); |  | ||||||
|    *   hbn.emplace_back(new GaussianConditional(...))); |  | ||||||
|    *   hbn.emplace_back(new DiscreteConditional(...))); |  | ||||||
|    */ |  | ||||||
|   template <class Conditional> |  | ||||||
|   void emplace_back(Conditional *conditional) { |  | ||||||
|     factors_.push_back(std::make_shared<HybridConditional>( |  | ||||||
|         std::shared_ptr<Conditional>(conditional))); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /**
 |   /**
 | ||||||
|    * Add a conditional using a shared_ptr, using implicit conversion to |    * Add a conditional using a shared_ptr, using implicit conversion to | ||||||
|    * a HybridConditional. |    * a HybridConditional. | ||||||
|  | @ -101,6 +99,54 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | ||||||
|         std::make_shared<HybridConditional>(std::move(conditional))); |         std::make_shared<HybridConditional>(std::move(conditional))); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   /**
 | ||||||
|  |    * @brief Add a conditional to the Bayes net. | ||||||
|  |    * Implicitly convert to a HybridConditional. | ||||||
|  |    * | ||||||
|  |    * E.g. | ||||||
|  |    * hbn.push_back(std::make_shared<DiscreteConditional>(m, "1/1")); | ||||||
|  |    * | ||||||
|  |    * @tparam CONDITIONAL Type of conditional. This is shared_ptr version. | ||||||
|  |    * @param conditional The conditional as a shared pointer. | ||||||
|  |    * @return IsSharedPtr<CONDITIONAL> | ||||||
|  |    */ | ||||||
|  |   template <class CONDITIONAL> | ||||||
|  |   IsSharedPtr<CONDITIONAL> push_back(const CONDITIONAL &conditional) { | ||||||
|  |     factors_.push_back(std::make_shared<HybridConditional>(conditional)); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /**
 | ||||||
|  |    * @brief Add a conditional to the Bayes net. | ||||||
|  |    * Implicitly convert to a HybridConditional. | ||||||
|  |    * | ||||||
|  |    * E.g. | ||||||
|  |    * hbn.push_back(DiscreteConditional(m, "1/1")); | ||||||
|  |    * hbn.push_back(GaussianConditional(X(0), Vector1(0.0), I_1x1)); | ||||||
|  |    * | ||||||
|  |    * @tparam CONDITIONAL Type of conditional. This is const ref version. | ||||||
|  |    * @param conditional The conditional as a const reference. | ||||||
|  |    * @return IsSharedPtr<CONDITIONAL> | ||||||
|  |    */ | ||||||
|  |   template <class CONDITIONAL> | ||||||
|  |   IsNotSharedPtr<CONDITIONAL> push_back(const CONDITIONAL &conditional) { | ||||||
|  |     auto cond_shared_ptr = std::make_shared<CONDITIONAL>(conditional); | ||||||
|  |     push_back(cond_shared_ptr); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /**
 | ||||||
|  |    * Preferred: add a conditional directly using a pointer. | ||||||
|  |    * | ||||||
|  |    * Examples: | ||||||
|  |    *   hbn.emplace_back(new GaussianMixture(...))); | ||||||
|  |    *   hbn.emplace_back(new GaussianConditional(...))); | ||||||
|  |    *   hbn.emplace_back(new DiscreteConditional(...))); | ||||||
|  |    */ | ||||||
|  |   template <class Conditional> | ||||||
|  |   void emplace_back(Conditional *conditional) { | ||||||
|  |     factors_.push_back(std::make_shared<HybridConditional>( | ||||||
|  |         std::shared_ptr<Conditional>(conditional))); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * @brief Get the Gaussian Bayes Net which corresponds to a specific discrete |    * @brief Get the Gaussian Bayes Net which corresponds to a specific discrete | ||||||
|    * value assignment. |    * value assignment. | ||||||
|  |  | ||||||
|  | @ -221,12 +221,12 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel) { | ||||||
|   auto c0 = make_shared<GaussianConditional>(z, Vector1(mu0), I_1x1, model), |   auto c0 = make_shared<GaussianConditional>(z, Vector1(mu0), I_1x1, model), | ||||||
|        c1 = make_shared<GaussianConditional>(z, Vector1(mu1), I_1x1, model); |        c1 = make_shared<GaussianConditional>(z, Vector1(mu1), I_1x1, model); | ||||||
| 
 | 
 | ||||||
|   auto gm = new GaussianMixture({z}, {}, {m}, {c0, c1}); |   GaussianMixture gm({z}, {}, {m}, {c0, c1}); | ||||||
|   auto mixing = new DiscreteConditional(m, "0.5/0.5"); |   DiscreteConditional mixing(m, "0.5/0.5"); | ||||||
| 
 | 
 | ||||||
|   HybridBayesNet hbn; |   HybridBayesNet hbn; | ||||||
|   hbn.emplace_back(gm); |   hbn.push_back(gm); | ||||||
|   hbn.emplace_back(mixing); |   hbn.push_back(mixing); | ||||||
| 
 | 
 | ||||||
|   // The result should be a sigmoid.
 |   // The result should be a sigmoid.
 | ||||||
|   // So should be m = 0.5 at z=3.0 - 1.0=2.0
 |   // So should be m = 0.5 at z=3.0 - 1.0=2.0
 | ||||||
|  | @ -237,7 +237,7 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel) { | ||||||
|   HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); |   HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); | ||||||
| 
 | 
 | ||||||
|   HybridBayesNet expected; |   HybridBayesNet expected; | ||||||
|   expected.emplace_back(new DiscreteConditional(m, "0.5/0.5")); |   expected.push_back(DiscreteConditional(m, "0.5/0.5")); | ||||||
| 
 | 
 | ||||||
|   EXPECT(assert_equal(expected, *bn)); |   EXPECT(assert_equal(expected, *bn)); | ||||||
| } | } | ||||||
|  | @ -265,12 +265,12 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel2) { | ||||||
|   auto c0 = make_shared<GaussianConditional>(z, Vector1(mu0), I_1x1, model0), |   auto c0 = make_shared<GaussianConditional>(z, Vector1(mu0), I_1x1, model0), | ||||||
|        c1 = make_shared<GaussianConditional>(z, Vector1(mu1), I_1x1, model1); |        c1 = make_shared<GaussianConditional>(z, Vector1(mu1), I_1x1, model1); | ||||||
| 
 | 
 | ||||||
|   auto gm = new GaussianMixture({z}, {}, {m}, {c0, c1}); |   GaussianMixture gm({z}, {}, {m}, {c0, c1}); | ||||||
|   auto mixing = new DiscreteConditional(m, "0.5/0.5"); |   DiscreteConditional mixing(m, "0.5/0.5"); | ||||||
| 
 | 
 | ||||||
|   HybridBayesNet hbn; |   HybridBayesNet hbn; | ||||||
|   hbn.emplace_back(gm); |   hbn.push_back(gm); | ||||||
|   hbn.emplace_back(mixing); |   hbn.push_back(mixing); | ||||||
| 
 | 
 | ||||||
|   // The result should be a sigmoid leaning towards model1
 |   // The result should be a sigmoid leaning towards model1
 | ||||||
|   // since it has the tighter covariance.
 |   // since it has the tighter covariance.
 | ||||||
|  | @ -281,8 +281,7 @@ TEST(GaussianMixtureFactor, GaussianMixtureModel2) { | ||||||
|   HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); |   HybridBayesNet::shared_ptr bn = gfg.eliminateSequential(); | ||||||
| 
 | 
 | ||||||
|   HybridBayesNet expected; |   HybridBayesNet expected; | ||||||
|   expected.emplace_back( |   expected.push_back(DiscreteConditional(m, "0.338561851224/0.661438148776")); | ||||||
|       new DiscreteConditional(m, "0.338561851224/0.661438148776")); |  | ||||||
| 
 | 
 | ||||||
|   EXPECT(assert_equal(expected, *bn)); |   EXPECT(assert_equal(expected, *bn)); | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue