rename variables to be agnostic to underlying data structure

release/4.3a0
Varun Agrawal 2023-06-28 16:13:54 -04:00
parent 4de2d46012
commit 9531506492
7 changed files with 61 additions and 62 deletions

View File

@ -228,19 +228,19 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
/** /**
* @brief Helper function to get the pruner functional. * @brief Helper function to get the pruner functional.
* *
* @param decisionTree The probability decision tree of only discrete keys. * @param discreteProbs The probabilities of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr( * @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)> * const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/ */
std::function<GaussianConditional::shared_ptr( std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)> const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) { GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {
// Get the discrete keys as sets for the decision tree // Get the discrete keys as sets for the decision tree
// and the gaussian mixture. // and the gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys()); auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
auto pruner = [decisionTree, decisionTreeKeySet, gaussianMixtureKeySet]( auto pruner = [discreteProbs, discreteProbsKeySet, gaussianMixtureKeySet](
const Assignment<Key> &choices, const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional) const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr { -> GaussianConditional::shared_ptr {
@ -249,8 +249,8 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
// Case where the gaussian mixture has the same // Case where the gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
if (gaussianMixtureKeySet == decisionTreeKeySet) { if (gaussianMixtureKeySet == discreteProbsKeySet) {
if (decisionTree(values) == 0.0) { if (discreteProbs(values) == 0.0) {
// empty aka null pointer // empty aka null pointer
std::shared_ptr<GaussianConditional> null; std::shared_ptr<GaussianConditional> null;
return null; return null;
@ -259,10 +259,10 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
} }
} else { } else {
std::vector<DiscreteKey> set_diff; std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), std::set_difference(
gaussianMixtureKeySet.begin(), discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
gaussianMixtureKeySet.end(), gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
std::back_inserter(set_diff)); std::back_inserter(set_diff));
const std::vector<DiscreteValues> assignments = const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff); DiscreteValues::CartesianProduct(set_diff);
@ -272,7 +272,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
// If any one of the sub-branches are non-zero, // If any one of the sub-branches are non-zero,
// we need this conditional. // we need this conditional.
if (decisionTree(augmented_values) > 0.0) { if (discreteProbs(augmented_values) > 0.0) {
return conditional; return conditional;
} }
} }
@ -285,12 +285,12 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &decisionTree) {
} }
/* *******************************************************************************/ /* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &decisionTree) { void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys()); auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys()); auto gmKeySet = DiscreteKeysAsSet(this->discreteKeys());
// Functional which loops over all assignments and create a set of // Functional which loops over all assignments and create a set of
// GaussianConditionals // GaussianConditionals
auto pruner = prunerFunc(decisionTree); auto pruner = prunerFunc(discreteProbs);
auto pruned_conditionals = conditionals_.apply(pruner); auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_; conditionals_.root_ = pruned_conditionals.root_;

View File

@ -74,13 +74,13 @@ class GTSAM_EXPORT GaussianMixture
/** /**
* @brief Helper function to get the pruner functor. * @brief Helper function to get the pruner functor.
* *
* @param decisionTree The pruned discrete probability decision tree. * @param discreteProbs The pruned discrete probabilities.
* @return std::function<GaussianConditional::shared_ptr( * @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)> * const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/ */
std::function<GaussianConditional::shared_ptr( std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)> const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
prunerFunc(const DecisionTreeFactor &decisionTree); prunerFunc(const DecisionTreeFactor &discreteProbs);
public: public:
/// @name Constructors /// @name Constructors
@ -234,12 +234,11 @@ class GTSAM_EXPORT GaussianMixture
/** /**
* @brief Prune the decision tree of Gaussian factors as per the discrete * @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`. * `discreteProbs`.
* *
* @param decisionTree A pruned decision tree of discrete keys where the * @param discreteProbs A pruned set of probabilities for the discrete keys.
* leaves are probabilities.
*/ */
void prune(const DecisionTreeFactor &decisionTree); void prune(const DecisionTreeFactor &discreteProbs);
/** /**
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while

View File

@ -39,41 +39,41 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
/* ************************************************************************* */ /* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> decisionTree; AlgebraicDecisionTree<Key> discreteProbs;
// The canonical decision tree factor which will get // The canonical decision tree factor which will get
// the discrete conditionals added to it. // the discrete conditionals added to it.
DecisionTreeFactor dtFactor; DecisionTreeFactor discreteProbsFactor;
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor. // Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscrete()); DecisionTreeFactor f(*conditional->asDiscrete());
dtFactor = dtFactor * f; discreteProbsFactor = discreteProbsFactor * f;
} }
} }
return std::make_shared<DecisionTreeFactor>(dtFactor); return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
} }
/* ************************************************************************* */ /* ************************************************************************* */
/** /**
* @brief Helper function to get the pruner functional. * @brief Helper function to get the pruner functional.
* *
* @param prunedDecisionTree The prob. decision tree of only discrete keys. * @param prunedDiscreteProbs The prob. decision tree of only discrete keys.
* @param conditional Conditional to prune. Used to get full assignment. * @param conditional Conditional to prune. Used to get full assignment.
* @return std::function<double(const Assignment<Key> &, double)> * @return std::function<double(const Assignment<Key> &, double)>
*/ */
std::function<double(const Assignment<Key> &, double)> prunerFunc( std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &prunedDecisionTree, const DecisionTreeFactor &prunedDiscreteProbs,
const HybridConditional &conditional) { const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree // Get the discrete keys as sets for the decision tree
// and the Gaussian mixture. // and the Gaussian mixture.
std::set<DiscreteKey> decisionTreeKeySet = std::set<DiscreteKey> discreteProbsKeySet =
DiscreteKeysAsSet(prunedDecisionTree.discreteKeys()); DiscreteKeysAsSet(prunedDiscreteProbs.discreteKeys());
std::set<DiscreteKey> conditionalKeySet = std::set<DiscreteKey> conditionalKeySet =
DiscreteKeysAsSet(conditional.discreteKeys()); DiscreteKeysAsSet(conditional.discreteKeys());
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet]( auto pruner = [prunedDiscreteProbs, discreteProbsKeySet, conditionalKeySet](
const Assignment<Key> &choices, const Assignment<Key> &choices,
double probability) -> double { double probability) -> double {
// This corresponds to 0 probability // This corresponds to 0 probability
@ -83,8 +83,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
DiscreteValues values(choices); DiscreteValues values(choices);
// Case where the Gaussian mixture has the same // Case where the Gaussian mixture has the same
// discrete keys as the decision tree. // discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) { if (conditionalKeySet == discreteProbsKeySet) {
if (prunedDecisionTree(values) == 0) { if (prunedDiscreteProbs(values) == 0) {
return pruned_prob; return pruned_prob;
} else { } else {
return probability; return probability;
@ -114,11 +114,12 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
} }
// Now we generate the full assignment by enumerating // Now we generate the full assignment by enumerating
// over all keys in the prunedDecisionTree. // over all keys in the prunedDiscreteProbs.
// First we find the differing keys // First we find the differing keys
std::vector<DiscreteKey> set_diff; std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(), std::set_difference(discreteProbsKeySet.begin(),
conditionalKeySet.begin(), conditionalKeySet.end(), discreteProbsKeySet.end(), conditionalKeySet.begin(),
conditionalKeySet.end(),
std::back_inserter(set_diff)); std::back_inserter(set_diff));
// Now enumerate over all assignments of the differing keys // Now enumerate over all assignments of the differing keys
@ -130,7 +131,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
// If any one of the sub-branches are non-zero, // If any one of the sub-branches are non-zero,
// we need this probability. // we need this probability.
if (prunedDecisionTree(augmented_values) > 0.0) { if (prunedDiscreteProbs(augmented_values) > 0.0) {
return probability; return probability;
} }
} }
@ -144,8 +145,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals( void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor &prunedDecisionTree) { const DecisionTreeFactor &prunedDiscreteProbs) {
KeyVector prunedTreeKeys = prunedDecisionTree.keys(); KeyVector prunedTreeKeys = prunedDiscreteProbs.keys();
// Loop with index since we need it later. // Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) { for (size_t i = 0; i < this->size(); i++) {
@ -157,7 +158,7 @@ void HybridBayesNet::updateDiscreteConditionals(
auto discreteTree = auto discreteTree =
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete); std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
DecisionTreeFactor::ADT prunedDiscreteTree = DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(prunedDecisionTree, *conditional)); discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));
// Create the new (hybrid) conditional // Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(), KeyVector frontals(discrete->frontals().begin(),
@ -175,10 +176,12 @@ void HybridBayesNet::updateDiscreteConditionals(
/* ************************************************************************* */ /* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys // Get the decision tree of only the discrete keys
auto discreteConditionals = this->discreteConditionals(); DecisionTreeFactor::shared_ptr discreteConditionals =
const auto decisionTree = discreteConditionals->prune(maxNrLeaves); this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves);
this->updateDiscreteConditionals(decisionTree); this->updateDiscreteConditionals(prunedDiscreteProbs);
/* To Prune, we visitWith every leaf in the GaussianMixture. /* To Prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree * For each leaf, using the assignment we can check the discrete decision tree
@ -190,12 +193,12 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
HybridBayesNet prunedBayesNetFragment; HybridBayesNet prunedBayesNetFragment;
// Go through all the conditionals in the // Go through all the conditionals in the
// Bayes Net and prune them as per decisionTree. // Bayes Net and prune them as per prunedDiscreteProbs.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) { if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it! // Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm); auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm);
prunedGaussianMixture->prune(decisionTree); // imperative :-( prunedGaussianMixture->prune(prunedDiscreteProbs); // imperative :-(
// Type-erase and add to the pruned Bayes Net fragment. // Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(prunedGaussianMixture); prunedBayesNetFragment.push_back(prunedGaussianMixture);

View File

@ -224,9 +224,9 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/** /**
* @brief Update the discrete conditionals with the pruned versions. * @brief Update the discrete conditionals with the pruned versions.
* *
* @param prunedDecisionTree * @param prunedDiscreteProbs
*/ */
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDecisionTree); void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */ /** Serialization function */

View File

@ -173,19 +173,18 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) { void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
this->roots_.at(0)->conditional()->asDiscrete();
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); DecisionTreeFactor prunedDiscreteProbs = discreteProbs->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_; discreteProbs->root_ = prunedDiscreteProbs.root_;
/// Helper struct for pruning the hybrid bayes tree. /// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData { struct HybridPrunerData {
/// The discrete decision tree after pruning. /// The discrete decision tree after pruning.
DecisionTreeFactor prunedDecisionTree; DecisionTreeFactor prunedDiscreteProbs;
HybridPrunerData(const DecisionTreeFactor& prunedDecisionTree, HybridPrunerData(const DecisionTreeFactor& prunedDiscreteProbs,
const HybridBayesTree::sharedNode& parentClique) const HybridBayesTree::sharedNode& parentClique)
: prunedDecisionTree(prunedDecisionTree) {} : prunedDiscreteProbs(prunedDiscreteProbs) {}
/** /**
* @brief A function used during tree traversal that operates on each node * @brief A function used during tree traversal that operates on each node
@ -205,13 +204,13 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
if (conditional->isHybrid()) { if (conditional->isHybrid()) {
auto gaussianMixture = conditional->asMixture(); auto gaussianMixture = conditional->asMixture();
gaussianMixture->prune(parentData.prunedDecisionTree); gaussianMixture->prune(parentData.prunedDiscreteProbs);
} }
return parentData; return parentData;
} }
}; };
HybridPrunerData rootData(prunedDecisionTree, 0); HybridPrunerData rootData(prunedDiscreteProbs, 0);
{ {
treeTraversal::no_op visitorPost; treeTraversal::no_op visitorPost;
// Limits OpenMP threads since we're mixing TBB and OpenMP // Limits OpenMP threads since we're mixing TBB and OpenMP

View File

@ -190,7 +190,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
/* ************************************************************************ */ /* ************************************************************************ */
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert // If any GaussianFactorGraph in the decision tree contains a nullptr, convert
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will // that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
// otherwise create a GFG with a single (null) factor, which doesn't register as null. // otherwise create a GFG with a single (null) factor,
// which doesn't register as null.
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) { GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
auto emptyGaussian = [](const GaussianFactorGraph &graph) { auto emptyGaussian = [](const GaussianFactorGraph &graph) {
bool hasNull = bool hasNull =
@ -246,10 +247,6 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Perform elimination! // Perform elimination!
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate); DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
#ifdef HYBRID_TIMING
tictoc_print_();
#endif
// Separate out decision tree into conditionals and remaining factors. // Separate out decision tree into conditionals and remaining factors.
const auto [conditionals, newFactors] = unzip(eliminationResults); const auto [conditionals, newFactors] = unzip(eliminationResults);

View File

@ -112,8 +112,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
public: public:
using Base = HybridFactorGraph; using Base = HybridFactorGraph;
using This = HybridGaussianFactorGraph; ///< this class using This = HybridGaussianFactorGraph; ///< this class
using BaseEliminateable = ///< for elimination
EliminateableFactorGraph<This>; ///< for elimination using BaseEliminateable = EliminateableFactorGraph<This>;
using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This using shared_ptr = std::shared_ptr<This>; ///< shared_ptr to This
using Values = gtsam::Values; ///< backwards compatibility using Values = gtsam::Values; ///< backwards compatibility
@ -148,7 +148,8 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
/// @name Standard Interface /// @name Standard Interface
/// @{ /// @{
using Base::error; // Expose error(const HybridValues&) method.. /// Expose error(const HybridValues&) method.
using Base::error;
/** /**
* @brief Compute error for each discrete assignment, * @brief Compute error for each discrete assignment,