Refactor with uniform dynamic pointer cast API

release/4.3a0
Frank Dellaert 2022-12-28 13:52:59 -05:00
parent c984a5ffa2
commit 1134d1c88e
6 changed files with 39 additions and 55 deletions

View File

@ -36,7 +36,7 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor. // Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscreteConditional()); DecisionTreeFactor f(*conditional->asDiscrete());
dtFactor = dtFactor * f; dtFactor = dtFactor * f;
} }
} }
@ -108,7 +108,7 @@ void HybridBayesNet::updateDiscreteConditionals(
HybridConditional::shared_ptr conditional = this->at(i); HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
// std::cout << demangle(typeid(conditional).name()) << std::endl; // std::cout << demangle(typeid(conditional).name()) << std::endl;
auto discrete = conditional->asDiscreteConditional(); auto discrete = conditional->asDiscrete();
KeyVector frontals(discrete->frontals().begin(), KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end()); discrete->frontals().end());
@ -151,13 +151,10 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Go through all the conditionals in the // Go through all the conditionals in the
// Bayes Net and prune them as per decisionTree. // Bayes Net and prune them as per decisionTree.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isHybrid()) { if (auto gm = conditional->asMixture()) {
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
// Make a copy of the Gaussian mixture and prune it! // Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture = auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(*gm);
boost::make_shared<GaussianMixture>(*gaussianMixture); prunedGaussianMixture->prune(*decisionTree); // imperative :-(
prunedGaussianMixture->prune(*decisionTree);
// Type-erase and add to the pruned Bayes Net fragment. // Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back( prunedBayesNetFragment.push_back(
@ -184,7 +181,7 @@ GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
/* ************************************************************************* */ /* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return at(i)->asDiscreteConditional(); return at(i)->asDiscrete();
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -192,16 +189,13 @@ GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const { const DiscreteValues &assignment) const {
GaussianBayesNet gbn; GaussianBayesNet gbn;
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isHybrid()) { if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment. // If conditional is hybrid, select based on assignment.
GaussianMixture gm = *conditional->asMixture(); gbn.push_back((*gm)(assignment));
gbn.push_back(gm(assignment)); } else if (auto gc = conditional->asGaussian()) {
} else if (conditional->isContinuous()) {
// If continuous only, add Gaussian conditional. // If continuous only, add Gaussian conditional.
gbn.push_back((conditional->asGaussian())); gbn.push_back(gc);
} else if (auto dc = conditional->asDiscrete()) {
} else if (conditional->isDiscrete()) {
// If conditional is discrete-only, we simply continue. // If conditional is discrete-only, we simply continue.
continue; continue;
} }
@ -216,7 +210,7 @@ HybridValues HybridBayesNet::optimize() const {
DiscreteBayesNet discrete_bn; DiscreteBayesNet discrete_bn;
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional()); discrete_bn.push_back(conditional->asDiscrete());
} }
} }
@ -238,26 +232,23 @@ double HybridBayesNet::evaluate(const HybridValues &values) const {
const DiscreteValues &discreteValues = values.discrete(); const DiscreteValues &discreteValues = values.discrete();
const VectorValues &continuousValues = values.continuous(); const VectorValues &continuousValues = values.continuous();
double probability = 1.0; double logDensity = 0.0, probability = 1.0;
// Iterate over each conditional. // Iterate over each conditional.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isHybrid()) { if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment and evaluate. const auto component = (*gm)(discreteValues);
const GaussianMixture::shared_ptr gm = conditional->asMixture(); logDensity += component->logDensity(continuousValues);
const auto conditional = (*gm)(discreteValues); } else if (auto gc = conditional->asGaussian()) {
probability *= conditional->evaluate(continuousValues);
} else if (conditional->isContinuous()) {
// If continuous only, evaluate the probability and multiply. // If continuous only, evaluate the probability and multiply.
probability *= conditional->asGaussian()->evaluate(continuousValues); logDensity += gc->logDensity(continuousValues);
} else if (conditional->isDiscrete()) { } else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, so return its probability. // Conditional is discrete-only, so return its probability.
probability *= probability *= dc->operator()(discreteValues);
conditional->asDiscreteConditional()->operator()(discreteValues);
} }
} }
return probability; return probability * exp(logDensity);
} }
/* ************************************************************************* */ /* ************************************************************************* */
@ -267,7 +258,7 @@ HybridValues HybridBayesNet::sample(const HybridValues &given,
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
// If conditional is discrete-only, we add to the discrete Bayes net. // If conditional is discrete-only, we add to the discrete Bayes net.
dbn.push_back(conditional->asDiscreteConditional()); dbn.push_back(conditional->asDiscrete());
} }
} }
// Sample a discrete assignment. // Sample a discrete assignment.
@ -309,23 +300,20 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
// Iterate over each conditional. // Iterate over each conditional.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isHybrid()) { if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment and compute error. // If conditional is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = conditional->asMixture();
AlgebraicDecisionTree<Key> conditional_error = AlgebraicDecisionTree<Key> conditional_error =
gm->error(continuousValues); gm->error(continuousValues);
error_tree = error_tree + conditional_error; error_tree = error_tree + conditional_error;
} else if (auto gc = conditional->asGaussian()) {
} else if (conditional->isContinuous()) {
// If continuous only, get the (double) error // If continuous only, get the (double) error
// and add it to the error_tree // and add it to the error_tree
double error = conditional->asGaussian()->error(continuousValues); double error = gc->error(continuousValues);
// Add the computed error to every leaf of the error tree. // Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply( error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; }); [error](double leaf_value) { return leaf_value + error; });
} else if (auto dc = conditional->asDiscrete()) {
} else if (conditional->isDiscrete()) {
// Conditional is discrete-only, we skip. // Conditional is discrete-only, we skip.
continue; continue;
} }

