HybridBayesNet::optimize
							parent
							
								
									a6101b2d8f
								
							
						
					
					
						commit
						f0df82ac04
					
				|  | @ -16,8 +16,8 @@ | ||||||
|  */ |  */ | ||||||
| 
 | 
 | ||||||
| #include <gtsam/hybrid/HybridBayesNet.h> | #include <gtsam/hybrid/HybridBayesNet.h> | ||||||
| #include <gtsam/hybrid/HybridValues.h> |  | ||||||
| #include <gtsam/hybrid/HybridLookupDAG.h> | #include <gtsam/hybrid/HybridLookupDAG.h> | ||||||
|  | #include <gtsam/hybrid/HybridValues.h> | ||||||
| 
 | 
 | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
| 
 | 
 | ||||||
|  | @ -112,13 +112,12 @@ HybridBayesNet HybridBayesNet::prune( | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const { | GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const { | ||||||
|   return boost::dynamic_pointer_cast<GaussianMixture>(factors_.at(i)->inner()); |   return factors_.at(i)->asMixture(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { | DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { | ||||||
|   return boost::dynamic_pointer_cast<DiscreteConditional>( |   return factors_.at(i)->asDiscreteConditional(); | ||||||
|       factors_.at(i)->inner()); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
|  | @ -138,4 +137,10 @@ HybridValues HybridBayesNet::optimize() const { | ||||||
|   return dag.argmax(); |   return dag.argmax(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /* *******************************************************************************/ | ||||||
|  | VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { | ||||||
|  |   GaussianBayesNet gbn = this->choose(assignment); | ||||||
|  |   return gbn.optimize(); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| }  // namespace gtsam
 | }  // namespace gtsam
 | ||||||
|  |  | ||||||
|  | @ -72,6 +72,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | ||||||
|   /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
 |   /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
 | ||||||
|   /// put this method there?
 |   /// put this method there?
 | ||||||
|   HybridValues optimize() const; |   HybridValues optimize() const; | ||||||
|  | 
 | ||||||
|  |   /**
 | ||||||
|  |    * @brief Given the discrete assignment, return the optimized estimate for the | ||||||
|  |    * selected Gaussian BayesNet. | ||||||
|  |    * | ||||||
|  |    * @param assignment An assignment of discrete values. | ||||||
|  |    * @return Values | ||||||
|  |    */ | ||||||
|  |   VectorValues optimize(const DiscreteValues &assignment) const; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| }  // namespace gtsam
 | }  // namespace gtsam
 | ||||||
|  |  | ||||||
|  | @ -69,7 +69,7 @@ class GTSAM_EXPORT HybridConditional | ||||||
|       BaseConditional;  ///< Typedef to our conditional base class
 |       BaseConditional;  ///< Typedef to our conditional base class
 | ||||||
| 
 | 
 | ||||||
|  protected: |  protected: | ||||||
|   // Type-erased pointer to the inner type
 |   /// Type-erased pointer to the inner type
 | ||||||
|   boost::shared_ptr<Factor> inner_; |   boost::shared_ptr<Factor> inner_; | ||||||
| 
 | 
 | ||||||
|  public: |  public: | ||||||
|  | @ -127,8 +127,7 @@ class GTSAM_EXPORT HybridConditional | ||||||
|    * @param gaussianMixture Gaussian Mixture Conditional used to create the |    * @param gaussianMixture Gaussian Mixture Conditional used to create the | ||||||
|    * HybridConditional. |    * HybridConditional. | ||||||
|    */ |    */ | ||||||
|   HybridConditional( |   HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture); | ||||||
|       boost::shared_ptr<GaussianMixture> gaussianMixture); |  | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * @brief Return HybridConditional as a GaussianMixture |    * @brief Return HybridConditional as a GaussianMixture | ||||||
|  | @ -168,10 +167,10 @@ class GTSAM_EXPORT HybridConditional | ||||||
|   /// Get the type-erased pointer to the inner type
 |   /// Get the type-erased pointer to the inner type
 | ||||||
|   boost::shared_ptr<Factor> inner() { return inner_; } |   boost::shared_ptr<Factor> inner() { return inner_; } | ||||||
| 
 | 
 | ||||||
| };  // DiscreteConditional
 | };  // HybridConditional
 | ||||||
| 
 | 
 | ||||||
| // traits
 | // traits
 | ||||||
| template <> | template <> | ||||||
| struct traits<HybridConditional> : public Testable<DiscreteConditional> {}; | struct traits<HybridConditional> : public Testable<HybridConditional> {}; | ||||||
| 
 | 
 | ||||||
| }  // namespace gtsam
 | }  // namespace gtsam
 | ||||||
|  |  | ||||||
|  | @ -85,6 +85,40 @@ TEST(HybridBayesNet, Choose) { | ||||||
|                       *gbn.at(3))); |                       *gbn.at(3))); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /* ****************************************************************************/ | ||||||
|  | // Test bayes net optimize
 | ||||||
|  | TEST(HybridBayesNet, Optimize) { | ||||||
|  |   Switching s(4); | ||||||
|  | 
 | ||||||
|  |   Ordering ordering; | ||||||
|  |   for (auto&& kvp : s.linearizationPoint) { | ||||||
|  |     ordering += kvp.key; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   HybridBayesNet::shared_ptr hybridBayesNet; | ||||||
|  |   HybridGaussianFactorGraph::shared_ptr remainingFactorGraph; | ||||||
|  |   std::tie(hybridBayesNet, remainingFactorGraph) = | ||||||
|  |       s.linearizedFactorGraph.eliminatePartialSequential(ordering); | ||||||
|  | 
 | ||||||
|  |   DiscreteValues assignment; | ||||||
|  |   assignment[M(1)] = 1; | ||||||
|  |   assignment[M(2)] = 1; | ||||||
|  |   assignment[M(3)] = 1; | ||||||
|  | 
 | ||||||
|  |   VectorValues delta = hybridBayesNet->optimize(assignment); | ||||||
|  | 
 | ||||||
|  |   // The linearization point has the same value as the key index,
 | ||||||
|  |   // e.g. X(1) = 1, X(2) = 2,
 | ||||||
|  |   // but the factors specify X(k) = k-1, so delta should be -1.
 | ||||||
|  |   VectorValues expected_delta; | ||||||
|  |   expected_delta.insert(make_pair(X(1), -Vector1::Ones())); | ||||||
|  |   expected_delta.insert(make_pair(X(2), -Vector1::Ones())); | ||||||
|  |   expected_delta.insert(make_pair(X(3), -Vector1::Ones())); | ||||||
|  |   expected_delta.insert(make_pair(X(4), -Vector1::Ones())); | ||||||
|  | 
 | ||||||
|  |   EXPECT(assert_equal(expected_delta, delta)); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| int main() { | int main() { | ||||||
|   TestResult tr; |   TestResult tr; | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue