Change threshold name
parent
3c10913c70
commit
9bae03a6fa
|
@ -75,7 +75,7 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
|
|||
// been pruned before. Another, possibly faster approach is branch and bound
|
||||
// search to find the K-best leaves and then create a single pruned conditional.
|
||||
DiscreteBayesNet DiscreteBayesNet::prune(
|
||||
size_t maxNrLeaves, const std::optional<double>& deadModeThreshold,
|
||||
size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
|
||||
DiscreteValues* fixedValues) const {
|
||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||
DiscreteConditional joint;
|
||||
|
@ -89,13 +89,13 @@ DiscreteBayesNet DiscreteBayesNet::prune(
|
|||
DiscreteValues deadModesValues;
|
||||
// If we have a dead mode threshold and discrete variables left after pruning,
|
||||
// then we run dead mode removal.
|
||||
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
|
||||
if (marginalThreshold.has_value() && pruned.keys().size() > 0) {
|
||||
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
|
||||
for (auto dkey : pruned.discreteKeys()) {
|
||||
const Vector probabilities = marginals.marginalProbabilities(dkey);
|
||||
|
||||
int index = -1;
|
||||
auto threshold = (probabilities.array() > *deadModeThreshold);
|
||||
auto threshold = (probabilities.array() > *marginalThreshold);
|
||||
// If atleast 1 value is non-zero, then we can find the index
|
||||
// Else if all are zero, index would be set to 0 which is incorrect
|
||||
if (!threshold.isZero()) {
|
||||
|
|
|
@ -128,12 +128,12 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
|
|||
* @brief Prune the Bayes net
|
||||
*
|
||||
* @param maxNrLeaves The maximum number of leaves to keep.
|
||||
* @param deadModeThreshold If given, threshold on marginals to prune variables.
|
||||
* @param marginalThreshold If given, threshold on marginals to prune variables.
|
||||
* @param fixedValues If given, return the fixed values removed.
|
||||
* @return A new DiscreteBayesNet with pruned conditionals.
|
||||
*/
|
||||
DiscreteBayesNet prune(size_t maxNrLeaves,
|
||||
const std::optional<double>& deadModeThreshold = {},
|
||||
const std::optional<double>& marginalThreshold = {},
|
||||
DiscreteValues* fixedValues = nullptr) const;
|
||||
|
||||
///@}
|
||||
|
|
|
@ -43,7 +43,8 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
|||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(
|
||||
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
|
||||
size_t maxNrLeaves, const std::optional<double> &marginalThreshold,
|
||||
DiscreteValues *fixedValues) const {
|
||||
#if GTSAM_HYBRID_TIMING
|
||||
gttic_(HybridPruning);
|
||||
#endif
|
||||
|
@ -52,14 +53,14 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
|
||||
// Prune discrete Bayes net
|
||||
DiscreteValues fixed;
|
||||
auto prunedBN = marginal.prune(maxNrLeaves, deadModeThreshold, &fixed);
|
||||
auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
|
||||
|
||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||
DiscreteConditional pruned;
|
||||
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
|
||||
|
||||
// Set the fixed values if requested.
|
||||
if (deadModeThreshold && fixedValues) {
|
||||
if (marginalThreshold && fixedValues) {
|
||||
*fixedValues = fixed;
|
||||
}
|
||||
|
||||
|
@ -71,7 +72,7 @@ HybridBayesNet HybridBayesNet::prune(
|
|||
if (conditional->isDiscrete()) continue;
|
||||
|
||||
// No-op if not a HybridGaussianConditional.
|
||||
if (deadModeThreshold) conditional = conditional->restrict(fixed);
|
||||
if (marginalThreshold) conditional = conditional->restrict(fixed);
|
||||
|
||||
// Now decide on type what to do:
|
||||
if (auto hgc = conditional->asHybrid()) {
|
||||
|
|
|
@ -217,16 +217,18 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
|
||||
*
|
||||
* @param maxNrLeaves Continuous values at which to compute the error.
|
||||
* @param deadModeThreshold The threshold to check the mode marginals against.
|
||||
* If greater than this threshold, the mode gets assigned that value and is
|
||||
* considered "dead" for hybrid elimination.
|
||||
* The mode can then be removed since it only has a single possible
|
||||
* assignment.
|
||||
* @param marginalThreshold The threshold to check the mode marginals against.
|
||||
* @param fixedValues The fixed values resulting from dead mode removal.
|
||||
*
|
||||
* @note If marginal greater than this threshold, the mode gets assigned that
|
||||
* value and is considered "dead" for hybrid elimination. The mode can then be
|
||||
* removed since it only has a single possible assignment.
|
||||
|
||||
* @return A pruned HybridBayesNet
|
||||
*/
|
||||
HybridBayesNet prune(
|
||||
size_t maxNrLeaves,
|
||||
const std::optional<double> &deadModeThreshold = {}) const;
|
||||
HybridBayesNet prune(size_t maxNrLeaves,
|
||||
const std::optional<double> &marginalThreshold = {},
|
||||
DiscreteValues *fixedValues = nullptr) const;
|
||||
|
||||
/**
|
||||
* @brief Error method using HybridValues which returns specific error for
|
||||
|
|
|
@ -81,7 +81,7 @@ void HybridSmoother::update(const HybridGaussianFactorGraph &graph,
|
|||
if (maxNrLeaves) {
|
||||
// `pruneBayesNet` sets the leaves with 0 in discreteFactor to nullptr in
|
||||
// all the conditionals with the same keys in bayesNetFragment.
|
||||
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, deadModeThreshold_);
|
||||
bayesNetFragment = bayesNetFragment.prune(*maxNrLeaves, marginalThreshold_);
|
||||
}
|
||||
|
||||
// Add the partial bayes net to the posterior bayes net.
|
||||
|
|
|
@ -30,18 +30,19 @@ class GTSAM_EXPORT HybridSmoother {
|
|||
HybridGaussianFactorGraph remainingFactorGraph_;
|
||||
|
||||
/// The threshold above which we make a decision about a mode.
|
||||
std::optional<double> deadModeThreshold_;
|
||||
std::optional<double> marginalThreshold_;
|
||||
DiscreteValues fixedValues_;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Constructor
|
||||
*
|
||||
* @param removeDeadModes Flag indicating whether to remove dead modes.
|
||||
* @param deadModeThreshold The threshold above which a mode gets assigned a
|
||||
* @param marginalThreshold The threshold above which a mode gets assigned a
|
||||
* value and is considered "dead". 0.99 is a good starting value.
|
||||
*/
|
||||
HybridSmoother(const std::optional<double> deadModeThreshold = {})
|
||||
: deadModeThreshold_(deadModeThreshold) {}
|
||||
HybridSmoother(const std::optional<double> marginalThreshold = {})
|
||||
: marginalThreshold_(marginalThreshold) {}
|
||||
|
||||
/**
|
||||
* Given new factors, perform an incremental update.
|
||||
|
|
Loading…
Reference in New Issue