From 1134d1c88e1fb28daa50a3467b7c97408a1ab7b5 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Wed, 28 Dec 2022 13:52:59 -0500 Subject: [PATCH] Refactor with uniform dynamic pointer cast API --- gtsam/hybrid/HybridBayesNet.cpp | 64 ++++++++----------- gtsam/hybrid/HybridBayesTree.cpp | 4 +- gtsam/hybrid/tests/testHybridBayesNet.cpp | 6 +- .../tests/testHybridGaussianFactorGraph.cpp | 2 +- gtsam/hybrid/tests/testHybridGaussianISAM.cpp | 9 ++- .../hybrid/tests/testHybridNonlinearISAM.cpp | 9 ++- 6 files changed, 39 insertions(+), 55 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 6024add07..c598b7d62 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -36,7 +36,7 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { for (auto &&conditional : *this) { if (conditional->isDiscrete()) { // Convert to a DecisionTreeFactor and add it to the main factor. - DecisionTreeFactor f(*conditional->asDiscreteConditional()); + DecisionTreeFactor f(*conditional->asDiscrete()); dtFactor = dtFactor * f; } } @@ -108,7 +108,7 @@ void HybridBayesNet::updateDiscreteConditionals( HybridConditional::shared_ptr conditional = this->at(i); if (conditional->isDiscrete()) { // std::cout << demangle(typeid(conditional).name()) << std::endl; - auto discrete = conditional->asDiscreteConditional(); + auto discrete = conditional->asDiscrete(); KeyVector frontals(discrete->frontals().begin(), discrete->frontals().end()); @@ -151,13 +151,10 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { // Go through all the conditionals in the // Bayes Net and prune them as per decisionTree. for (auto &&conditional : *this) { - if (conditional->isHybrid()) { - GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture(); - + if (auto gm = conditional->asMixture()) { // Make a copy of the Gaussian mixture and prune it! - auto prunedGaussianMixture = - boost::make_shared(*gaussianMixture); - prunedGaussianMixture->prune(*decisionTree); + auto prunedGaussianMixture = boost::make_shared(*gm); + prunedGaussianMixture->prune(*decisionTree); // imperative :-( // Type-erase and add to the pruned Bayes Net fragment. prunedBayesNetFragment.push_back( @@ -184,7 +181,7 @@ GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const { /* ************************************************************************* */ DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { - return at(i)->asDiscreteConditional(); + return at(i)->asDiscrete(); } /* ************************************************************************* */ @@ -192,16 +189,13 @@ GaussianBayesNet HybridBayesNet::choose( const DiscreteValues &assignment) const { GaussianBayesNet gbn; for (auto &&conditional : *this) { - if (conditional->isHybrid()) { + if (auto gm = conditional->asMixture()) { // If conditional is hybrid, select based on assignment. - GaussianMixture gm = *conditional->asMixture(); - gbn.push_back(gm(assignment)); - - } else if (conditional->isContinuous()) { + gbn.push_back((*gm)(assignment)); + } else if (auto gc = conditional->asGaussian()) { // If continuous only, add Gaussian conditional. - gbn.push_back((conditional->asGaussian())); - - } else if (conditional->isDiscrete()) { + gbn.push_back(gc); + } else if (auto dc = conditional->asDiscrete()) { // If conditional is discrete-only, we simply continue. continue; } @@ -216,7 +210,7 @@ HybridValues HybridBayesNet::optimize() const { DiscreteBayesNet discrete_bn; for (auto &&conditional : *this) { if (conditional->isDiscrete()) { - discrete_bn.push_back(conditional->asDiscreteConditional()); + discrete_bn.push_back(conditional->asDiscrete()); } } @@ -238,26 +232,23 @@ double HybridBayesNet::evaluate(const HybridValues &values) const { const DiscreteValues &discreteValues = values.discrete(); const VectorValues &continuousValues = values.continuous(); - double probability = 1.0; + double logDensity = 0.0, probability = 1.0; // Iterate over each conditional. for (auto &&conditional : *this) { - if (conditional->isHybrid()) { - // If conditional is hybrid, select based on assignment and evaluate. - const GaussianMixture::shared_ptr gm = conditional->asMixture(); - const auto conditional = (*gm)(discreteValues); - probability *= conditional->evaluate(continuousValues); - } else if (conditional->isContinuous()) { + if (auto gm = conditional->asMixture()) { + const auto component = (*gm)(discreteValues); + logDensity += component->logDensity(continuousValues); + } else if (auto gc = conditional->asGaussian()) { // If continuous only, evaluate the probability and multiply. - probability *= conditional->asGaussian()->evaluate(continuousValues); - } else if (conditional->isDiscrete()) { + logDensity += gc->logDensity(continuousValues); + } else if (auto dc = conditional->asDiscrete()) { // Conditional is discrete-only, so return its probability. - probability *= - conditional->asDiscreteConditional()->operator()(discreteValues); + probability *= dc->operator()(discreteValues); } } - return probability; + return probability * exp(logDensity); } /* ************************************************************************* */ @@ -267,7 +258,7 @@ HybridValues HybridBayesNet::sample(const HybridValues &given, for (auto &&conditional : *this) { if (conditional->isDiscrete()) { // If conditional is discrete-only, we add to the discrete Bayes net. - dbn.push_back(conditional->asDiscreteConditional()); + dbn.push_back(conditional->asDiscrete()); } } // Sample a discrete assignment. @@ -309,23 +300,20 @@ AlgebraicDecisionTree HybridBayesNet::error( // Iterate over each conditional. for (auto &&conditional : *this) { - if (conditional->isHybrid()) { + if (auto gm = conditional->asMixture()) { // If conditional is hybrid, select based on assignment and compute error. - GaussianMixture::shared_ptr gm = conditional->asMixture(); AlgebraicDecisionTree conditional_error = gm->error(continuousValues); error_tree = error_tree + conditional_error; - - } else if (conditional->isContinuous()) { + } else if (auto gc = conditional->asGaussian()) { // If continuous only, get the (double) error // and add it to the error_tree - double error = conditional->asGaussian()->error(continuousValues); + double error = gc->error(continuousValues); // Add the computed error to every leaf of the error tree. error_tree = error_tree.apply( [error](double leaf_value) { return leaf_value + error; }); - - } else if (conditional->isDiscrete()) { + } else if (auto dc = conditional->asDiscrete()) { // Conditional is discrete-only, we skip. continue; } diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp index 8fdedab44..ed70a0aa9 100644 --- a/gtsam/hybrid/HybridBayesTree.cpp +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -49,7 +49,7 @@ HybridValues HybridBayesTree::optimize() const { // The root should be discrete only, we compute the MPE if (root_conditional->isDiscrete()) { - dbn.push_back(root_conditional->asDiscreteConditional()); + dbn.push_back(root_conditional->asDiscrete()); mpe = DiscreteFactorGraph(dbn).optimize(); } else { throw std::runtime_error( @@ -147,7 +147,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const { /* ************************************************************************* */ void HybridBayesTree::prune(const size_t maxNrLeaves) { auto decisionTree = - this->roots_.at(0)->conditional()->asDiscreteConditional(); + this->roots_.at(0)->conditional()->asDiscrete(); DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); decisionTree->root_ = prunedDecisionTree.root_; diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index f90152abe..d22087f47 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -317,8 +317,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, prunedDecisionTree->nrLeaves()); - auto original_discrete_conditionals = - *(hybridBayesNet->at(4)->asDiscreteConditional()); + auto original_discrete_conditionals = *(hybridBayesNet->at(4)->asDiscrete()); // Prune! hybridBayesNet->prune(maxNrLeaves); @@ -338,8 +337,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { }; // Get the pruned discrete conditionals as an AlgebraicDecisionTree - auto pruned_discrete_conditionals = - hybridBayesNet->at(4)->asDiscreteConditional(); + auto pruned_discrete_conditionals = hybridBayesNet->at(4)->asDiscrete(); auto discrete_conditional_tree = boost::dynamic_pointer_cast( pruned_discrete_conditionals); diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 7877461b6..55e4c28ad 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -133,7 +133,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) { auto result = hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {M(1)})); - auto dc = result->at(2)->asDiscreteConditional(); + auto dc = result->at(2)->asDiscrete(); DiscreteValues dv; dv[M(1)] = 0; EXPECT_DOUBLES_EQUAL(1, dc->operator()(dv), 1e-3); diff --git a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp index 18ce7f10e..11bd3b415 100644 --- a/gtsam/hybrid/tests/testHybridGaussianISAM.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianISAM.cpp @@ -111,8 +111,7 @@ TEST(HybridGaussianElimination, IncrementalInference) { // Run update step isam.update(graph1); - auto discreteConditional_m0 = - isam[M(0)]->conditional()->asDiscreteConditional(); + auto discreteConditional_m0 = isam[M(0)]->conditional()->asDiscrete(); EXPECT(discreteConditional_m0->keys() == KeyVector({M(0)})); /********************************************************/ @@ -170,10 +169,10 @@ TEST(HybridGaussianElimination, IncrementalInference) { DiscreteValues m00; m00[M(0)] = 0, m00[M(1)] = 0; DiscreteConditional decisionTree = - *(*discreteBayesTree)[M(1)]->conditional()->asDiscreteConditional(); + *(*discreteBayesTree)[M(1)]->conditional()->asDiscrete(); double m00_prob = decisionTree(m00); - auto discreteConditional = isam[M(1)]->conditional()->asDiscreteConditional(); + auto discreteConditional = isam[M(1)]->conditional()->asDiscrete(); // Test if the probability values are as expected with regression tests. DiscreteValues assignment; @@ -535,7 +534,7 @@ TEST(HybridGaussianISAM, NonTrivial) { // The final discrete graph should not be empty since we have eliminated // all continuous variables. - auto discreteTree = inc[M(3)]->conditional()->asDiscreteConditional(); + auto discreteTree = inc[M(3)]->conditional()->asDiscrete(); EXPECT_LONGS_EQUAL(3, discreteTree->size()); // Test if the optimal discrete mode assignment is (1, 1, 1). diff --git a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp index 3bdb5ed1e..8801a8946 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearISAM.cpp @@ -124,8 +124,7 @@ TEST(HybridNonlinearISAM, IncrementalInference) { isam.update(graph1, initial); HybridGaussianISAM bayesTree = isam.bayesTree(); - auto discreteConditional_m0 = - bayesTree[M(0)]->conditional()->asDiscreteConditional(); + auto discreteConditional_m0 = bayesTree[M(0)]->conditional()->asDiscrete(); EXPECT(discreteConditional_m0->keys() == KeyVector({M(0)})); /********************************************************/ @@ -187,11 +186,11 @@ TEST(HybridNonlinearISAM, IncrementalInference) { DiscreteValues m00; m00[M(0)] = 0, m00[M(1)] = 0; DiscreteConditional decisionTree = - *(*discreteBayesTree)[M(1)]->conditional()->asDiscreteConditional(); + *(*discreteBayesTree)[M(1)]->conditional()->asDiscrete(); double m00_prob = decisionTree(m00); auto discreteConditional = - bayesTree[M(1)]->conditional()->asDiscreteConditional(); + bayesTree[M(1)]->conditional()->asDiscrete(); // Test if the probability values are as expected with regression tests. DiscreteValues assignment; @@ -558,7 +557,7 @@ TEST(HybridNonlinearISAM, NonTrivial) { // The final discrete graph should not be empty since we have eliminated // all continuous variables. - auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional(); + auto discreteTree = bayesTree[M(3)]->conditional()->asDiscrete(); EXPECT_LONGS_EQUAL(3, discreteTree->size()); // Test if the optimal discrete mode assignment is (1, 1, 1).