Further optimize the implementation of BayesTreeMarginalizationHelper:
Now we won't re-emilinate any unnecessary nodes (we re-emilinated whole subtrees in the previous commits, which is not optimal)release/4.3a0
							parent
							
								
									14c3467520
								
							
						
					
					
						commit
						1a5e711f0e
					
				|  | @ -21,6 +21,7 @@ | ||||||
| #pragma once | #pragma once | ||||||
| 
 | 
 | ||||||
| #include <unordered_map> | #include <unordered_map> | ||||||
|  | #include <deque> | ||||||
| #include <gtsam/inference/BayesTree.h> | #include <gtsam/inference/BayesTree.h> | ||||||
| #include <gtsam/inference/BayesTreeCliqueBase.h> | #include <gtsam/inference/BayesTreeCliqueBase.h> | ||||||
| #include <gtsam/base/debug.h> | #include <gtsam/base/debug.h> | ||||||
|  | @ -50,9 +51,12 @@ public: | ||||||
|    * 2. Or it has a child node depending on a marginalizable variable AND the |    * 2. Or it has a child node depending on a marginalizable variable AND the | ||||||
|    *    subtree rooted at that child contains non-marginalizables. |    *    subtree rooted at that child contains non-marginalizables. | ||||||
|    *  |    *  | ||||||
|    * In addition, the subtrees under the aforementioned cliques that require |    * In addition, for any descendant node depending on a marginalizable | ||||||
|    * re-elimination, which contain non-marginalizable variables in their root |    * variable, if the subtree rooted at that descendant contains | ||||||
|    * node, also need to be re-eliminated. |    * non-marginalizable variables (i.e., it lies on a path from one of the | ||||||
|  |    * aforementioned cliques that require re-elimination to a node containing | ||||||
|  |    * non-marginalizable variables at the leaf side), then it also needs to | ||||||
|  |    * be re-eliminated. | ||||||
|    *  |    *  | ||||||
|    * @param[in] bayesTree The Bayes tree |    * @param[in] bayesTree The Bayes tree | ||||||
|    * @param[in] marginalizableKeys Keys to be marginalized |    * @param[in] marginalizableKeys Keys to be marginalized | ||||||
|  | @ -66,7 +70,7 @@ public: | ||||||
|     std::set<Key> additionalKeys; |     std::set<Key> additionalKeys; | ||||||
|     std::set<Key> marginalizableKeySet( |     std::set<Key> marginalizableKeySet( | ||||||
|         marginalizableKeys.begin(), marginalizableKeys.end()); |         marginalizableKeys.begin(), marginalizableKeys.end()); | ||||||
|     std::set<sharedClique> dependentSubtrees; |     std::set<sharedClique> dependentCliques; | ||||||
|     CachedSearch cachedSearch; |     CachedSearch cachedSearch; | ||||||
| 
 | 
 | ||||||
|     // Check each clique that contains a marginalizable key
 |     // Check each clique that contains a marginalizable key
 | ||||||
|  | @ -77,17 +81,14 @@ public: | ||||||
|         // Add frontal variables from current clique
 |         // Add frontal variables from current clique
 | ||||||
|         addCliqueToKeySet(clique, &additionalKeys); |         addCliqueToKeySet(clique, &additionalKeys); | ||||||
| 
 | 
 | ||||||
|         // Then gather dependent subtrees to be added later
 |         // Then add the dependent cliques
 | ||||||
|         gatherDependentSubtrees( |         for (const sharedClique& dependent : | ||||||
|             clique, marginalizableKeySet, &dependentSubtrees, &cachedSearch); |              gatherDependentCliques(clique, marginalizableKeySet, &cachedSearch)) { | ||||||
|  |           addCliqueToKeySet(dependent, &additionalKeys); | ||||||
|  |         } | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // Add the remaining dependent cliques
 |  | ||||||
|     for (const sharedClique& subtree : dependentSubtrees) { |  | ||||||
|       addSubtreeToKeySet(subtree, &additionalKeys); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     if (debug) { |     if (debug) { | ||||||
|       std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; |       std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; | ||||||
|       for (const Key& key : additionalKeys) { |       for (const Key& key : additionalKeys) { | ||||||
|  | @ -219,53 +220,53 @@ public: | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * Gather all subtrees that depend on a marginalizable key and contain |    * Gather all dependent nodes that lie on a path from the root clique | ||||||
|    * non-marginalizable variables in their root. |    * to a clique containing a non-marginalizable variable at the leaf side. | ||||||
|    * |    * | ||||||
|    * @param[in] rootClique The starting clique |    * @param[in] rootClique The root clique | ||||||
|    * @param[in] marginalizableKeys Set of keys to be marginalized |    * @param[in] marginalizableKeys Set of keys to be marginalized | ||||||
|    * @param[out] dependentSubtrees Pointer to set storing dependent cliques |  | ||||||
|    */ |    */ | ||||||
|   static void gatherDependentSubtrees( |   static std::set<sharedClique> gatherDependentCliques( | ||||||
|       const sharedClique& rootClique, |       const sharedClique& rootClique, | ||||||
|       const std::set<Key>& marginalizableKeys, |       const std::set<Key>& marginalizableKeys, | ||||||
|       std::set<sharedClique>* dependentSubtrees, |  | ||||||
|       CachedSearch* cache) { |       CachedSearch* cache) { | ||||||
|     for (Key key : rootClique->conditional()->frontals()) { |     std::vector<sharedClique> dependentChildren; | ||||||
|       if (marginalizableKeys.count(key)) { |     dependentChildren.reserve(rootClique->children.size()); | ||||||
|         // Find children that depend on this key
 |     for (const sharedClique& child : rootClique->children) { | ||||||
|         for (const sharedClique& child : rootClique->children) { |       if (hasDependency(child, marginalizableKeys)) { | ||||||
|           if (!dependentSubtrees->count(child) && |         dependentChildren.push_back(child); | ||||||
|               hasDependency(child, key)) { |  | ||||||
|             getSubtreesContainingNonMarginalizables( |  | ||||||
|                 child, marginalizableKeys, cache, dependentSubtrees); |  | ||||||
|           } |  | ||||||
|         } |  | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|  |     return gatherDependentCliquesFromChildren(dependentChildren, marginalizableKeys, cache); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|    * Gather all subtrees that contain non-marginalizable variables in its root. |    * A helper function for the above gatherDependentCliques(). | ||||||
|    */ |    */ | ||||||
|   static void getSubtreesContainingNonMarginalizables( |   static std::set<sharedClique> gatherDependentCliquesFromChildren( | ||||||
|       const sharedClique& rootClique, |       const std::vector<sharedClique>& dependentChildren, | ||||||
|       const std::set<Key>& marginalizableKeys, |       const std::set<Key>& marginalizableKeys, | ||||||
|       CachedSearch* cache, |       CachedSearch* cache) { | ||||||
|       std::set<sharedClique>* subtreesContainingNonMarginalizables) { |     std::deque<sharedClique> descendants( | ||||||
|     // If the root clique itself contains non-marginalizable variables, we
 |         dependentChildren.begin(), dependentChildren.end()); | ||||||
|     // just add it to subtreesContainingNonMarginalizables;    
 |     std::set<sharedClique> dependentCliques; | ||||||
|     if (!isWholeCliqueMarginalizable(rootClique, marginalizableKeys, cache)) { |     while (!descendants.empty()) { | ||||||
|       subtreesContainingNonMarginalizables->insert(rootClique); |       sharedClique descendant = descendants.front(); | ||||||
|       return; |       descendants.pop_front(); | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     // Otherwise, we need to recursively check the children
 |       // If the subtree rooted at this descendant contains non-marginalizables,
 | ||||||
|     for (const sharedClique& child : rootClique->children) { |       // it must lie on a path from the root clique to a clique containing
 | ||||||
|       getSubtreesContainingNonMarginalizables( |       // non-marginalizables at the leaf side.
 | ||||||
|           child, marginalizableKeys, cache, |       if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) { | ||||||
|           subtreesContainingNonMarginalizables); |         dependentCliques.insert(descendant); | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       // Add all children of the current descendant to the set descendants.
 | ||||||
|  |       for (const sharedClique& child : descendant->children) { | ||||||
|  |         descendants.push_back(child); | ||||||
|  |       } | ||||||
|     } |     } | ||||||
|  |     return dependentCliques; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /**
 |   /**
 | ||||||
|  | @ -282,28 +283,6 @@ public: | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /**
 |  | ||||||
|    * Add all frontal variables from a subtree to a key set. |  | ||||||
|    * |  | ||||||
|    * @param[in] subRoot Root clique of the subtree |  | ||||||
|    * @param[out] additionalKeys Pointer to the output key set |  | ||||||
|    */ |  | ||||||
|   static void addSubtreeToKeySet( |  | ||||||
|       const sharedClique& subRoot, |  | ||||||
|       std::set<Key>* additionalKeys) { |  | ||||||
|     std::set<sharedClique> cliques; |  | ||||||
|     cliques.insert(subRoot); |  | ||||||
|     while(!cliques.empty()) { |  | ||||||
|       auto begin = cliques.begin(); |  | ||||||
|       sharedClique clique = *begin; |  | ||||||
|       cliques.erase(begin); |  | ||||||
|       addCliqueToKeySet(clique, additionalKeys); |  | ||||||
|       for (const sharedClique& child : clique->children) { |  | ||||||
|         cliques.insert(child); |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /**
 |   /**
 | ||||||
|    * Check if the clique depends on the given key. |    * Check if the clique depends on the given key. | ||||||
|    *  |    *  | ||||||
|  | @ -322,6 +301,19 @@ public: | ||||||
|       return false; |       return false; | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
|  | 
 | ||||||
|  |   /**
 | ||||||
|  |    * Check if the clique depends on any of the given keys. | ||||||
|  |    */ | ||||||
|  |   static bool hasDependency( | ||||||
|  |       const sharedClique& clique, const std::set<Key>& keys) { | ||||||
|  |     for (Key key : keys) { | ||||||
|  |       if (hasDependency(clique, key)) { | ||||||
|  |         return true; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
| }; | }; | ||||||
| // BayesTreeMarginalizationHelper
 | // BayesTreeMarginalizationHelper
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue