Merge branch 'develop' into hybrid/tests
						commit
						153c12e18a
					
				|  | @ -8,7 +8,7 @@ | ||||||
| 
 | 
 | ||||||
| /**
 | /**
 | ||||||
|  * @file   HybridBayesNet.cpp |  * @file   HybridBayesNet.cpp | ||||||
|  * @brief  A bayes net of Gaussian Conditionals indexed by discrete keys. |  * @brief  A Bayes net of Gaussian Conditionals indexed by discrete keys. | ||||||
|  * @author Fan Jiang |  * @author Fan Jiang | ||||||
|  * @author Varun Agrawal |  * @author Varun Agrawal | ||||||
|  * @author Shangjie Xue |  * @author Shangjie Xue | ||||||
|  | @ -20,18 +20,20 @@ | ||||||
| #include <gtsam/hybrid/HybridBayesNet.h> | #include <gtsam/hybrid/HybridBayesNet.h> | ||||||
| #include <gtsam/hybrid/HybridValues.h> | #include <gtsam/hybrid/HybridValues.h> | ||||||
| 
 | 
 | ||||||
|  | // In Wrappers we have no access to this so have a default ready
 | ||||||
|  | static std::mt19937_64 kRandomNumberGenerator(42); | ||||||
|  | 
 | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { | DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { | ||||||
|   AlgebraicDecisionTree<Key> decisionTree; |   AlgebraicDecisionTree<Key> decisionTree; | ||||||
| 
 | 
 | ||||||
|   // The canonical decision tree factor which will get the discrete conditionals
 |   // The canonical decision tree factor which will get
 | ||||||
|   // added to it.
 |   // the discrete conditionals added to it.
 | ||||||
|   DecisionTreeFactor dtFactor; |   DecisionTreeFactor dtFactor; | ||||||
| 
 | 
 | ||||||
|   for (size_t i = 0; i < this->size(); i++) { |   for (auto &&conditional : *this) { | ||||||
|     HybridConditional::shared_ptr conditional = this->at(i); |  | ||||||
|     if (conditional->isDiscrete()) { |     if (conditional->isDiscrete()) { | ||||||
|       // Convert to a DecisionTreeFactor and add it to the main factor.
 |       // Convert to a DecisionTreeFactor and add it to the main factor.
 | ||||||
|       DecisionTreeFactor f(*conditional->asDiscreteConditional()); |       DecisionTreeFactor f(*conditional->asDiscreteConditional()); | ||||||
|  | @ -53,7 +55,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc( | ||||||
|     const DecisionTreeFactor &decisionTree, |     const DecisionTreeFactor &decisionTree, | ||||||
|     const HybridConditional &conditional) { |     const HybridConditional &conditional) { | ||||||
|   // Get the discrete keys as sets for the decision tree
 |   // Get the discrete keys as sets for the decision tree
 | ||||||
|   // and the gaussian mixture.
 |   // and the Gaussian mixture.
 | ||||||
|   auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); |   auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); | ||||||
|   auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys()); |   auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys()); | ||||||
| 
 | 
 | ||||||
|  | @ -62,7 +64,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc( | ||||||
|                     double probability) -> double { |                     double probability) -> double { | ||||||
|     // typecast so we can use this to get probability value
 |     // typecast so we can use this to get probability value
 | ||||||
|     DiscreteValues values(choices); |     DiscreteValues values(choices); | ||||||
|     // Case where the gaussian mixture has the same
 |     // Case where the Gaussian mixture has the same
 | ||||||
|     // discrete keys as the decision tree.
 |     // discrete keys as the decision tree.
 | ||||||
|     if (conditionalKeySet == decisionTreeKeySet) { |     if (conditionalKeySet == decisionTreeKeySet) { | ||||||
|       if (decisionTree(values) == 0) { |       if (decisionTree(values) == 0) { | ||||||
|  | @ -101,6 +103,7 @@ void HybridBayesNet::updateDiscreteConditionals( | ||||||
|     const DecisionTreeFactor::shared_ptr &prunedDecisionTree) { |     const DecisionTreeFactor::shared_ptr &prunedDecisionTree) { | ||||||
|   KeyVector prunedTreeKeys = prunedDecisionTree->keys(); |   KeyVector prunedTreeKeys = prunedDecisionTree->keys(); | ||||||
| 
 | 
 | ||||||
|  |   // Loop with index since we need it later.
 | ||||||
|   for (size_t i = 0; i < this->size(); i++) { |   for (size_t i = 0; i < this->size(); i++) { | ||||||
|     HybridConditional::shared_ptr conditional = this->at(i); |     HybridConditional::shared_ptr conditional = this->at(i); | ||||||
|     if (conditional->isDiscrete()) { |     if (conditional->isDiscrete()) { | ||||||
|  | @ -153,7 +156,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { | ||||||
|     if (conditional->isHybrid()) { |     if (conditional->isHybrid()) { | ||||||
|       GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture(); |       GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture(); | ||||||
| 
 | 
 | ||||||
|       // Make a copy of the gaussian mixture and prune it!
 |       // Make a copy of the Gaussian mixture and prune it!
 | ||||||
|       auto prunedGaussianMixture = |       auto prunedGaussianMixture = | ||||||
|           boost::make_shared<GaussianMixture>(*gaussianMixture); |           boost::make_shared<GaussianMixture>(*gaussianMixture); | ||||||
|       prunedGaussianMixture->prune(*decisionTree); |       prunedGaussianMixture->prune(*decisionTree); | ||||||
|  | @ -173,35 +176,35 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const { | GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const { | ||||||
|   return factors_.at(i)->asMixture(); |   return at(i)->asMixture(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const { | GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const { | ||||||
|   return factors_.at(i)->asGaussian(); |   return at(i)->asGaussian(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { | DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { | ||||||
|   return factors_.at(i)->asDiscreteConditional(); |   return at(i)->asDiscreteConditional(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| GaussianBayesNet HybridBayesNet::choose( | GaussianBayesNet HybridBayesNet::choose( | ||||||
|     const DiscreteValues &assignment) const { |     const DiscreteValues &assignment) const { | ||||||
|   GaussianBayesNet gbn; |   GaussianBayesNet gbn; | ||||||
|   for (size_t idx = 0; idx < size(); idx++) { |   for (auto &&conditional : *this) { | ||||||
|     if (factors_.at(idx)->isHybrid()) { |     if (conditional->isHybrid()) { | ||||||
|       // If factor is hybrid, select based on assignment.
 |       // If conditional is hybrid, select based on assignment.
 | ||||||
|       GaussianMixture gm = *this->atMixture(idx); |       GaussianMixture gm = *conditional->asMixture(); | ||||||
|       gbn.push_back(gm(assignment)); |       gbn.push_back(gm(assignment)); | ||||||
| 
 | 
 | ||||||
|     } else if (factors_.at(idx)->isContinuous()) { |     } else if (conditional->isContinuous()) { | ||||||
|       // If continuous only, add gaussian conditional.
 |       // If continuous only, add Gaussian conditional.
 | ||||||
|       gbn.push_back((this->atGaussian(idx))); |       gbn.push_back((conditional->asGaussian())); | ||||||
| 
 | 
 | ||||||
|     } else if (factors_.at(idx)->isDiscrete()) { |     } else if (conditional->isDiscrete()) { | ||||||
|       // If factor at `idx` is discrete-only, we simply continue.
 |       // If conditional is discrete-only, we simply continue.
 | ||||||
|       continue; |       continue; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  | @ -213,7 +216,7 @@ GaussianBayesNet HybridBayesNet::choose( | ||||||
| HybridValues HybridBayesNet::optimize() const { | HybridValues HybridBayesNet::optimize() const { | ||||||
|   // Solve for the MPE
 |   // Solve for the MPE
 | ||||||
|   DiscreteBayesNet discrete_bn; |   DiscreteBayesNet discrete_bn; | ||||||
|   for (auto &conditional : factors_) { |   for (auto &&conditional : *this) { | ||||||
|     if (conditional->isDiscrete()) { |     if (conditional->isDiscrete()) { | ||||||
|       discrete_bn.push_back(conditional->asDiscreteConditional()); |       discrete_bn.push_back(conditional->asDiscreteConditional()); | ||||||
|     } |     } | ||||||
|  | @ -238,6 +241,41 @@ VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { | ||||||
|   return gbn.optimize(); |   return gbn.optimize(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /* ************************************************************************* */ | ||||||
|  | HybridValues HybridBayesNet::sample(const HybridValues &given, | ||||||
|  |                                     std::mt19937_64 *rng) const { | ||||||
|  |   DiscreteBayesNet dbn; | ||||||
|  |   for (auto &&conditional : *this) { | ||||||
|  |     if (conditional->isDiscrete()) { | ||||||
|  |       // If conditional is discrete-only, we add to the discrete Bayes net.
 | ||||||
|  |       dbn.push_back(conditional->asDiscreteConditional()); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   // Sample a discrete assignment.
 | ||||||
|  |   const DiscreteValues assignment = dbn.sample(given.discrete()); | ||||||
|  |   // Select the continuous Bayes net corresponding to the assignment.
 | ||||||
|  |   GaussianBayesNet gbn = choose(assignment); | ||||||
|  |   // Sample from the Gaussian Bayes net.
 | ||||||
|  |   VectorValues sample = gbn.sample(given.continuous(), rng); | ||||||
|  |   return {assignment, sample}; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* ************************************************************************* */ | ||||||
|  | HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const { | ||||||
|  |   HybridValues given; | ||||||
|  |   return sample(given, rng); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* ************************************************************************* */ | ||||||
|  | HybridValues HybridBayesNet::sample(const HybridValues &given) const { | ||||||
|  |   return sample(given, &kRandomNumberGenerator); | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* ************************************************************************* */ | ||||||
|  | HybridValues HybridBayesNet::sample() const { | ||||||
|  |   return sample(&kRandomNumberGenerator); | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| double HybridBayesNet::error(const VectorValues &continuousValues, | double HybridBayesNet::error(const VectorValues &continuousValues, | ||||||
|                              const DiscreteValues &discreteValues) const { |                              const DiscreteValues &discreteValues) const { | ||||||
|  | @ -248,34 +286,28 @@ double HybridBayesNet::error(const VectorValues &continuousValues, | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| AlgebraicDecisionTree<Key> HybridBayesNet::error( | AlgebraicDecisionTree<Key> HybridBayesNet::error( | ||||||
|     const VectorValues &continuousValues) const { |     const VectorValues &continuousValues) const { | ||||||
|   AlgebraicDecisionTree<Key> error_tree; |   AlgebraicDecisionTree<Key> error_tree(0.0); | ||||||
| 
 | 
 | ||||||
|   // Iterate over each factor.
 |   // Iterate over each conditional.
 | ||||||
|   for (size_t idx = 0; idx < size(); idx++) { |   for (auto &&conditional : *this) { | ||||||
|     AlgebraicDecisionTree<Key> conditional_error; |     if (conditional->isHybrid()) { | ||||||
|  |       // If conditional is hybrid, select based on assignment and compute error.
 | ||||||
|  |       GaussianMixture::shared_ptr gm = conditional->asMixture(); | ||||||
|  |       AlgebraicDecisionTree<Key> conditional_error = | ||||||
|  |           gm->error(continuousValues); | ||||||
| 
 | 
 | ||||||
|     if (factors_.at(idx)->isHybrid()) { |       error_tree = error_tree + conditional_error; | ||||||
|       // If factor is hybrid, select based on assignment and compute error.
 |  | ||||||
|       GaussianMixture::shared_ptr gm = this->atMixture(idx); |  | ||||||
|       conditional_error = gm->error(continuousValues); |  | ||||||
| 
 | 
 | ||||||
|       // Assign for the first index, add error for subsequent ones.
 |     } else if (conditional->isContinuous()) { | ||||||
|       if (idx == 0) { |  | ||||||
|         error_tree = conditional_error; |  | ||||||
|       } else { |  | ||||||
|         error_tree = error_tree + conditional_error; |  | ||||||
|       } |  | ||||||
| 
 |  | ||||||
|     } else if (factors_.at(idx)->isContinuous()) { |  | ||||||
|       // If continuous only, get the (double) error
 |       // If continuous only, get the (double) error
 | ||||||
|       // and add it to the error_tree
 |       // and add it to the error_tree
 | ||||||
|       double error = this->atGaussian(idx)->error(continuousValues); |       double error = conditional->asGaussian()->error(continuousValues); | ||||||
|       // Add the computed error to every leaf of the error tree.
 |       // Add the computed error to every leaf of the error tree.
 | ||||||
|       error_tree = error_tree.apply( |       error_tree = error_tree.apply( | ||||||
|           [error](double leaf_value) { return leaf_value + error; }); |           [error](double leaf_value) { return leaf_value + error; }); | ||||||
| 
 | 
 | ||||||
|     } else if (factors_.at(idx)->isDiscrete()) { |     } else if (conditional->isDiscrete()) { | ||||||
|       // If factor at `idx` is discrete-only, we skip.
 |       // Conditional is discrete-only, we skip.
 | ||||||
|       continue; |       continue; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -8,7 +8,7 @@ | ||||||
| 
 | 
 | ||||||
| /**
 | /**
 | ||||||
|  * @file    HybridBayesNet.h |  * @file    HybridBayesNet.h | ||||||
|  * @brief   A bayes net of Gaussian Conditionals indexed by discrete keys. |  * @brief   A Bayes net of Gaussian Conditionals indexed by discrete keys. | ||||||
|  * @author  Varun Agrawal |  * @author  Varun Agrawal | ||||||
|  * @author  Fan Jiang |  * @author  Fan Jiang | ||||||
|  * @author  Frank Dellaert |  * @author  Frank Dellaert | ||||||
|  | @ -43,7 +43,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | ||||||
|   /// @name Standard Constructors
 |   /// @name Standard Constructors
 | ||||||
|   /// @{
 |   /// @{
 | ||||||
| 
 | 
 | ||||||
|   /** Construct empty bayes net */ |   /** Construct empty Bayes net */ | ||||||
|   HybridBayesNet() = default; |   HybridBayesNet() = default; | ||||||
| 
 | 
 | ||||||
|   /// @}
 |   /// @}
 | ||||||
|  | @ -120,7 +120,47 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> { | ||||||
|    */ |    */ | ||||||
|   DecisionTreeFactor::shared_ptr discreteConditionals() const; |   DecisionTreeFactor::shared_ptr discreteConditionals() const; | ||||||
| 
 | 
 | ||||||
|  public: |   /**
 | ||||||
|  |    * @brief Sample from an incomplete BayesNet, given missing variables. | ||||||
|  |    * | ||||||
|  |    * Example: | ||||||
|  |    *   std::mt19937_64 rng(42); | ||||||
|  |    *   VectorValues given = ...; | ||||||
|  |    *   auto sample = bn.sample(given, &rng); | ||||||
|  |    * | ||||||
|  |    * @param given Values of missing variables. | ||||||
|  |    * @param rng The pseudo-random number generator. | ||||||
|  |    * @return HybridValues | ||||||
|  |    */ | ||||||
|  |   HybridValues sample(const HybridValues &given, std::mt19937_64 *rng) const; | ||||||
|  | 
 | ||||||
|  |   /**
 | ||||||
|  |    * @brief Sample using ancestral sampling. | ||||||
|  |    * | ||||||
|  |    * Example: | ||||||
|  |    *   std::mt19937_64 rng(42); | ||||||
|  |    *   auto sample = bn.sample(&rng); | ||||||
|  |    * | ||||||
|  |    * @param rng The pseudo-random number generator. | ||||||
|  |    * @return HybridValues | ||||||
|  |    */ | ||||||
|  |   HybridValues sample(std::mt19937_64 *rng) const; | ||||||
|  | 
 | ||||||
|  |   /**
 | ||||||
|  |    * @brief Sample from an incomplete BayesNet, use default rng. | ||||||
|  |    * | ||||||
|  |    * @param given Values of missing variables. | ||||||
|  |    * @return HybridValues | ||||||
|  |    */ | ||||||
|  |   HybridValues sample(const HybridValues &given) const; | ||||||
|  | 
 | ||||||
|  |   /**
 | ||||||
|  |    * @brief Sample using ancestral sampling, use default rng. | ||||||
|  |    * | ||||||
|  |    * @return HybridValues | ||||||
|  |    */ | ||||||
|  |   HybridValues sample() const; | ||||||
|  | 
 | ||||||
|   /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
 |   /// Prune the Hybrid Bayes Net such that we have at most maxNrLeaves leaves.
 | ||||||
|   HybridBayesNet prune(size_t maxNrLeaves); |   HybridBayesNet prune(size_t maxNrLeaves); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -297,6 +297,75 @@ TEST(HybridBayesNet, Serialization) { | ||||||
|   EXPECT(equalsBinary<HybridBayesNet>(hbn)); |   EXPECT(equalsBinary<HybridBayesNet>(hbn)); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /* ****************************************************************************/ | ||||||
|  | // Test HybridBayesNet sampling.
 | ||||||
|  | TEST(HybridBayesNet, Sampling) { | ||||||
|  |   HybridNonlinearFactorGraph nfg; | ||||||
|  | 
 | ||||||
|  |   auto noise_model = noiseModel::Diagonal::Sigmas(Vector1(1.0)); | ||||||
|  |   auto zero_motion = | ||||||
|  |       boost::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model); | ||||||
|  |   auto one_motion = | ||||||
|  |       boost::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model); | ||||||
|  |   std::vector<NonlinearFactor::shared_ptr> factors = {zero_motion, one_motion}; | ||||||
|  |   nfg.emplace_nonlinear<PriorFactor<double>>(X(0), 0.0, noise_model); | ||||||
|  |   nfg.emplace_hybrid<MixtureFactor>( | ||||||
|  |       KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors); | ||||||
|  | 
 | ||||||
|  |   DiscreteKey mode(M(0), 2); | ||||||
|  |   auto discrete_prior = boost::make_shared<DiscreteDistribution>(mode, "1/1"); | ||||||
|  |   nfg.push_discrete(discrete_prior); | ||||||
|  | 
 | ||||||
|  |   Values initial; | ||||||
|  |   double z0 = 0.0, z1 = 1.0; | ||||||
|  |   initial.insert<double>(X(0), z0); | ||||||
|  |   initial.insert<double>(X(1), z1); | ||||||
|  | 
 | ||||||
|  |   // Create the factor graph from the nonlinear factor graph.
 | ||||||
|  |   HybridGaussianFactorGraph::shared_ptr fg = nfg.linearize(initial); | ||||||
|  |   // Eliminate into BN
 | ||||||
|  |   Ordering ordering = fg->getHybridOrdering(); | ||||||
|  |   HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering); | ||||||
|  | 
 | ||||||
|  |   // Set up sampling
 | ||||||
|  |   std::mt19937_64 gen(11); | ||||||
|  | 
 | ||||||
|  |   // Initialize containers for computing the mean values.
 | ||||||
|  |   vector<double> discrete_samples; | ||||||
|  |   VectorValues average_continuous; | ||||||
|  | 
 | ||||||
|  |   size_t num_samples = 1000; | ||||||
|  |   for (size_t i = 0; i < num_samples; i++) { | ||||||
|  |     // Sample
 | ||||||
|  |     HybridValues sample = bn->sample(&gen); | ||||||
|  | 
 | ||||||
|  |     discrete_samples.push_back(sample.discrete()[M(0)]); | ||||||
|  | 
 | ||||||
|  |     if (i == 0) { | ||||||
|  |       average_continuous.insert(sample.continuous()); | ||||||
|  |     } else { | ||||||
|  |       average_continuous += sample.continuous(); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   EXPECT_LONGS_EQUAL(2, average_continuous.size()); | ||||||
|  |   EXPECT_LONGS_EQUAL(num_samples, discrete_samples.size()); | ||||||
|  | 
 | ||||||
|  |   // Regressions don't work across platforms :-(
 | ||||||
|  |   // // regression for specific RNG seed
 | ||||||
|  |   // double discrete_sum =
 | ||||||
|  |   //     std::accumulate(discrete_samples.begin(), discrete_samples.end(),
 | ||||||
|  |   //                     decltype(discrete_samples)::value_type(0));
 | ||||||
|  |   // EXPECT_DOUBLES_EQUAL(0.477, discrete_sum / num_samples, 1e-9);
 | ||||||
|  | 
 | ||||||
|  |   // VectorValues expected;
 | ||||||
|  |   // expected.insert({X(0), Vector1(-0.0131207162712)});
 | ||||||
|  |   // expected.insert({X(1), Vector1(-0.499026377568)});
 | ||||||
|  |   // // regression for specific RNG seed
 | ||||||
|  |   // EXPECT(assert_equal(expected, average_continuous.scale(1.0 /
 | ||||||
|  |   // num_samples)));
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| int main() { | int main() { | ||||||
|   TestResult tr; |   TestResult tr; | ||||||
|  |  | ||||||
|  | @ -64,8 +64,9 @@ namespace gtsam { | ||||||
|     return sample(result, rng); |     return sample(result, rng); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   VectorValues GaussianBayesNet::sample(VectorValues result, |   VectorValues GaussianBayesNet::sample(const VectorValues& given, | ||||||
|                                         std::mt19937_64* rng) const { |                                         std::mt19937_64* rng) const { | ||||||
|  |     VectorValues result(given); | ||||||
|     // sample each node in reverse topological sort order (parents first)
 |     // sample each node in reverse topological sort order (parents first)
 | ||||||
|     for (auto cg : boost::adaptors::reverse(*this)) { |     for (auto cg : boost::adaptors::reverse(*this)) { | ||||||
|       const VectorValues sampled = cg->sample(result, rng); |       const VectorValues sampled = cg->sample(result, rng); | ||||||
|  | @ -79,7 +80,7 @@ namespace gtsam { | ||||||
|     return sample(&kRandomNumberGenerator); |     return sample(&kRandomNumberGenerator); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   VectorValues GaussianBayesNet::sample(VectorValues given) const { |   VectorValues GaussianBayesNet::sample(const VectorValues& given) const { | ||||||
|     return sample(given, &kRandomNumberGenerator); |     return sample(given, &kRandomNumberGenerator); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -110,13 +110,13 @@ namespace gtsam { | ||||||
|      *   VectorValues given = ...; |      *   VectorValues given = ...; | ||||||
|      *   auto sample = gbn.sample(given, &rng); |      *   auto sample = gbn.sample(given, &rng); | ||||||
|      */ |      */ | ||||||
|     VectorValues sample(VectorValues given, std::mt19937_64* rng) const; |     VectorValues sample(const VectorValues& given, std::mt19937_64* rng) const; | ||||||
| 
 | 
 | ||||||
|     /// Sample using ancestral sampling, use default rng
 |     /// Sample using ancestral sampling, use default rng
 | ||||||
|     VectorValues sample() const; |     VectorValues sample() const; | ||||||
| 
 | 
 | ||||||
|     /// Sample from an incomplete BayesNet, use default rng
 |     /// Sample from an incomplete BayesNet, use default rng
 | ||||||
|     VectorValues sample(VectorValues given) const; |     VectorValues sample(const VectorValues& given) const; | ||||||
| 
 | 
 | ||||||
|     /**
 |     /**
 | ||||||
|      * Return ordering corresponding to a topological sort. |      * Return ordering corresponding to a topological sort. | ||||||
|  |  | ||||||
|  | @ -299,14 +299,12 @@ double GaussianConditional::logDeterminant() const { | ||||||
|           "GaussianConditional::sample can only be called on single variable " |           "GaussianConditional::sample can only be called on single variable " | ||||||
|           "conditionals"); |           "conditionals"); | ||||||
|     } |     } | ||||||
|     if (!model_) { | 
 | ||||||
|       throw std::invalid_argument( |  | ||||||
|           "GaussianConditional::sample can only be called if a diagonal noise " |  | ||||||
|           "model was specified at construction."); |  | ||||||
|     } |  | ||||||
|     VectorValues solution = solve(parentsValues); |     VectorValues solution = solve(parentsValues); | ||||||
|     Key key = firstFrontalKey(); |     Key key = firstFrontalKey(); | ||||||
|     const Vector& sigmas = model_->sigmas(); |     // The vector of sigma values for sampling.
 | ||||||
|  |     // If no model, initialize sigmas to 1, else to model sigmas
 | ||||||
|  |     const Vector& sigmas = (!model_) ? Vector::Ones(rows()) : model_->sigmas(); | ||||||
|     solution[key] += Sampler::sampleDiagonal(sigmas, rng); |     solution[key] += Sampler::sampleDiagonal(sigmas, rng); | ||||||
|     return solution; |     return solution; | ||||||
|   } |   } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue