HybridBayesNet and HybridBayesTree both use similar pruning functions
							parent
							
								
									2225ecf442
								
							
						
					
					
						commit
						5e99cd7095
					
				|  | @ -22,14 +22,6 @@ | ||||||
| 
 | 
 | ||||||
| namespace gtsam { | namespace gtsam { | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ |  | ||||||
| /// Return the DiscreteKey vector as a set.
 |  | ||||||
| static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) { |  | ||||||
|   std::set<DiscreteKey> s; |  | ||||||
|   s.insert(dkeys.begin(), dkeys.end()); |  | ||||||
|   return s; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { | DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { | ||||||
|   AlgebraicDecisionTree<Key> decisionTree; |   AlgebraicDecisionTree<Key> decisionTree; | ||||||
|  | @ -49,63 +41,6 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { | ||||||
|   return boost::make_shared<DecisionTreeFactor>(dtFactor); |   return boost::make_shared<DecisionTreeFactor>(dtFactor); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /**
 |  | ||||||
|  * @brief Helper function to get the pruner functional. |  | ||||||
|  * |  | ||||||
|  * @param probDecisionTree The probability decision tree of only discrete keys. |  | ||||||
|  * @param discreteFactorKeySet Set of DiscreteKeys in probDecisionTree. |  | ||||||
|  * Pre-computed for efficiency. |  | ||||||
|  * @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture. |  | ||||||
|  * @return std::function<GaussianConditional::shared_ptr( |  | ||||||
|  * const Assignment<Key> &, const GaussianConditional::shared_ptr &)> |  | ||||||
|  */ |  | ||||||
| std::function<GaussianConditional::shared_ptr( |  | ||||||
|     const Assignment<Key> &, const GaussianConditional::shared_ptr &)> |  | ||||||
| PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree, |  | ||||||
|            const std::set<DiscreteKey> &discreteFactorKeySet, |  | ||||||
|            const std::set<DiscreteKey> &gaussianMixtureKeySet) { |  | ||||||
|   auto pruner = [&](const Assignment<Key> &choices, |  | ||||||
|                     const GaussianConditional::shared_ptr &conditional) |  | ||||||
|       -> GaussianConditional::shared_ptr { |  | ||||||
|     // typecast so we can use this to get probability value
 |  | ||||||
|     DiscreteValues values(choices); |  | ||||||
| 
 |  | ||||||
|     // Case where the gaussian mixture has the same
 |  | ||||||
|     // discrete keys as the decision tree.
 |  | ||||||
|     if (gaussianMixtureKeySet == discreteFactorKeySet) { |  | ||||||
|       if ((*probDecisionTree)(values) == 0.0) { |  | ||||||
|         // empty aka null pointer
 |  | ||||||
|         boost::shared_ptr<GaussianConditional> null; |  | ||||||
|         return null; |  | ||||||
|       } else { |  | ||||||
|         return conditional; |  | ||||||
|       } |  | ||||||
|     } else { |  | ||||||
|       std::vector<DiscreteKey> set_diff; |  | ||||||
|       std::set_difference( |  | ||||||
|           discreteFactorKeySet.begin(), discreteFactorKeySet.end(), |  | ||||||
|           gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(), |  | ||||||
|           std::back_inserter(set_diff)); |  | ||||||
|       const std::vector<DiscreteValues> assignments = |  | ||||||
|           DiscreteValues::CartesianProduct(set_diff); |  | ||||||
|       for (const DiscreteValues &assignment : assignments) { |  | ||||||
|         DiscreteValues augmented_values(values); |  | ||||||
|         augmented_values.insert(assignment.begin(), assignment.end()); |  | ||||||
| 
 |  | ||||||
|         // If any one of the sub-branches are non-zero,
 |  | ||||||
|         // we need this conditional.
 |  | ||||||
|         if ((*probDecisionTree)(augmented_values) > 0.0) { |  | ||||||
|           return conditional; |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|       // If we are here, it means that all the sub-branches are 0,
 |  | ||||||
|       // so we prune.
 |  | ||||||
|       return nullptr; |  | ||||||
|     } |  | ||||||
|   }; |  | ||||||
|   return pruner; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { | HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { | ||||||
|   // Get the decision tree of only the discrete keys
 |   // Get the decision tree of only the discrete keys
 | ||||||
|  | @ -114,8 +49,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { | ||||||
|       boost::make_shared<DecisionTreeFactor>( |       boost::make_shared<DecisionTreeFactor>( | ||||||
|           discreteConditionals->prune(maxNrLeaves)); |           discreteConditionals->prune(maxNrLeaves)); | ||||||
| 
 | 
 | ||||||
|   auto discreteFactorKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys()); |  | ||||||
| 
 |  | ||||||
|   /* To Prune, we visitWith every leaf in the GaussianMixture.
 |   /* To Prune, we visitWith every leaf in the GaussianMixture.
 | ||||||
|    * For each leaf, using the assignment we can check the discrete decision tree |    * For each leaf, using the assignment we can check the discrete decision tree | ||||||
|    * for 0.0 probability, then just set the leaf to a nullptr. |    * for 0.0 probability, then just set the leaf to a nullptr. | ||||||
|  | @ -130,35 +63,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const { | ||||||
|   for (size_t i = 0; i < this->size(); i++) { |   for (size_t i = 0; i < this->size(); i++) { | ||||||
|     HybridConditional::shared_ptr conditional = this->at(i); |     HybridConditional::shared_ptr conditional = this->at(i); | ||||||
| 
 | 
 | ||||||
|     GaussianMixture::shared_ptr gaussianMixture = |     if (conditional->isHybrid()) { | ||||||
|         boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner()); |       GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture(); | ||||||
| 
 | 
 | ||||||
|     if (gaussianMixture) { |       // Make a copy of the gaussian mixture and prune it!
 | ||||||
|       // We may have mixtures with less discrete keys than discreteFactor so
 |       auto prunedGaussianMixture = | ||||||
|       // we skip those since the label assignment does not exist.
 |           boost::make_shared<GaussianMixture>(*gaussianMixture); | ||||||
|       auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys()); |       prunedGaussianMixture->prune(*discreteFactor); | ||||||
| 
 |  | ||||||
|       // Get the pruner function.
 |  | ||||||
|       auto pruner = PrunerFunc(discreteFactor, discreteFactorKeySet, gmKeySet); |  | ||||||
| 
 |  | ||||||
|       // Run the pruning to get a new, pruned tree
 |  | ||||||
|       GaussianMixture::Conditionals prunedTree = |  | ||||||
|           gaussianMixture->conditionals().apply(pruner); |  | ||||||
| 
 |  | ||||||
|       DiscreteKeys discreteKeys = gaussianMixture->discreteKeys(); |  | ||||||
|       // reverse keys to get a natural ordering
 |  | ||||||
|       std::reverse(discreteKeys.begin(), discreteKeys.end()); |  | ||||||
| 
 |  | ||||||
|       // Convert from boost::iterator_range to KeyVector
 |  | ||||||
|       // so we can pass it to constructor.
 |  | ||||||
|       KeyVector frontals(gaussianMixture->frontals().begin(), |  | ||||||
|                          gaussianMixture->frontals().end()), |  | ||||||
|           parents(gaussianMixture->parents().begin(), |  | ||||||
|                   gaussianMixture->parents().end()); |  | ||||||
| 
 |  | ||||||
|       // Create the new gaussian mixture and add it to the bayes net.
 |  | ||||||
|       auto prunedGaussianMixture = boost::make_shared<GaussianMixture>( |  | ||||||
|           frontals, parents, discreteKeys, prunedTree); |  | ||||||
| 
 | 
 | ||||||
|       // Type-erase and add to the pruned Bayes Net fragment.
 |       // Type-erase and add to the pruned Bayes Net fragment.
 | ||||||
|       prunedBayesNetFragment.push_back( |       prunedBayesNetFragment.push_back( | ||||||
|  |  | ||||||
|  | @ -149,16 +149,16 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { | ||||||
|   auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>( |   auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>( | ||||||
|       this->roots_.at(0)->conditional()->inner()); |       this->roots_.at(0)->conditional()->inner()); | ||||||
| 
 | 
 | ||||||
|   DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves); |   DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); | ||||||
|   decisionTree->root_ = prunedDiscreteFactor.root_; |   decisionTree->root_ = prunedDecisionTree.root_; | ||||||
| 
 | 
 | ||||||
|   /// Helper struct for pruning the hybrid bayes tree.
 |   /// Helper struct for pruning the hybrid bayes tree.
 | ||||||
|   struct HybridPrunerData { |   struct HybridPrunerData { | ||||||
|     /// The discrete decision tree after pruning.
 |     /// The discrete decision tree after pruning.
 | ||||||
|     DecisionTreeFactor prunedDiscreteFactor; |     DecisionTreeFactor prunedDecisionTree; | ||||||
|     HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor, |     HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree, | ||||||
|                      const HybridBayesTree::sharedNode& parentClique) |                      const HybridBayesTree::sharedNode& parentClique) | ||||||
|         : prunedDiscreteFactor(prunedDiscreteFactor) {} |         : prunedDecisionTree(prunedDecisionTree) {} | ||||||
| 
 | 
 | ||||||
|     /**
 |     /**
 | ||||||
|      * @brief A function used during tree traversal that operates on each node |      * @brief A function used during tree traversal that operates on each node | ||||||
|  | @ -178,19 +178,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) { | ||||||
|       if (conditional->isHybrid()) { |       if (conditional->isHybrid()) { | ||||||
|         auto gaussianMixture = conditional->asMixture(); |         auto gaussianMixture = conditional->asMixture(); | ||||||
| 
 | 
 | ||||||
|         // Check if the number of discrete keys match,
 |         gaussianMixture->prune(parentData.prunedDecisionTree); | ||||||
|         // else we get an assignment error.
 |  | ||||||
|         // TODO(Varun) Update prune method to handle assignment subset?
 |  | ||||||
|         if (gaussianMixture->discreteKeys() == |  | ||||||
|             parentData.prunedDiscreteFactor.discreteKeys()) { |  | ||||||
|           gaussianMixture->prune(parentData.prunedDiscreteFactor); |  | ||||||
|         } |  | ||||||
|       } |       } | ||||||
|       return parentData; |       return parentData; | ||||||
|     } |     } | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
|   HybridPrunerData rootData(prunedDiscreteFactor, 0); |   HybridPrunerData rootData(prunedDecisionTree, 0); | ||||||
|   { |   { | ||||||
|     treeTraversal::no_op visitorPost; |     treeTraversal::no_op visitorPost; | ||||||
|     // Limits OpenMP threads since we're mixing TBB and OpenMP
 |     // Limits OpenMP threads since we're mixing TBB and OpenMP
 | ||||||
|  |  | ||||||
|  | @ -337,7 +337,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) { | ||||||
|   EXPECT_LONGS_EQUAL( |   EXPECT_LONGS_EQUAL( | ||||||
|       2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents()); |       2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents()); | ||||||
|   EXPECT_LONGS_EQUAL( |   EXPECT_LONGS_EQUAL( | ||||||
|       4, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents()); |       3, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents()); | ||||||
|   EXPECT_LONGS_EQUAL( |   EXPECT_LONGS_EQUAL( | ||||||
|       5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents()); |       5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents()); | ||||||
|   EXPECT_LONGS_EQUAL( |   EXPECT_LONGS_EQUAL( | ||||||
|  |  | ||||||
|  | @ -363,7 +363,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) { | ||||||
|   EXPECT_LONGS_EQUAL( |   EXPECT_LONGS_EQUAL( | ||||||
|       2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents()); |       2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents()); | ||||||
|   EXPECT_LONGS_EQUAL( |   EXPECT_LONGS_EQUAL( | ||||||
|       4, bayesTree[X(2)]->conditional()->asMixture()->nrComponents()); |       3, bayesTree[X(2)]->conditional()->asMixture()->nrComponents()); | ||||||
|   EXPECT_LONGS_EQUAL( |   EXPECT_LONGS_EQUAL( | ||||||
|       5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents()); |       5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents()); | ||||||
|   EXPECT_LONGS_EQUAL( |   EXPECT_LONGS_EQUAL( | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue