Refactor with uniform dynamic pointer cast API
parent
c984a5ffa2
commit
1134d1c88e
|
@ -36,7 +36,7 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
|
|||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
// Convert to a DecisionTreeFactor and add it to the main factor.
|
||||
DecisionTreeFactor f(*conditional->asDiscreteConditional());
|
||||
DecisionTreeFactor f(*conditional->asDiscrete());
|
||||
dtFactor = dtFactor * f;
|
||||
}
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ void HybridBayesNet::updateDiscreteConditionals(
|
|||
HybridConditional::shared_ptr conditional = this->at(i);
|
||||
if (conditional->isDiscrete()) {
|
||||
// std::cout << demangle(typeid(conditional).name()) << std::endl;
|
||||
auto discrete = conditional->asDiscreteConditional();
|
||||
auto discrete = conditional->asDiscrete();
|
||||
KeyVector frontals(discrete->frontals().begin(),
|
||||
discrete->frontals().end());
|
||||
|
||||
|
@ -151,13 +151,10 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
|
|||
// Go through all the conditionals in the
|
||||
// Bayes Net and prune them as per decisionTree.
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isHybrid()) {
|
||||
GaussianMixture::shared_ptr gaussianMixture = conditional->asMixture();
|
||||
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
// Make a copy of the Gaussian mixture and prune it!
|
||||
auto prunedGaussianMixture =
|
||||
boost::make_shared<GaussianMixture>(*gaussianMixture);
|
||||
prunedGaussianMixture->prune(*decisionTree);
|
||||
auto prunedGaussianMixture = boost::make_shared<GaussianMixture>(*gm);
|
||||
prunedGaussianMixture->prune(*decisionTree); // imperative :-(
|
||||
|
||||
// Type-erase and add to the pruned Bayes Net fragment.
|
||||
prunedBayesNetFragment.push_back(
|
||||
|
@ -184,7 +181,7 @@ GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
|||
|
||||
/* ************************************************************************* */
|
||||
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
||||
return at(i)->asDiscreteConditional();
|
||||
return at(i)->asDiscrete();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -192,16 +189,13 @@ GaussianBayesNet HybridBayesNet::choose(
|
|||
const DiscreteValues &assignment) const {
|
||||
GaussianBayesNet gbn;
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isHybrid()) {
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
// If conditional is hybrid, select based on assignment.
|
||||
GaussianMixture gm = *conditional->asMixture();
|
||||
gbn.push_back(gm(assignment));
|
||||
|
||||
} else if (conditional->isContinuous()) {
|
||||
gbn.push_back((*gm)(assignment));
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous only, add Gaussian conditional.
|
||||
gbn.push_back((conditional->asGaussian()));
|
||||
|
||||
} else if (conditional->isDiscrete()) {
|
||||
gbn.push_back(gc);
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
// If conditional is discrete-only, we simply continue.
|
||||
continue;
|
||||
}
|
||||
|
@ -216,7 +210,7 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
DiscreteBayesNet discrete_bn;
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discrete_bn.push_back(conditional->asDiscreteConditional());
|
||||
discrete_bn.push_back(conditional->asDiscrete());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -238,26 +232,23 @@ double HybridBayesNet::evaluate(const HybridValues &values) const {
|
|||
const DiscreteValues &discreteValues = values.discrete();
|
||||
const VectorValues &continuousValues = values.continuous();
|
||||
|
||||
double probability = 1.0;
|
||||
double logDensity = 0.0, probability = 1.0;
|
||||
|
||||
// Iterate over each conditional.
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isHybrid()) {
|
||||
// If conditional is hybrid, select based on assignment and evaluate.
|
||||
const GaussianMixture::shared_ptr gm = conditional->asMixture();
|
||||
const auto conditional = (*gm)(discreteValues);
|
||||
probability *= conditional->evaluate(continuousValues);
|
||||
} else if (conditional->isContinuous()) {
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
const auto component = (*gm)(discreteValues);
|
||||
logDensity += component->logDensity(continuousValues);
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous only, evaluate the probability and multiply.
|
||||
probability *= conditional->asGaussian()->evaluate(continuousValues);
|
||||
} else if (conditional->isDiscrete()) {
|
||||
logDensity += gc->logDensity(continuousValues);
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
// Conditional is discrete-only, so return its probability.
|
||||
probability *=
|
||||
conditional->asDiscreteConditional()->operator()(discreteValues);
|
||||
probability *= dc->operator()(discreteValues);
|
||||
}
|
||||
}
|
||||
|
||||
return probability;
|
||||
return probability * exp(logDensity);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
@ -267,7 +258,7 @@ HybridValues HybridBayesNet::sample(const HybridValues &given,
|
|||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
// If conditional is discrete-only, we add to the discrete Bayes net.
|
||||
dbn.push_back(conditional->asDiscreteConditional());
|
||||
dbn.push_back(conditional->asDiscrete());
|
||||
}
|
||||
}
|
||||
// Sample a discrete assignment.
|
||||
|
@ -309,23 +300,20 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
|||
|
||||
// Iterate over each conditional.
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isHybrid()) {
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
// If conditional is hybrid, select based on assignment and compute error.
|
||||
GaussianMixture::shared_ptr gm = conditional->asMixture();
|
||||
AlgebraicDecisionTree<Key> conditional_error =
|
||||
gm->error(continuousValues);
|
||||
|
||||
error_tree = error_tree + conditional_error;
|
||||
|
||||
} else if (conditional->isContinuous()) {
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous only, get the (double) error
|
||||
// and add it to the error_tree
|
||||
double error = conditional->asGaussian()->error(continuousValues);
|
||||
double error = gc->error(continuousValues);
|
||||
// Add the computed error to every leaf of the error tree.
|
||||
error_tree = error_tree.apply(
|
||||
[error](double leaf_value) { return leaf_value + error; });
|
||||
|
||||
} else if (conditional->isDiscrete()) {
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
// Conditional is discrete-only, we skip.
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ HybridValues HybridBayesTree::optimize() const {
|
|||
|
||||
// The root should be discrete only, we compute the MPE
|
||||
if (root_conditional->isDiscrete()) {
|
||||
dbn.push_back(root_conditional->asDiscreteConditional());
|
||||
dbn.push_back(root_conditional->asDiscrete());
|
||||
mpe = DiscreteFactorGraph(dbn).optimize();
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
|
@ -147,7 +147,7 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
|||
/* ************************************************************************* */
|
||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||
auto decisionTree =
|
||||
this->roots_.at(0)->conditional()->asDiscreteConditional();
|
||||
this->roots_.at(0)->conditional()->asDiscrete();
|
||||
|
||||
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
|
||||
decisionTree->root_ = prunedDecisionTree.root_;
|
||||
|
|
|
@ -317,8 +317,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/,
|
||||
prunedDecisionTree->nrLeaves());
|
||||
|
||||
auto original_discrete_conditionals =
|
||||
*(hybridBayesNet->at(4)->asDiscreteConditional());
|
||||
auto original_discrete_conditionals = *(hybridBayesNet->at(4)->asDiscrete());
|
||||
|
||||
// Prune!
|
||||
hybridBayesNet->prune(maxNrLeaves);
|
||||
|
@ -338,8 +337,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) {
|
|||
};
|
||||
|
||||
// Get the pruned discrete conditionals as an AlgebraicDecisionTree
|
||||
auto pruned_discrete_conditionals =
|
||||
hybridBayesNet->at(4)->asDiscreteConditional();
|
||||
auto pruned_discrete_conditionals = hybridBayesNet->at(4)->asDiscrete();
|
||||
auto discrete_conditional_tree =
|
||||
boost::dynamic_pointer_cast<DecisionTreeFactor::ADT>(
|
||||
pruned_discrete_conditionals);
|
||||
|
|
|
@ -133,7 +133,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) {
|
|||
auto result =
|
||||
hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {M(1)}));
|
||||
|
||||
auto dc = result->at(2)->asDiscreteConditional();
|
||||
auto dc = result->at(2)->asDiscrete();
|
||||
DiscreteValues dv;
|
||||
dv[M(1)] = 0;
|
||||
EXPECT_DOUBLES_EQUAL(1, dc->operator()(dv), 1e-3);
|
||||
|
|
|
@ -111,8 +111,7 @@ TEST(HybridGaussianElimination, IncrementalInference) {
|
|||
// Run update step
|
||||
isam.update(graph1);
|
||||
|
||||
auto discreteConditional_m0 =
|
||||
isam[M(0)]->conditional()->asDiscreteConditional();
|
||||
auto discreteConditional_m0 = isam[M(0)]->conditional()->asDiscrete();
|
||||
EXPECT(discreteConditional_m0->keys() == KeyVector({M(0)}));
|
||||
|
||||
/********************************************************/
|
||||
|
@ -170,10 +169,10 @@ TEST(HybridGaussianElimination, IncrementalInference) {
|
|||
DiscreteValues m00;
|
||||
m00[M(0)] = 0, m00[M(1)] = 0;
|
||||
DiscreteConditional decisionTree =
|
||||
*(*discreteBayesTree)[M(1)]->conditional()->asDiscreteConditional();
|
||||
*(*discreteBayesTree)[M(1)]->conditional()->asDiscrete();
|
||||
double m00_prob = decisionTree(m00);
|
||||
|
||||
auto discreteConditional = isam[M(1)]->conditional()->asDiscreteConditional();
|
||||
auto discreteConditional = isam[M(1)]->conditional()->asDiscrete();
|
||||
|
||||
// Test if the probability values are as expected with regression tests.
|
||||
DiscreteValues assignment;
|
||||
|
@ -535,7 +534,7 @@ TEST(HybridGaussianISAM, NonTrivial) {
|
|||
|
||||
// The final discrete graph should not be empty since we have eliminated
|
||||
// all continuous variables.
|
||||
auto discreteTree = inc[M(3)]->conditional()->asDiscreteConditional();
|
||||
auto discreteTree = inc[M(3)]->conditional()->asDiscrete();
|
||||
EXPECT_LONGS_EQUAL(3, discreteTree->size());
|
||||
|
||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||
|
|
|
@ -124,8 +124,7 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
|
|||
isam.update(graph1, initial);
|
||||
HybridGaussianISAM bayesTree = isam.bayesTree();
|
||||
|
||||
auto discreteConditional_m0 =
|
||||
bayesTree[M(0)]->conditional()->asDiscreteConditional();
|
||||
auto discreteConditional_m0 = bayesTree[M(0)]->conditional()->asDiscrete();
|
||||
EXPECT(discreteConditional_m0->keys() == KeyVector({M(0)}));
|
||||
|
||||
/********************************************************/
|
||||
|
@ -187,11 +186,11 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
|
|||
DiscreteValues m00;
|
||||
m00[M(0)] = 0, m00[M(1)] = 0;
|
||||
DiscreteConditional decisionTree =
|
||||
*(*discreteBayesTree)[M(1)]->conditional()->asDiscreteConditional();
|
||||
*(*discreteBayesTree)[M(1)]->conditional()->asDiscrete();
|
||||
double m00_prob = decisionTree(m00);
|
||||
|
||||
auto discreteConditional =
|
||||
bayesTree[M(1)]->conditional()->asDiscreteConditional();
|
||||
bayesTree[M(1)]->conditional()->asDiscrete();
|
||||
|
||||
// Test if the probability values are as expected with regression tests.
|
||||
DiscreteValues assignment;
|
||||
|
@ -558,7 +557,7 @@ TEST(HybridNonlinearISAM, NonTrivial) {
|
|||
|
||||
// The final discrete graph should not be empty since we have eliminated
|
||||
// all continuous variables.
|
||||
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscreteConditional();
|
||||
auto discreteTree = bayesTree[M(3)]->conditional()->asDiscrete();
|
||||
EXPECT_LONGS_EQUAL(3, discreteTree->size());
|
||||
|
||||
// Test if the optimal discrete mode assignment is (1, 1, 1).
|
||||
|
|
Loading…
Reference in New Issue