add checks
parent
5d2d879462
commit
4c5b842c73
|
@ -20,7 +20,6 @@
|
||||||
#include <gtsam/discrete/DiscreteConditional.h>
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -53,7 +52,15 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
||||||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||||
DiscreteConditional joint;
|
DiscreteConditional joint;
|
||||||
for (auto &&conditional : marginal) {
|
for (auto &&conditional : marginal) {
|
||||||
joint = joint * (*conditional);
|
// The last discrete conditional may be a DiscreteTableConditional
|
||||||
|
if (auto dtc =
|
||||||
|
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) {
|
||||||
|
DiscreteConditional dc(dtc->nrFrontals(),
|
||||||
|
dtc->table().toDecisionTreeFactor());
|
||||||
|
joint = joint * dc;
|
||||||
|
} else {
|
||||||
|
joint = joint * (*conditional);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prune the joint. NOTE: again, possibly quite expensive.
|
// Prune the joint. NOTE: again, possibly quite expensive.
|
||||||
|
@ -127,7 +134,14 @@ HybridValues HybridBayesNet::optimize() const {
|
||||||
|
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (conditional->isDiscrete()) {
|
if (conditional->isDiscrete()) {
|
||||||
discrete_fg.push_back(conditional->asDiscrete());
|
if (auto dtc = conditional->asDiscrete<DiscreteTableConditional>()) {
|
||||||
|
// The number of keys should be small so should not
|
||||||
|
// be expensive to convert to DiscreteConditional.
|
||||||
|
discrete_fg.push_back(DiscreteConditional(
|
||||||
|
dtc->nrFrontals(), dtc->table().toDecisionTreeFactor()));
|
||||||
|
} else {
|
||||||
|
discrete_fg.push_back(conditional->asDiscrete());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -201,7 +201,8 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
|
auto discreteProbs =
|
||||||
|
this->roots_.at(0)->conditional()->asDiscrete<DiscreteTableConditional>();
|
||||||
|
|
||||||
DiscreteConditional::shared_ptr prunedDiscreteProbs =
|
DiscreteConditional::shared_ptr prunedDiscreteProbs =
|
||||||
discreteProbs->prune(maxNrLeaves);
|
discreteProbs->prune(maxNrLeaves);
|
||||||
|
|
Loading…
Reference in New Issue