From 98cdf1193facd7715cd37de002fb50762a3d2e2a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 29 Jan 2025 17:54:29 -0500 Subject: [PATCH] Fix pruning --- gtsam/hybrid/HybridBayesNet.cpp | 90 +++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 27 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index e64284a94..2efb8030e 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -49,6 +49,9 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { // search to find the K-best leaves and then create a single pruned conditional. HybridBayesNet HybridBayesNet::prune( size_t maxNrLeaves, const std::optional &deadModeThreshold) const { +#if GTSAM_HYBRID_TIMING + gttic_(HybridPruning); +#endif // Collect all the discrete conditionals. Could be small if already pruned. const DiscreteBayesNet marginal = discreteMarginal(); @@ -69,6 +72,10 @@ HybridBayesNet HybridBayesNet::prune( // If we have a dead mode threshold and discrete variables left after pruning, // then we run dead mode removal. if (deadModeThreshold.has_value() && pruned.keys().size() > 0) { +#if GTSAM_HYBRID_TIMING + gttic_(DeadModeRemoval); +#endif + DiscreteMarginals marginals(DiscreteFactorGraph{pruned}); for (auto dkey : pruned.discreteKeys()) { Vector probabilities = marginals.marginalProbabilities(dkey); @@ -89,24 +96,11 @@ HybridBayesNet HybridBayesNet::prune( // Remove the modes (imperative) pruned.removeDiscreteModes(deadModesValues); - /* - If the pruned discrete conditional has any keys left, - we add it to the HybridBayesNet. - If not, it means it is an orphan so we don't add this pruned joint, - and instead add only the marginals below. - */ - if (pruned.keys().size() > 0) { - result.emplace_shared(pruned); - } + GTSAM_PRINT(deadModesValues); - // Add the marginals for future factors - for (auto &&[key, _] : deadModesValues) { - result.push_back( - std::dynamic_pointer_cast(marginals(key))); - } - - } else { - result.emplace_shared(pruned); +#if GTSAM_HYBRID_TIMING + gttoc_(DeadModeRemoval); +#endif } /* To prune, we visitWith every leaf in the HybridGaussianConditional. @@ -122,20 +116,37 @@ HybridBayesNet HybridBayesNet::prune( if (auto hgc = conditional->asHybrid()) { // Prune the hybrid Gaussian conditional! auto prunedHybridGaussianConditional = hgc->prune(pruned); + if (!prunedHybridGaussianConditional) { + GTSAM_PRINT(marginal); + GTSAM_PRINT(pruned); + throw std::runtime_error( + "A HybridGaussianConditional had all its conditionals pruned"); + } if (deadModeThreshold.has_value()) { - KeyVector deadKeys, conditionalDiscreteKeys; - for (const auto &kv : deadModesValues) { - deadKeys.push_back(kv.first); + const auto &discreteParents = + prunedHybridGaussianConditional->discreteKeys(); + DiscreteValues deadParentValues; + DiscreteKeys liveParents; + for (const auto &key : discreteParents) { + auto it = deadModesValues.find(key.first); + if (it != deadModesValues.end()) + deadParentValues[key.first] = it->second; + else + liveParents.emplace_back(key); } - for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) { - conditionalDiscreteKeys.push_back(dkey.first); - } - // The discrete keys in the conditional are the same as the keys in the - // dead modes, then we just get the corresponding Gaussian conditional. - if (deadKeys == conditionalDiscreteKeys) { + // If so then we just get the corresponding Gaussian conditional: + if (deadParentValues.size() == discreteParents.size()) { + // print on how many discreteParents we are choosing: result.push_back( - prunedHybridGaussianConditional->choose(deadModesValues)); + prunedHybridGaussianConditional->choose(deadParentValues)); + } else if (liveParents.size() > 0) { + auto newTree = prunedHybridGaussianConditional->factors(); + for (auto &&[key, value] : deadModesValues) { + newTree = newTree.choose(key, value); + } + result.emplace_shared(liveParents, + newTree); } else { // Add as-is result.push_back(prunedHybridGaussianConditional); @@ -152,6 +163,31 @@ HybridBayesNet HybridBayesNet::prune( // We ignore DiscreteConditional as they are already pruned and added. } +#if GTSAM_HYBRID_TIMING + gttoc_(HybridPruning); +#endif + + if (deadModeThreshold.has_value()) { + /* + If the pruned discrete conditional has any keys left, + we add it to the HybridBayesNet. + If not, it means it is an orphan so we don't add this pruned joint, + and instead add only the marginals below. + */ + if (pruned.keys().size() > 0) { + result.emplace_shared(pruned); + } + + // Add the marginals for future factors + // for (auto &&[key, _] : deadModesValues) { + // result.push_back( + // std::dynamic_pointer_cast(marginals(key))); + // } + + } else { + result.emplace_shared(pruned); + } + return result; }