pass DiscreteConditional& for pruning instead of shared_ptr
							parent
							
								
									b7bddde82b
								
							
						
					
					
						commit
						d6bc1e11a6
					
				|  | @ -53,8 +53,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { | |||
|   DiscreteConditional joint; | ||||
|   for (auto &&conditional : marginal) { | ||||
|     // The last discrete conditional may be a TableDistribution
 | ||||
|     if (auto dtc = | ||||
|             std::dynamic_pointer_cast<TableDistribution>(conditional)) { | ||||
|     if (auto dtc = std::dynamic_pointer_cast<TableDistribution>(conditional)) { | ||||
|       DiscreteConditional dc(dtc->nrFrontals(), dtc->toDecisionTreeFactor()); | ||||
|       joint = joint * dc; | ||||
|     } else { | ||||
|  | @ -81,7 +80,7 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { | |||
|   for (auto &&conditional : *this) { | ||||
|     if (auto hgc = conditional->asHybrid()) { | ||||
|       // Prune the hybrid Gaussian conditional!
 | ||||
|       auto prunedHybridGaussianConditional = hgc->prune(pruned); | ||||
|       auto prunedHybridGaussianConditional = hgc->prune(*pruned); | ||||
| 
 | ||||
|       // Type-erase and add to the pruned Bayes Net fragment.
 | ||||
|       result.push_back(prunedHybridGaussianConditional); | ||||
|  |  | |||
|  | @ -236,7 +236,7 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { | |||
|         if (!hybridGaussianCond->pruned()) { | ||||
|           // Imperative
 | ||||
|           clique->conditional() = std::make_shared<HybridConditional>( | ||||
|               hybridGaussianCond->prune(parentData.prunedDiscreteProbs)); | ||||
|               hybridGaussianCond->prune(*parentData.prunedDiscreteProbs)); | ||||
|         } | ||||
|       } | ||||
|       return parentData; | ||||
|  |  | |||
|  | @ -304,18 +304,18 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) { | |||
| 
 | ||||
| /* *******************************************************************************/ | ||||
| HybridGaussianConditional::shared_ptr HybridGaussianConditional::prune( | ||||
|     const DiscreteConditional::shared_ptr &discreteProbs) const { | ||||
|   // Find keys in discreteProbs->keys() but not in this->keys():
 | ||||
|     const DiscreteConditional &discreteProbs) const { | ||||
|   // Find keys in discreteProbs.keys() but not in this->keys():
 | ||||
|   std::set<Key> mine(this->keys().begin(), this->keys().end()); | ||||
|   std::set<Key> theirs(discreteProbs->keys().begin(), | ||||
|                        discreteProbs->keys().end()); | ||||
|   std::set<Key> theirs(discreteProbs.keys().begin(), | ||||
|                        discreteProbs.keys().end()); | ||||
|   std::vector<Key> diff; | ||||
|   std::set_difference(theirs.begin(), theirs.end(), mine.begin(), mine.end(), | ||||
|                       std::back_inserter(diff)); | ||||
| 
 | ||||
|   // Find maximum probability value for every combination of our keys.
 | ||||
|   Ordering keys(diff); | ||||
|   auto max = discreteProbs->max(keys); | ||||
|   auto max = discreteProbs.max(keys); | ||||
| 
 | ||||
|   // Check the max value for every combination of our keys.
 | ||||
|   // If the max value is 0.0, we can prune the corresponding conditional.
 | ||||
|  |  | |||
|  | @ -236,7 +236,7 @@ class GTSAM_EXPORT HybridGaussianConditional | |||
|    * @return Shared pointer to possibly a pruned HybridGaussianConditional | ||||
|    */ | ||||
|   HybridGaussianConditional::shared_ptr prune( | ||||
|       const DiscreteConditional::shared_ptr &discreteProbs) const; | ||||
|       const DiscreteConditional &discreteProbs) const; | ||||
| 
 | ||||
|   /// Return true if the conditional has already been pruned.
 | ||||
|   bool pruned() const { return pruned_; } | ||||
|  |  | |||
|  | @ -261,8 +261,8 @@ TEST(HybridGaussianConditional, Prune) { | |||
|       potentials[i] = 1; | ||||
|       const DecisionTreeFactor decisionTreeFactor(keys, potentials); | ||||
|       // Prune the HybridGaussianConditional
 | ||||
|       const auto pruned = hgc.prune(std::make_shared<DiscreteConditional>( | ||||
|           keys.size(), decisionTreeFactor)); | ||||
|       const auto pruned = | ||||
|           hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); | ||||
|       // Check that the pruned HybridGaussianConditional has 1 conditional
 | ||||
|       EXPECT_LONGS_EQUAL(1, pruned->nrComponents()); | ||||
|     } | ||||
|  | @ -272,8 +272,8 @@ TEST(HybridGaussianConditional, Prune) { | |||
|                                          0, 0, 0.5, 0}; | ||||
|     const DecisionTreeFactor decisionTreeFactor(keys, potentials); | ||||
| 
 | ||||
|     const auto pruned = hgc.prune( | ||||
|         std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor)); | ||||
|     const auto pruned = | ||||
|         hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); | ||||
| 
 | ||||
|     // Check that the pruned HybridGaussianConditional has 2 conditionals
 | ||||
|     EXPECT_LONGS_EQUAL(2, pruned->nrComponents()); | ||||
|  | @ -288,8 +288,8 @@ TEST(HybridGaussianConditional, Prune) { | |||
|                                          0,   0, 0.5, 0}; | ||||
|     const DecisionTreeFactor decisionTreeFactor(keys, potentials); | ||||
| 
 | ||||
|     const auto pruned = hgc.prune( | ||||
|         std::make_shared<DiscreteConditional>(keys.size(), decisionTreeFactor)); | ||||
|     const auto pruned = | ||||
|         hgc.prune(DiscreteConditional(keys.size(), decisionTreeFactor)); | ||||
| 
 | ||||
|     // Check that the pruned HybridGaussianConditional has 3 conditionals
 | ||||
|     EXPECT_LONGS_EQUAL(3, pruned->nrComponents()); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue