pass DiscreteConditional& for pruning instead of shared_ptr
							parent
							
								
									b7bddde82b
								
							
						
					
					
						commit
						d6bc1e11a6
					
				|  | @ -53,8 +53,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { | ||||||
|   DiscreteConditional joint; |   DiscreteConditional joint; | ||||||
|   for (auto &&conditional : marginal) { |   for (auto &&conditional : marginal) { | ||||||
|     // The last discrete conditional may be a TableDistribution
 |     // The last discrete conditional may be a TableDistribution
 | ||||||
|     if (auto dtc = |     if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(conditional)) { | ||||||
|             std::dynamic_pointer_cast<TableDistribution>(conditional)) { |  | ||||||
|       DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor()); |       DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor()); | ||||||
|       joint = joint * dc; |       joint = joint * dc; | ||||||
|     } else { |     } else { | ||||||
|  | @ -81,7 +80,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { | ||||||
|   for (auto &&conditional : *this) { |   for (auto &&conditional : *this) { | ||||||
|     if (auto hgc = conditional->asHybrid()) { |     if (auto hgc = conditional->asHybrid()) { | ||||||
|       // Prune the hybrid Gaussian conditional!
 |       // Prune the hybrid Gaussian conditional!
 | ||||||
|       auto prunedHybridGaussianConditional = hgc->prune(pruned); |       auto prunedHybridGaussianConditional = hgc->prune(*pruned); | ||||||
| 
 | 
 | ||||||
|       // Type-erase and add to the pruned Bayes Net fragment.
 |       // Type-erase and add to the pruned Bayes Net fragment.
 | ||||||
|       result.push_back(prunedHybridGaussianConditional); |       result.push_back(prunedHybridGaussianConditional); | ||||||
|  |  | ||||||
|  | @ -236,7 +236,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { | ||||||
|         if (!hybridGaussianCond->pruned()) { |         if (!hybridGaussianCond->pruned()) { | ||||||
|           // Imperative
 |           // Imperative
 | ||||||
|           clique->conditional() = std::make_shared<HybridConditional>( |           clique->conditional() = std::make_shared<HybridConditional>( | ||||||
|               hybridGaussianCond->prune(parentData.prunedDiscreteProbs)); |               hybridGaussianCond->prune(*parentData.prunedDiscreteProbs)); | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
|       return parentData; |       return parentData; | ||||||
|  |  | ||||||
|  | @ -304,18 +304,18 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { | ||||||
| 
 | 
 | ||||||
| /* *******************************************************************************/ | /* *******************************************************************************/ | ||||||
| HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( | HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( | ||||||
|     const DiscreteConditional::shared_ptr &discreteProbs) const { |     const DiscreteConditional &discreteProbs) const { | ||||||
|   // Find keys in discreteProbs->keys() but not in this->keys():
 |   // Find keys in discreteProbs.keys() but not in this->keys():
 | ||||||
|   std::set<Key> mine(this->keys().begin(), this->keys().end()); |   std::set<Key> mine(this->keys().begin(), this->keys().end()); | ||||||
|   std::set<Key> theirs(discreteProbs->keys().begin(), |   std::set<Key> theirs(discreteProbs.keys().begin(), | ||||||
|                        discreteProbs->keys().end()); |                        discreteProbs.keys().end()); | ||||||
|   std::vector<Key> diff; |   std::vector<Key> diff; | ||||||
|   std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(), |   std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(), | ||||||
|                       std::back_inserter(diff)); |                       std::back_inserter(diff)); | ||||||
| 
 | 
 | ||||||
|   // Find maximum probability value for every combination of our keys.
 |   // Find maximum probability value for every combination of our keys.
 | ||||||
|   Ordering keys(diff); |   Ordering keys(diff); | ||||||
|   auto max = discreteProbs->max(keys); |   auto max = discreteProbs.max(keys); | ||||||
| 
 | 
 | ||||||
|   // Check the max value for every combination of our keys.
 |   // Check the max value for every combination of our keys.
 | ||||||
|   // If the max value is 0.0, we can prune the corresponding conditional.
 |   // If the max value is 0.0, we can prune the corresponding conditional.
 | ||||||
|  |  | ||||||
|  | @ -236,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional | ||||||
|    * @return Shared pointer to possibly a pruned HybridGaussianConditional |    * @return Shared pointer to possibly a pruned HybridGaussianConditional | ||||||
|    */ |    */ | ||||||
|   HybridGaussianConditional::shared_ptr prune( |   HybridGaussianConditional::shared_ptr prune( | ||||||
|       const DiscreteConditional::shared_ptr &discreteProbs) const; |       const DiscreteConditional &discreteProbs) const; | ||||||
| 
 | 
 | ||||||
|   /// Return true if the conditional has already been pruned.
 |   /// Return true if the conditional has already been pruned.
 | ||||||
|   bool pruned() const { return pruned_; } |   bool pruned() const { return pruned_; } | ||||||
|  |  | ||||||
|  | @ -261,8 +261,8 @@ TEST(HybridGaussianConditional, Prune) { | ||||||
|       potentials[i] = 1; |       potentials[i] = 1; | ||||||
|       const DecisionTreeFactor decisionTreeFactor(keys, potentials); |       const DecisionTreeFactor decisionTreeFactor(keys, potentials); | ||||||
|       // Prune the HybridGaussianConditional
 |       // Prune the HybridGaussianConditional
 | ||||||
|       const auto pruned = hgc.prune(std::make_shared<DiscreteConditional>( |       const auto pruned = | ||||||
|           keys.size(), decisionTreeFactor)); |           hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); | ||||||
|       // Check that the pruned HybridGaussianConditional has 1 conditional
 |       // Check that the pruned HybridGaussianConditional has 1 conditional
 | ||||||
|       EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); |       EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); | ||||||
|     } |     } | ||||||
|  | @ -272,8 +272,8 @@ TEST(HybridGaussianConditional, Prune) { | ||||||
|                                          0, 0, 0.5, 0}; |                                          0, 0, 0.5, 0}; | ||||||
|     const DecisionTreeFactor decisionTreeFactor(keys, potentials); |     const DecisionTreeFactor decisionTreeFactor(keys, potentials); | ||||||
| 
 | 
 | ||||||
|     const auto pruned = hgc.prune( |     const auto pruned = | ||||||
|         std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor)); |         hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); | ||||||
| 
 | 
 | ||||||
|     // Check that the pruned HybridGaussianConditional has 2 conditionals
 |     // Check that the pruned HybridGaussianConditional has 2 conditionals
 | ||||||
|     EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); |     EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); | ||||||
|  | @ -288,8 +288,8 @@ TEST(HybridGaussianConditional, Prune) { | ||||||
|                                          0,   0, 0.5, 0}; |                                          0,   0, 0.5, 0}; | ||||||
|     const DecisionTreeFactor decisionTreeFactor(keys, potentials); |     const DecisionTreeFactor decisionTreeFactor(keys, potentials); | ||||||
| 
 | 
 | ||||||
|     const auto pruned = hgc.prune( |     const auto pruned = | ||||||
|         std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor)); |         hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); | ||||||
| 
 | 
 | ||||||
|     // Check that the pruned HybridGaussianConditional has 3 conditionals
 |     // Check that the pruned HybridGaussianConditional has 3 conditionals
 | ||||||
|     EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); |     EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue