HybridBayesNet and HybridBayesTree both use similar pruning functions

release/4.3a0
Varun Agrawal 2022-10-11 12:36:58 -04:00
parent 2225ecf442
commit 5e99cd7095
4 changed files with 15 additions and 110 deletions

View File

@ -22,14 +22,6 @@
namespace gtsam {
/* ************************************************************************* */
/// Return the DiscreteKey vector as a set.
static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
std::set<DiscreteKey> s;
s.insert(dkeys.begin(), dkeys.end());
return s;
}
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree;
@ -49,63 +41,6 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
return boost::make_shared<DecisionTreeFactor>(dtFactor);
}
/**
* @brief Helper function to get the pruner functional.
*
* @param probDecisionTree The probability decision tree of only discrete keys.
* @param discreteFactorKeySet Set of DiscreteKeys in probDecisionTree.
* Pre-computed for efficiency.
* @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree,
const std::set<DiscreteKey> &discreteFactorKeySet,
const std::set<DiscreteKey> &gaussianMixtureKeySet) {
auto pruner = [&](const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the gaussian mixture has the same
// discrete keys as the decision tree.
if (gaussianMixtureKeySet == discreteFactorKeySet) {
if ((*probDecisionTree)(values) == 0.0) {
// empty aka null pointer
boost::shared_ptr<GaussianConditional> null;
return null;
} else {
return conditional;
}
} else {
std::vector<DiscreteKey> set_diff;
std::set_difference(
discreteFactorKeySet.begin(), discreteFactorKeySet.end(),
gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
augmented_values.insert(assignment.begin(), assignment.end());
// If any one of the sub-branches are non-zero,
// we need this conditional.
if ((*probDecisionTree)(augmented_values) > 0.0) {
return conditional;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return nullptr;
}
};
return pruner;
}
/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Get the decision tree of only the discrete keys
@ -114,8 +49,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
boost::make_shared<DecisionTreeFactor>(
discreteConditionals->prune(maxNrLeaves));
auto discreteFactorKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys());
/* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
@ -130,35 +63,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
GaussianMixture::shared_ptr gaussianMixture =
boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner());
if (conditional->isHybrid()) {
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
if (gaussianMixture) {
// We may have mixtures with less discrete keys than discreteFactor so
// we skip those since the label assignment does not exist.
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
// Get the pruner function.
auto pruner = PrunerFunc(discreteFactor, discreteFactorKeySet, gmKeySet);
// Run the pruning to get a new, pruned tree
GaussianMixture::Conditionals prunedTree =
gaussianMixture->conditionals().apply(pruner);
DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
// reverse keys to get a natural ordering
std::reverse(discreteKeys.begin(), discreteKeys.end());
// Convert from boost::iterator_range to KeyVector
// so we can pass it to constructor.
KeyVector frontals(gaussianMixture->frontals().begin(),
gaussianMixture->frontals().end()),
parents(gaussianMixture->parents().begin(),
gaussianMixture->parents().end());
// Create the new gaussian mixture and add it to the bayes net.
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
frontals, parents, discreteKeys, prunedTree);
// Make a copy of the gaussian mixture and prune it!
auto prunedGaussianMixture =
boost::make_shared<GaussianMixture>(*gaussianMixture);
prunedGaussianMixture->prune(*discreteFactor);
// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(

View File

@ -149,16 +149,16 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
this->roots_.at(0)->conditional()->inner());
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDiscreteFactor.root_;
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;
/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {
/// The discrete decision tree after pruning.
DecisionTreeFactor prunedDiscreteFactor;
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
DecisionTreeFactor prunedDecisionTree;
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree,
const HybridBayesTree::sharedNode& parentClique)
: prunedDiscreteFactor(prunedDiscreteFactor) {}
: prunedDecisionTree(prunedDecisionTree) {}
/**
* @brief A function used during tree traversal that operates on each node
@ -178,19 +178,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture();
// Check if the number of discrete keys match,
// else we get an assignment error.
// TODO(Varun) Update prune method to handle assignment subset?
if (gaussianMixture->discreteKeys() ==
parentData.prunedDiscreteFactor.discreteKeys()) {
gaussianMixture->prune(parentData.prunedDiscreteFactor);
}
gaussianMixture->prune(parentData.prunedDecisionTree);
}
return parentData;
}
};
HybridPrunerData rootData(prunedDiscreteFactor, 0);
HybridPrunerData rootData(prunedDecisionTree, 0);
{
treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP

View File

@ -337,7 +337,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
EXPECT_LONGS_EQUAL(
2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
4, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
3, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(

View File

@ -363,7 +363,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
EXPECT_LONGS_EQUAL(
2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
4, bayesTree[X(2)]->conditional()->asMixture()->nrComponents());
3, bayesTree[X(2)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(
5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents());
EXPECT_LONGS_EQUAL(