Change threshold name

release/4.3a0
Frank Dellaert 2025-01-30 08:57:04 -05:00
parent 3c10913c70
commit 9bae03a6fa
6 changed files with 26 additions and 22 deletions

View File

@ -75,7 +75,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
DiscreteBayesNet DiscreteBayesNet::prune(
size_t maxNrLeaves, const std::optional<double>& deadModeThreshold,
size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
DiscreteValues* fixedValues) const {
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
@ -89,13 +89,13 @@ DiscreteBayesNet DiscreteBayesNet::prune(
DiscreteValues deadModesValues;
// If we have a dead mode threshold and discrete variables left after pruning,
// then we run dead mode removal.
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
if (marginalThreshold.has_value() && pruned.keys().size() > 0) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) {
const Vector probabilities = marginals.marginalProbabilities(dkey);
int index = -1;
auto threshold = (probabilities.array() > *deadModeThreshold);
auto threshold = (probabilities.array() > *marginalThreshold);
// If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) {

View File

@ -128,12 +128,12 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
* @brief Prune the Bayes net
*
* @param maxNrLeaves The maximum number of leaves to keep.
* @param deadModeThreshold If given, threshold on marginals to prune variables.
* @param marginalThreshold If given, threshold on marginals to prune variables.
* @param fixedValues If given, return the fixed values removed.
* @return A new DiscreteBayesNet with pruned conditionals.
*/
DiscreteBayesNet prune(size_t maxNrLeaves,
const std::optional<double>& deadModeThreshold = {},
const std::optional<double>& marginalThreshold = {},
DiscreteValues* fixedValues = nullptr) const;
///@}

View File

@ -43,7 +43,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
size_t maxNrLeaves, const std::optional<double> &marginalThreshold,
DiscreteValues *fixedValues) const {
#if GTSAM_HYBRID_TIMING
gttic_(HybridPruning);
#endif
@ -52,14 +53,14 @@ HybridBayesNet HybridBayesNet::prune(
// Prune discrete Bayes net
DiscreteValues fixed;
auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed);
auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional pruned;
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
// Set the fixed values if requested.
if (deadModeThreshold && fixedValues) {
if (marginalThreshold && fixedValues) {
*fixedValues = fixed;
}
@ -71,7 +72,7 @@ HybridBayesNet HybridBayesNet::prune(
if (conditional->isDiscrete()) continue;
// No-op if not a HybridGaussianConditional.
if (deadModeThreshold) conditional = conditional->restrict(fixed);
if (marginalThreshold) conditional = conditional->restrict(fixed);
// Now decide on type what to do:
if (auto hgc = conditional->asHybrid()) {

View File

@ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
*
* @param maxNrLeaves Continuous values at which to compute the error.
* @param deadModeThreshold The threshold to check the mode marginals against.
* If greater than this threshold, the mode gets assigned that value and is
* considered "dead" for hybrid elimination.
* The mode can then be removed since it only has a single possible
* assignment.
* @param marginalThreshold The threshold to check the mode marginals against.
* @param fixedValues The fixed values resulting from dead mode removal.
*
* @note If marginal greater than this threshold, the mode gets assigned that
* value and is considered "dead" for hybrid elimination. The mode can then be
* removed since it only has a single possible assignment.
* @return A pruned HybridBayesNet
*/
HybridBayesNet prune(
size_t maxNrLeaves,
const std::optional<double> &deadModeThreshold = {}) const;
HybridBayesNet prune(size_t maxNrLeaves,
const std::optional<double> &marginalThreshold = {},
DiscreteValues *fixedValues = nullptr) const;
/**
* @brief Error method using HybridValues which returns specific error for

View File

@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
if (maxNrLeaves) {
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
// all the conditionals with the same keys in bayesNetFragment.
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_);
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_);
}
// Add the partial bayes net to the posterior bayes net.

View File

@ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother {
HybridGaussianFactorGraph remainingFactorGraph_;
/// The threshold above which we make a decision about a mode.
std::optional<double> deadModeThreshold_;
std::optional<double> marginalThreshold_;
DiscreteValues fixedValues_;
public:
/**
* @brief Constructor
*
* @param removeDeadModes Flag indicating whether to remove dead modes.
* @param deadModeThreshold The threshold above which a mode gets assigned a
* @param marginalThreshold The threshold above which a mode gets assigned a
* value and is considered "dead". 0.99 is a good starting value.
*/
HybridSmoother(const std::optional<double> deadModeThreshold = {})
: deadModeThreshold_(deadModeThreshold) {}
HybridSmoother(const std::optional<double> marginalThreshold = {})
: marginalThreshold_(marginalThreshold) {}
/**
* Given new factors, perform an incremental update.