add checks
parent
5d2d879462
commit
4c5b842c73
|
@ -20,7 +20,6 @@
|
|||
#include <gtsam/discrete/DiscreteConditional.h>
|
||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
#include <memory>
|
||||
|
@ -53,7 +52,15 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
|
|||
// Multiply into one big conditional. NOTE: possibly quite expensive.
|
||||
DiscreteConditional joint;
|
||||
for (auto &&conditional : marginal) {
|
||||
joint = joint * (*conditional);
|
||||
// The last discrete conditional may be a DiscreteTableConditional
|
||||
if (auto dtc =
|
||||
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) {
|
||||
DiscreteConditional dc(dtc->nrFrontals(),
|
||||
dtc->table().toDecisionTreeFactor());
|
||||
joint = joint * dc;
|
||||
} else {
|
||||
joint = joint * (*conditional);
|
||||
}
|
||||
}
|
||||
|
||||
// Prune the joint. NOTE: again, possibly quite expensive.
|
||||
|
@ -127,7 +134,14 @@ HybridValues HybridBayesNet::optimize() const {
|
|||
|
||||
for (auto &&conditional : *this) {
|
||||
if (conditional->isDiscrete()) {
|
||||
discrete_fg.push_back(conditional->asDiscrete());
|
||||
if (auto dtc = conditional->asDiscrete<DiscreteTableConditional>()) {
|
||||
// The number of keys should be small so should not
|
||||
// be expensive to convert to DiscreteConditional.
|
||||
discrete_fg.push_back(DiscreteConditional(
|
||||
dtc->nrFrontals(), dtc->table().toDecisionTreeFactor()));
|
||||
} else {
|
||||
discrete_fg.push_back(conditional->asDiscrete());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -201,7 +201,8 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
|||
|
||||
/* ************************************************************************* */
|
||||
void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
|
||||
auto discreteProbs =
|
||||
this->roots_.at(0)->conditional()->asDiscrete<DiscreteTableConditional>();
|
||||
|
||||
DiscreteConditional::shared_ptr prunedDiscreteProbs =
|
||||
discreteProbs->prune(maxNrLeaves);
|
||||
|
|
Loading…
Reference in New Issue