Refactor with uniform dynamic pointer cast API

release/4.3a0
Frank Dellaert 2022-12-28 13:52:59 -05:00
parent c984a5ffa2
commit 1134d1c88e
6 changed files with 39 additions and 55 deletions

View File

@ -36,7 +36,7 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscreteConditional());
DecisionTreeFactor f(*conditional->asDiscrete());
dtFactor = dtFactor * f;
}
}
@ -108,7 +108,7 @@ void HybridBayesNet::updateDiscreteConditionals(
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
// std::cout << demangle(typeid(conditional).name()) << std::endl;
auto discrete = conditional->asDiscreteConditional();
auto discrete = conditional->asDiscrete();
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
@ -151,13 +151,10 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Go through all the conditionals in the
// Bayes Net and prune them as per decisionTree.
for (auto &&conditional : *this) {
if (conditional->isHybrid()) {
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture =
boost::make_shared<GaussianMixture>(*gaussianMixture);
prunedGaussianMixture->prune(*decisionTree);
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(*gm);
prunedGaussianMixture->prune(*decisionTree); // imperative :-(
// Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(
@ -184,7 +181,7 @@ GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return at(i)->asDiscreteConditional();
return at(i)->asDiscrete();
}
/* ************************************************************************* */
@ -192,16 +189,13 @@ GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (auto &&conditional : *this) {
if (conditional->isHybrid()) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment.
GaussianMixture gm = *conditional->asMixture();
gbn.push_back(gm(assignment));
} else if (conditional->isContinuous()) {
gbn.push_back((*gm)(assignment));
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, add Gaussian conditional.
gbn.push_back((conditional->asGaussian()));
} else if (conditional->isDiscrete()) {
gbn.push_back(gc);
} else if (auto dc = conditional->asDiscrete()) {
// If conditional is discrete-only, we simply continue.
continue;
}
@ -216,7 +210,7 @@ HybridValues HybridBayesNet::optimize() const {
DiscreteBayesNet discrete_bn;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional());
discrete_bn.push_back(conditional->asDiscrete());
}
}
@ -238,26 +232,23 @@ double HybridBayesNet::evaluate(const HybridValues &values) const {
const DiscreteValues &discreteValues = values.discrete();
const VectorValues &continuousValues = values.continuous();
double probability = 1.0;
double logDensity = 0.0, probability = 1.0;
// Iterate over each conditional.
for (auto &&conditional : *this) {
if (conditional->isHybrid()) {
// If conditional is hybrid, select based on assignment and evaluate.
const GaussianMixture::shared_ptr gm = conditional->asMixture();
const auto conditional = (*gm)(discreteValues);
probability *= conditional->evaluate(continuousValues);
} else if (conditional->isContinuous()) {
if (auto gm = conditional->asMixture()) {
const auto component = (*gm)(discreteValues);
logDensity += component->logDensity(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, evaluate the probability and multiply.
probability *= conditional->asGaussian()->evaluate(continuousValues);
} else if (conditional->isDiscrete()) {
logDensity += gc->logDensity(continuousValues);
} else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, so return its probability.
probability *=
conditional->asDiscreteConditional()->operator()(discreteValues);
probability *= dc->operator()(discreteValues);
}
}
return probability;
return probability * exp(logDensity);
}
/* ************************************************************************* */
@ -267,7 +258,7 @@ HybridValues HybridBayesNet::sample(const HybridValues &given,
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// If conditional is discrete-only, we add to the discrete Bayes net.
dbn.push_back(conditional->asDiscreteConditional());
dbn.push_back(conditional->asDiscrete());
}
}
// Sample a discrete assignment.
@ -309,23 +300,20 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
// Iterate over each conditional.
for (auto &&conditional : *this) {
if (conditional->isHybrid()) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment and compute error.
GaussianMixture::shared_ptr gm = conditional->asMixture();
AlgebraicDecisionTree<Key> conditional_error =
gm->error(continuousValues);
error_tree = error_tree + conditional_error;
} else if (conditional->isContinuous()) {
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, get the (double) error
// and add it to the error_tree
double error = conditional->asGaussian()->error(continuousValues);
double error = gc->error(continuousValues);
// Add the computed error to every leaf of the error tree.
error_tree = error_tree.apply(
[error](double leaf_value) { return leaf_value + error; });
} else if (conditional->isDiscrete()) {
} else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, we skip.
continue;
}

View File

@ -49,7 +49,7 @@ HybridValues HybridBayesTree::optimize() const {
// The root should be discrete only, we compute the MPE
if (root_conditional->isDiscrete()) {
dbn.push_back(root_conditional->asDiscreteConditional());
dbn.push_back(root_conditional->asDiscrete());
mpe = DiscreteFactorGraph(dbn).optimize();
} else {
throw std::runtime_error(
@ -147,7 +147,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto decisionTree =
this->roots_.at(0)->conditional()->asDiscreteConditional();
this->roots_.at(0)->conditional()->asDiscrete();
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;

View File

@ -317,8 +317,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
prunedDecisionTree->nrLeaves());
auto original_discrete_conditionals =
*(hybridBayesNet->at(4)->asDiscreteConditional());
auto original_discrete_conditionals = *(hybridBayesNet->at(4)->asDiscrete());
// Prune!
hybridBayesNet->prune(maxNrLeaves);
@ -338,8 +337,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
};
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
auto pruned_discrete_conditionals =
hybridBayesNet->at(4)->asDiscreteConditional();
auto pruned_discrete_conditionals = hybridBayesNet->at(4)->asDiscrete();
auto discrete_conditional_tree =
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
pruned_discrete_conditionals);

View File

@ -133,7 +133,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
auto result =
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {M(1)}));
auto dc = result->at(2)->asDiscreteConditional();
auto dc = result->at(2)->asDiscrete();
DiscreteValues dv;
dv[M(1)] = 0;
EXPECT_DOUBLES_EQUAL(1, dc->operator()(dv), 1e-3);

View File

@ -111,8 +111,7 @@ TEST(HybridGaussianElimination, IncrementalInference) {
// Run update step
isam.update(graph1);
auto discreteConditional_m0 =
isam[M(0)]->conditional()->asDiscreteConditional();
auto discreteConditional_m0 = isam[M(0)]->conditional()->asDiscrete();
EXPECT(discreteConditional_m0->keys() == KeyVector({M(0)}));
/********************************************************/
@ -170,10 +169,10 @@ TEST(HybridGaussianElimination, IncrementalInference) {
DiscreteValues m00;
m00[M(0)] = 0, m00[M(1)] = 0;
DiscreteConditional decisionTree =
*(*discreteBayesTree)[M(1)]->conditional()->asDiscreteConditional();
*(*discreteBayesTree)[M(1)]->conditional()->asDiscrete();
double m00_prob = decisionTree(m00);
auto discreteConditional = isam[M(1)]->conditional()->asDiscreteConditional();
auto discreteConditional = isam[M(1)]->conditional()->asDiscrete();
// Test if the probability values are as expected with regression tests.
DiscreteValues assignment;
@ -535,7 +534,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
// The final discrete graph should not be empty since we have eliminated
// all continuous variables.
auto discreteTree = inc[M(3)]->conditional()->asDiscreteConditional();
auto discreteTree = inc[M(3)]->conditional()->asDiscrete();
EXPECT_LONGS_EQUAL(3, discreteTree->size());
// Test if the optimal discrete mode assignment is (1, 1, 1).

View File

@ -124,8 +124,7 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
isam.update(graph1, initial);
HybridGaussianISAM bayesTree = isam.bayesTree();
auto discreteConditional_m0 =
bayesTree[M(0)]->conditional()->asDiscreteConditional();
auto discreteConditional_m0 = bayesTree[M(0)]->conditional()->asDiscrete();
EXPECT(discreteConditional_m0->keys() == KeyVector({M(0)}));
/********************************************************/
@ -187,11 +186,11 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
DiscreteValues m00;
m00[M(0)] = 0, m00[M(1)] = 0;
DiscreteConditional decisionTree =
*(*discreteBayesTree)[M(1)]->conditional()->asDiscreteConditional();
*(*discreteBayesTree)[M(1)]->conditional()->asDiscrete();
double m00_prob = decisionTree(m00);
auto discreteConditional =
bayesTree[M(1)]->conditional()->asDiscreteConditional();
bayesTree[M(1)]->conditional()->asDiscrete();
// Test if the probability values are as expected with regression tests.
DiscreteValues assignment;
@ -558,7 +557,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
// The final discrete graph should not be empty since we have eliminated
// all continuous variables.
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional();
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscrete();
EXPECT_LONGS_EQUAL(3, discreteTree->size());
// Test if the optimal discrete mode assignment is (1, 1, 1).