View File

@ -49,7 +49,7 @@ HybridValues HybridBayesTree::optimize() const {
// The root should be discrete only, we compute the MPE // The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) { if (root_conditional->isDiscrete()) {
dbn.push_back(root_conditional->asDiscreteConditional()); dbn.push_back(root_conditional->asDiscrete());
mpe = DiscreteFactorGraph(dbn).optimize(); mpe = DiscreteFactorGraph(dbn).optimize();
} else { } else {
throw std::runtime_error( throw std::runtime_error(
@ -147,7 +147,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */ /* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) { void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree = auto decisionTree =
this->roots_.at(0)->conditional()->asDiscreteConditional(); this->roots_.at(0)->conditional()->asDiscrete();
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_; decisionTree->root_ = prunedDecisionTree.root_;

View File

@ -317,8 +317,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
prunedDecisionTree->nrLeaves()); prunedDecisionTree->nrLeaves());
auto original_discrete_conditionals = auto original_discrete_conditionals = *(hybridBayesNet->at(4)->asDiscrete());
*(hybridBayesNet->at(4)->asDiscreteConditional());
// Prune! // Prune!
hybridBayesNet->prune(maxNrLeaves); hybridBayesNet->prune(maxNrLeaves);
@ -338,8 +337,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
}; };
// Get the pruned discrete conditionals as an AlgebraicDecisionTree // Get the pruned discrete conditionals as an AlgebraicDecisionTree
auto pruned_discrete_conditionals = auto pruned_discrete_conditionals = hybridBayesNet->at(4)->asDiscrete();
hybridBayesNet->at(4)->asDiscreteConditional();
auto discrete_conditional_tree = auto discrete_conditional_tree =
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>( boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
pruned_discrete_conditionals); pruned_discrete_conditionals);

View File

@ -133,7 +133,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
auto result = auto result =
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {M(1)})); hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {M(1)}));
auto dc = result->at(2)->asDiscreteConditional(); auto dc = result->at(2)->asDiscrete();
DiscreteValues dv; DiscreteValues dv;
dv[M(1)] = 0; dv[M(1)] = 0;
EXPECT_DOUBLES_EQUAL(1, dc->operator()(dv), 1e-3); EXPECT_DOUBLES_EQUAL(1, dc->operator()(dv), 1e-3);

View File

@ -111,8 +111,7 @@ TEST(HybridGaussianElimination, IncrementalInference) {
// Run update step // Run update step
isam.update(graph1); isam.update(graph1);
auto discreteConditional_m0 = auto discreteConditional_m0 = isam[M(0)]->conditional()->asDiscrete();
isam[M(0)]->conditional()->asDiscreteConditional();
EXPECT(discreteConditional_m0->keys() == KeyVector({M(0)})); EXPECT(discreteConditional_m0->keys() == KeyVector({M(0)}));
/********************************************************/ /********************************************************/
@ -170,10 +169,10 @@ TEST(HybridGaussianElimination, IncrementalInference) {
DiscreteValues m00; DiscreteValues m00;
m00[M(0)] = 0, m00[M(1)] = 0; m00[M(0)] = 0, m00[M(1)] = 0;
DiscreteConditional decisionTree = DiscreteConditional decisionTree =
*(*discreteBayesTree)[M(1)]->conditional()->asDiscreteConditional(); *(*discreteBayesTree)[M(1)]->conditional()->asDiscrete();
double m00_prob = decisionTree(m00); double m00_prob = decisionTree(m00);
auto discreteConditional = isam[M(1)]->conditional()->asDiscreteConditional(); auto discreteConditional = isam[M(1)]->conditional()->asDiscrete();
// Test if the probability values are as expected with regression tests. // Test if the probability values are as expected with regression tests.
DiscreteValues assignment; DiscreteValues assignment;
@ -535,7 +534,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
// The final discrete graph should not be empty since we have eliminated // The final discrete graph should not be empty since we have eliminated
// all continuous variables. // all continuous variables.
auto discreteTree = inc[M(3)]->conditional()->asDiscreteConditional(); auto discreteTree = inc[M(3)]->conditional()->asDiscrete();
EXPECT_LONGS_EQUAL(3, discreteTree->size()); EXPECT_LONGS_EQUAL(3, discreteTree->size());
// Test if the optimal discrete mode assignment is (1, 1, 1). // Test if the optimal discrete mode assignment is (1, 1, 1).

View File

@ -124,8 +124,7 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
isam.update(graph1, initial); isam.update(graph1, initial);
HybridGaussianISAM bayesTree = isam.bayesTree(); HybridGaussianISAM bayesTree = isam.bayesTree();
auto discreteConditional_m0 = auto discreteConditional_m0 = bayesTree[M(0)]->conditional()->asDiscrete();
bayesTree[M(0)]->conditional()->asDiscreteConditional();
EXPECT(discreteConditional_m0->keys() == KeyVector({M(0)})); EXPECT(discreteConditional_m0->keys() == KeyVector({M(0)}));
/********************************************************/ /********************************************************/
@ -187,11 +186,11 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
DiscreteValues m00; DiscreteValues m00;
m00[M(0)] = 0, m00[M(1)] = 0; m00[M(0)] = 0, m00[M(1)] = 0;
DiscreteConditional decisionTree = DiscreteConditional decisionTree =
*(*discreteBayesTree)[M(1)]->conditional()->asDiscreteConditional(); *(*discreteBayesTree)[M(1)]->conditional()->asDiscrete();
double m00_prob = decisionTree(m00); double m00_prob = decisionTree(m00);
auto discreteConditional = auto discreteConditional =
bayesTree[M(1)]->conditional()->asDiscreteConditional(); bayesTree[M(1)]->conditional()->asDiscrete();
// Test if the probability values are as expected with regression tests. // Test if the probability values are as expected with regression tests.
DiscreteValues assignment; DiscreteValues assignment;
@ -558,7 +557,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
// The final discrete graph should not be empty since we have eliminated // The final discrete graph should not be empty since we have eliminated
// all continuous variables. // all continuous variables.
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional(); auto discreteTree = bayesTree[M(3)]->conditional()->asDiscrete();
EXPECT_LONGS_EQUAL(3, discreteTree->size()); EXPECT_LONGS_EQUAL(3, discreteTree->size());
// Test if the optimal discrete mode assignment is (1, 1, 1). // Test if the optimal discrete mode assignment is (1, 1, 1).