add checks

release/4.3a0
Varun Agrawal 2025-01-01 21:50:26 -05:00
parent 5d2d879462
commit 4c5b842c73
2 changed files with 19 additions and 4 deletions

View File

@ -20,7 +20,6 @@
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridValues.h>
#include <memory>
@ -53,7 +52,15 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
for (auto &&conditional : marginal) {
joint = joint * (*conditional);
// The last discrete conditional may be a DiscreteTableConditional
if (auto dtc =
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) {
DiscreteConditional dc(dtc->nrFrontals(),
dtc->table().toDecisionTreeFactor());
joint = joint * dc;
} else {
joint = joint * (*conditional);
}
}
// Prune the joint. NOTE: again, possibly quite expensive.
@ -127,7 +134,14 @@ HybridValues HybridBayesNet::optimize() const {
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_fg.push_back(conditional->asDiscrete());
if (auto dtc = conditional->asDiscrete<DiscreteTableConditional>()) {
// The number of keys should be small so should not
// be expensive to convert to DiscreteConditional.
discrete_fg.push_back(DiscreteConditional(
dtc->nrFrontals(), dtc->table().toDecisionTreeFactor()));
} else {
discrete_fg.push_back(conditional->asDiscrete());
}
}
}

View File

@ -201,7 +201,8 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
auto discreteProbs =
this->roots_.at(0)->conditional()->asDiscrete<DiscreteTableConditional>();
DiscreteConditional::shared_ptr prunedDiscreteProbs =
discreteProbs->prune(maxNrLeaves);