rename variables to be agnostic to underlying data structure
parent
4de2d46012
commit
9531506492
|
@ -228,19 +228,19 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
|
|||
/**
|
||||
* @brief Helper function to get the pruner functional.
|
||||
*
|
||||
* @param decisionTree The probability decision tree of only discrete keys.
|
||||
* @param discreteProbs The probabilities of only discrete keys.
|
||||
* @return std::function<GaussianConditional::shared_ptr(
|
||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
*/
|
||||
std::function<GaussianConditional::shared_ptr(
|
||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
||||
GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {
|
||||
// Get the discrete keys as sets for the decision tree
|
||||
// and the gaussian mixture.
|
||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
||||
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
||||
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
||||
|
||||
auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet](
|
||||
auto pruner = [discreteProbs, discreteProbsKeySet, gaussianMixtureKeySet](
|
||||
const Assignment<Key> &choices,
|
||||
const GaussianConditional::shared_ptr &conditional)
|
||||
-> GaussianConditional::shared_ptr {
|
||||
|
@ -249,8 +249,8 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
|||
|
||||
// Case where the gaussian mixture has the same
|
||||
// discrete keys as the decision tree.
|
||||
if (gaussianMixtureKeySet == decisionTreeKeySet) {
|
||||
if (decisionTree(values) == 0.0) {
|
||||
if (gaussianMixtureKeySet == discreteProbsKeySet) {
|
||||
if (discreteProbs(values) == 0.0) {
|
||||
// empty aka null pointer
|
||||
std::shared_ptr<GaussianConditional> null;
|
||||
return null;
|
||||
|
@ -259,10 +259,10 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
|||
}
|
||||
} else {
|
||||
std::vector<DiscreteKey> set_diff;
|
||||
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
||||
gaussianMixtureKeySet.begin(),
|
||||
gaussianMixtureKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
std::set_difference(
|
||||
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
|
||||
gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
|
||||
const std::vector<DiscreteValues> assignments =
|
||||
DiscreteValues::CartesianProduct(set_diff);
|
||||
|
@ -272,7 +272,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
|||
|
||||
// If any one of the sub-branches are non-zero,
|
||||
// we need this conditional.
|
||||
if (decisionTree(augmented_values) > 0.0) {
|
||||
if (discreteProbs(augmented_values) > 0.0) {
|
||||
return conditional;
|
||||
}
|
||||
}
|
||||
|
@ -285,12 +285,12 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) {
|
||||
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
|
||||
void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
|
||||
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
|
||||
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
|
||||
// Functional which loops over all assignments and create a set of
|
||||
// GaussianConditionals
|
||||
auto pruner = prunerFunc(decisionTree);
|
||||
auto pruner = prunerFunc(discreteProbs);
|
||||
|
||||
auto pruned_conditionals = conditionals_.apply(pruner);
|
||||
conditionals_.root_ = pruned_conditionals.root_;
|
||||
|
|
|
@ -74,13 +74,13 @@ class GTSAM_EXPORT GaussianMixture
|
|||
/**
|
||||
* @brief Helper function to get the pruner functor.
|
||||
*
|
||||
* @param decisionTree The pruned discrete probability decision tree.
|
||||
* @param discreteProbs The pruned discrete probabilities.
|
||||
* @return std::function<GaussianConditional::shared_ptr(
|
||||
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
*/
|
||||
std::function<GaussianConditional::shared_ptr(
|
||||
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
|
||||
prunerFunc(const DecisionTreeFactor &decisionTree);
|
||||
prunerFunc(const DecisionTreeFactor &discreteProbs);
|
||||
|
||||
public:
|
||||
/// @name Constructors
|
||||
|
@ -234,12 +234,11 @@ class GTSAM_EXPORT GaussianMixture
|
|||
|
||||
/**
|
||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||
* `decisionTree`.
|
||||
* `discreteProbs`.
|
||||
*
|
||||
* @param decisionTree A pruned decision tree of discrete keys where the
|
||||
* leaves are probabilities.
|
||||
* @param discreteProbs A pruned set of probabilities for the discrete keys.
|
||||
*/
|
||||
void prune(const DecisionTreeFactor &decisionTree);
|
||||
void prune(const DecisionTreeFactor &discreteProbs);
|
||||
|
||||
/**
|
||||
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
|
||||
|
|
|
@ -39,41 +39,41 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
|
|||
|
||||
/* ************************************************************************* */
|
||||
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
||||
AlgebraicDecisionTree<Key> decisionTree;
|
||||
AlgebraicDecisionTree<Key> discreteProbs;
|
||||
|
||||
// The canonical decision tree factor which will get
|
||||
// the discrete conditionals added to it.
|
||||
DecisionTreeFactor dtFactor;
|
||||
DecisionTreeFactor discreteProbsFactor;
|
||||
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
// Convert to a DecisionTreeFactor and add it to the main factor.
|
||||
DecisionTreeFactor f(*conditional->asDiscrete());
|
||||
dtFactor = dtFactor * f;
|
||||
discreteProbsFactor = discreteProbsFactor * f;
|
||||
}
|
||||
}
|
||||
return std::make_shared<DecisionTreeFactor>(dtFactor);
|
||||
return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/**
|
||||
* @brief Helper function to get the pruner functional.
|
||||
*
|
||||
* @param prunedDecisionTree The prob. decision tree of only discrete keys.
|
||||
* @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
|
||||
* @param conditional Conditional to prune. Used to get full assignment.
|
||||
* @return std::function<double(const Assignment<Key> &, double)>
|
||||
*/
|
||||
std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
||||
const DecisionTreeFactor &prunedDecisionTree,
|
||||
const DecisionTreeFactor &prunedDiscreteProbs,
|
||||
const HybridConditional &conditional) {
|
||||
// Get the discrete keys as sets for the decision tree
|
||||
// and the Gaussian mixture.
|
||||
std::set<DiscreteKey> decisionTreeKeySet =
|
||||
DiscreteKeysAsSet(prunedDecisionTree.discreteKeys());
|
||||
std::set<DiscreteKey> discreteProbsKeySet =
|
||||
DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
|
||||
std::set<DiscreteKey> conditionalKeySet =
|
||||
DiscreteKeysAsSet(conditional.discreteKeys());
|
||||
|
||||
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
|
||||
auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
|
||||
const Assignment<Key> &choices,
|
||||
double probability) -> double {
|
||||
// This corresponds to 0 probability
|
||||
|
@ -83,8 +83,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|||
DiscreteValues values(choices);
|
||||
// Case where the Gaussian mixture has the same
|
||||
// discrete keys as the decision tree.
|
||||
if (conditionalKeySet == decisionTreeKeySet) {
|
||||
if (prunedDecisionTree(values) == 0) {
|
||||
if (conditionalKeySet == discreteProbsKeySet) {
|
||||
if (prunedDiscreteProbs(values) == 0) {
|
||||
return pruned_prob;
|
||||
} else {
|
||||
return probability;
|
||||
|
@ -114,11 +114,12 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|||
}
|
||||
|
||||
// Now we generate the full assignment by enumerating
|
||||
// over all keys in the prunedDecisionTree.
|
||||
// over all keys in the prunedDiscreteProbs.
|
||||
// First we find the differing keys
|
||||
std::vector<DiscreteKey> set_diff;
|
||||
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
|
||||
conditionalKeySet.begin(), conditionalKeySet.end(),
|
||||
std::set_difference(discreteProbsKeySet.begin(),
|
||||
discreteProbsKeySet.end(), conditionalKeySet.begin(),
|
||||
conditionalKeySet.end(),
|
||||
std::back_inserter(set_diff));
|
||||
|
||||
// Now enumerate over all assignments of the differing keys
|
||||
|
@ -130,7 +131,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|||
|
||||
// If any one of the sub-branches are non-zero,
|
||||
// we need this probability.
|
||||
if (prunedDecisionTree(augmented_values) > 0.0) {
|
||||
if (prunedDiscreteProbs(augmented_values) > 0.0) {
|
||||
return probability;
|
||||
}
|
||||
}
|
||||
|
@ -144,8 +145,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
|
|||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesNet::updateDiscreteConditionals(
|
||||
const DecisionTreeFactor &prunedDecisionTree) {
|
||||
KeyVector prunedTreeKeys = prunedDecisionTree.keys();
|
||||
const DecisionTreeFactor &prunedDiscreteProbs) {
|
||||
KeyVector prunedTreeKeys = prunedDiscreteProbs.keys();
|
||||
|
||||
// Loop with index since we need it later.
|
||||
for (size_t i = 0; i < this->size(); i++) {
|
||||
|
@ -157,7 +158,7 @@ void HybridBayesNet::updateDiscreteConditionals(
|
|||
auto discreteTree =
|
||||
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
|
||||
DecisionTreeFactor::ADT prunedDiscreteTree =
|
||||
discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional));
|
||||
discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));
|
||||
|
||||
// Create the new (hybrid) conditional
|
||||
KeyVector frontals(discrete->frontals().begin(),
|
||||
|
@ -175,10 +176,12 @@ void HybridBayesNet::updateDiscreteConditionals(
|
|||
/* ************************************************************************* */
|
||||
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
||||
// Get the decision tree of only the discrete keys
|
||||
auto discreteConditionals = this->discreteConditionals();
|
||||
const auto decisionTree = discreteConditionals->prune(maxNrLeaves);
|
||||
DecisionTreeFactor::shared_ptr discreteConditionals =
|
||||
this->discreteConditionals();
|
||||
const DecisionTreeFactor prunedDiscreteProbs =
|
||||
discreteConditionals->prune(maxNrLeaves);
|
||||
|
||||
this->updateDiscreteConditionals(decisionTree);
|
||||
this->updateDiscreteConditionals(prunedDiscreteProbs);
|
||||
|
||||
/* To Prune, we visitWith every leaf in the GaussianMixture.
|
||||
* For each leaf, using the assignment we can check the discrete decision tree
|
||||
|
@ -190,12 +193,12 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
|||
HybridBayesNet prunedBayesNetFragment;
|
||||
|
||||
// Go through all the conditionals in the
|
||||
// Bayes Net and prune them as per decisionTree.
|
||||
// Bayes Net and prune them as per prunedDiscreteProbs.
|
||||
for (auto &&conditional : *this) {
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
// Make a copy of the Gaussian mixture and prune it!
|
||||
auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm);
|
||||
prunedGaussianMixture->prune(decisionTree); // imperative :-(
|
||||
prunedGaussianMixture->prune(prunedDiscreteProbs); // imperative :-(
|
||||
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
prunedBayesNetFragment.push_back(prunedGaussianMixture);
|
||||
|
|
|
@ -224,9 +224,9 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
|||
/**
|
||||
* @brief Update the discrete conditionals with the pruned versions.
|
||||
*
|
||||
* @param prunedDecisionTree
|
||||
* @param prunedDiscreteProbs
|
||||
*/
|
||||
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree);
|
||||
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
|
||||
|
||||
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||
/** Serialization function */
|
||||
|
|
|
@ -173,19 +173,18 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
|||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||
auto decisionTree =
|
||||
this->roots_.at(0)->conditional()->asDiscrete();
|
||||
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
|
||||
|
||||
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
|
||||
decisionTree->root_ = prunedDecisionTree.root_;
|
||||
DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
|
||||
discreteProbs->root_ = prunedDiscreteProbs.root_;
|
||||
|
||||
/// Helper struct for pruning the hybrid bayes tree.
|
||||
struct HybridPrunerData {
|
||||
/// The discrete decision tree after pruning.
|
||||
DecisionTreeFactor prunedDecisionTree;
|
||||
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree,
|
||||
DecisionTreeFactor prunedDiscreteProbs;
|
||||
HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
|
||||
const HybridBayesTree::sharedNode& parentClique)
|
||||
: prunedDecisionTree(prunedDecisionTree) {}
|
||||
: prunedDiscreteProbs(prunedDiscreteProbs) {}
|
||||
|
||||
/**
|
||||
* @brief A function used during tree traversal that operates on each node
|
||||
|
@ -205,13 +204,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
|||
if (conditional->isHybrid()) {
|
||||
auto gaussianMixture = conditional->asMixture();
|
||||
|
||||
gaussianMixture->prune(parentData.prunedDecisionTree);
|
||||
gaussianMixture->prune(parentData.prunedDiscreteProbs);
|
||||
}
|
||||
return parentData;
|
||||
}
|
||||
};
|
||||
|
||||
HybridPrunerData rootData(prunedDecisionTree, 0);
|
||||
HybridPrunerData rootData(prunedDiscreteProbs, 0);
|
||||
{
|
||||
treeTraversal::no_op visitorPost;
|
||||
// Limits OpenMP threads since we're mixing TBB and OpenMP
|
||||
|
|
|
@ -190,7 +190,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
/* ************************************************************************ */
|
||||
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
||||
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
|
||||
// otherwise create a GFG with a single (null) factor, which doesn't register as null.
|
||||
// otherwise create a GFG with a single (null) factor,
|
||||
// which doesn't register as null.
|
||||
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
||||
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
|
||||
bool hasNull =
|
||||
|
@ -246,10 +247,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
// Perform elimination!
|
||||
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
|
||||
|
||||
#ifdef HYBRID_TIMING
|
||||
tictoc_print_();
|
||||
#endif
|
||||
|
||||
// Separate out decision tree into conditionals and remaining factors.
|
||||
const auto [conditionals, newFactors] = unzip(eliminationResults);
|
||||
|
||||
|
|
|
@ -112,8 +112,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
public:
|
||||
using Base = HybridFactorGraph;
|
||||
using This = HybridGaussianFactorGraph; ///< this class
|
||||
using BaseEliminateable =
|
||||
EliminateableFactorGraph<This>; ///< for elimination
|
||||
///< for elimination
|
||||
using BaseEliminateable = EliminateableFactorGraph<This>;
|
||||
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This
|
||||
|
||||
using Values = gtsam::Values; ///< backwards compatibility
|
||||
|
@ -148,7 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
/// @name Standard Interface
|
||||
/// @{
|
||||
|
||||
using Base::error; // Expose error(const HybridValues&) method..
|
||||
/// Expose error(const HybridValues&) method.
|
||||
using Base::error;
|
||||
|
||||
/**
|
||||
* @brief Compute error for each discrete assignment,
|
||||
|
|
Loading…
Reference in New Issue