HybridBayesNet and HybridBayesTree both use similar pruning functions
parent
2225ecf442
commit
5e99cd7095
|
|
@ -22,14 +22,6 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************* */
|
||||
/// Return the DiscreteKey vector as a set.
|
||||
static std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &dkeys) {
|
||||
std::set<DiscreteKey> s;
|
||||
s.insert(dkeys.begin(), dkeys.end());
|
||||
return s;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||
AlgebraicDecisionTree<Key> decisionTree;
|
||||
|
|
@ -49,63 +41,6 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
|||
return boost::make_shared<DecisionTreeFactor>(dtFactor);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper function to get the pruner functional.
|
||||
*
|
||||
* @param probDecisionTree The probability decision tree of only discrete keys.
|
||||
* @param discreteFactorKeySet Set of DiscreteKeys in probDecisionTree.
|
||||
* Pre-computed for efficiency.
|
||||
* @param gaussianMixtureKeySet Set of DiscreteKeys in the GaussianMixture.
|
||||
* @return std::function<GaussianConditional::shared_ptr(
|
||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
*/
|
||||
std::function<GaussianConditional::shared_ptr(
|
||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
PrunerFunc(const DecisionTreeFactor::shared_ptr &probDecisionTree,
|
||||
const std::set<DiscreteKey> &discreteFactorKeySet,
|
||||
const std::set<DiscreteKey> &gaussianMixtureKeySet) {
|
||||
auto pruner = [&](const Assignment<Key> &choices,
|
||||
const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianConditional::shared_ptr {
|
||||
// typecast so we can use this to get probability value
|
||||
DiscreteValues values(choices);
|
||||
|
||||
// Case where the gaussian mixture has the same
|
||||
// discrete keys as the decision tree.
|
||||
if (gaussianMixtureKeySet == discreteFactorKeySet) {
|
||||
if ((*probDecisionTree)(values) == 0.0) {
|
||||
// empty aka null pointer
|
||||
boost::shared_ptr<GaussianConditional> null;
|
||||
return null;
|
||||
} else {
|
||||
return conditional;
|
||||
}
|
||||
} else {
|
||||
std::vector<DiscreteKey> set_diff;
|
||||
std::set_difference(
|
||||
discreteFactorKeySet.begin(), discreteFactorKeySet.end(),
|
||||
gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(set_diff);
|
||||
for (const DiscreteValues &assignment : assignments) {
|
||||
DiscreteValues augmented_values(values);
|
||||
augmented_values.insert(assignment.begin(), assignment.end());
|
||||
|
||||
// If any one of the sub-branches are non-zero,
|
||||
// we need this conditional.
|
||||
if ((*probDecisionTree)(augmented_values) > 0.0) {
|
||||
return conditional;
|
||||
}
|
||||
}
|
||||
// If we are here, it means that all the sub-branches are 0,
|
||||
// so we prune.
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
return pruner;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||
// Get the decision tree of only the discrete keys
|
||||
|
|
@ -114,8 +49,6 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
|||
boost::make_shared<DecisionTreeFactor>(
|
||||
discreteConditionals->prune(maxNrLeaves));
|
||||
|
||||
auto discreteFactorKeySet = DiscreteKeysAsSet(discreteFactor->discreteKeys());
|
||||
|
||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
* for 0.0 probability, then just set the leaf to a nullptr.
|
||||
|
|
@ -130,35 +63,13 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
|||
for (size_t i = 0; i < this->size(); i++) {
|
||||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
|
||||
GaussianMixture::shared_ptr gaussianMixture =
|
||||
boost::dynamic_pointer_cast<GaussianMixture>(conditional->inner());
|
||||
if (conditional->isHybrid()) {
|
||||
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
|
||||
|
||||
if (gaussianMixture) {
|
||||
// We may have mixtures with less discrete keys than discreteFactor so
|
||||
// we skip those since the label assignment does not exist.
|
||||
auto gmKeySet = DiscreteKeysAsSet(gaussianMixture->discreteKeys());
|
||||
|
||||
// Get the pruner function.
|
||||
auto pruner = PrunerFunc(discreteFactor, discreteFactorKeySet, gmKeySet);
|
||||
|
||||
// Run the pruning to get a new, pruned tree
|
||||
GaussianMixture::Conditionals prunedTree =
|
||||
gaussianMixture->conditionals().apply(pruner);
|
||||
|
||||
DiscreteKeys discreteKeys = gaussianMixture->discreteKeys();
|
||||
// reverse keys to get a natural ordering
|
||||
std::reverse(discreteKeys.begin(), discreteKeys.end());
|
||||
|
||||
// Convert from boost::iterator_range to KeyVector
|
||||
// so we can pass it to constructor.
|
||||
KeyVector frontals(gaussianMixture->frontals().begin(),
|
||||
gaussianMixture->frontals().end()),
|
||||
parents(gaussianMixture->parents().begin(),
|
||||
gaussianMixture->parents().end());
|
||||
|
||||
// Create the new gaussian mixture and add it to the bayes net.
|
||||
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(
|
||||
frontals, parents, discreteKeys, prunedTree);
|
||||
// Make a copy of the gaussian mixture and prune it!
|
||||
auto prunedGaussianMixture =
|
||||
boost::make_shared<GaussianMixture>(*gaussianMixture);
|
||||
prunedGaussianMixture->prune(*discreteFactor);
|
||||
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
prunedBayesNetFragment.push_back(
|
||||
|
|
|
|||
|
|
@ -149,16 +149,16 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
|||
auto decisionTree = boost::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
this->roots_.at(0)->conditional()->inner());
|
||||
|
||||
DecisionTreeFactor prunedDiscreteFactor = decisionTree->prune(maxNrLeaves);
|
||||
decisionTree->root_ = prunedDiscreteFactor.root_;
|
||||
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
|
||||
decisionTree->root_ = prunedDecisionTree.root_;
|
||||
|
||||
/// Helper struct for pruning the hybrid bayes tree.
|
||||
struct HybridPrunerData {
|
||||
/// The discrete decision tree after pruning.
|
||||
DecisionTreeFactor prunedDiscreteFactor;
|
||||
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteFactor,
|
||||
DecisionTreeFactor prunedDecisionTree;
|
||||
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree,
|
||||
const HybridBayesTree::sharedNode& parentClique)
|
||||
: prunedDiscreteFactor(prunedDiscreteFactor) {}
|
||||
: prunedDecisionTree(prunedDecisionTree) {}
|
||||
|
||||
/**
|
||||
* @brief A function used during tree traversal that operates on each node
|
||||
|
|
@ -178,19 +178,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
|||
if (conditional->isHybrid()) {
|
||||
auto gaussianMixture = conditional->asMixture();
|
||||
|
||||
// Check if the number of discrete keys match,
|
||||
// else we get an assignment error.
|
||||
// TODO(Varun) Update prune method to handle assignment subset?
|
||||
if (gaussianMixture->discreteKeys() ==
|
||||
parentData.prunedDiscreteFactor.discreteKeys()) {
|
||||
gaussianMixture->prune(parentData.prunedDiscreteFactor);
|
||||
}
|
||||
gaussianMixture->prune(parentData.prunedDecisionTree);
|
||||
}
|
||||
return parentData;
|
||||
}
|
||||
};
|
||||
|
||||
HybridPrunerData rootData(prunedDiscreteFactor, 0);
|
||||
HybridPrunerData rootData(prunedDecisionTree, 0);
|
||||
{
|
||||
treeTraversal::no_op visitorPost;
|
||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||
|
|
|
|||
|
|
@ -337,7 +337,7 @@ TEST(HybridGaussianElimination, Incremental_approximate) {
|
|||
EXPECT_LONGS_EQUAL(
|
||||
2, incrementalHybrid[X(1)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
4, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
|
||||
3, incrementalHybrid[X(2)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
5, incrementalHybrid[X(3)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
|
|
|
|||
|
|
@ -363,7 +363,7 @@ TEST(HybridNonlinearISAM, Incremental_approximate) {
|
|||
EXPECT_LONGS_EQUAL(
|
||||
2, bayesTree[X(1)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
4, bayesTree[X(2)]->conditional()->asMixture()->nrComponents());
|
||||
3, bayesTree[X(2)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
5, bayesTree[X(3)]->conditional()->asMixture()->nrComponents());
|
||||
EXPECT_LONGS_EQUAL(
|
||||
|
|
|
|||
Loading…
Reference in New Issue