add checks

release/4.3a0
Varun Agrawal 2025-01-01 21:50:26 -05:00
parent 5d2d879462
commit 4c5b842c73
2 changed files with 19 additions and 4 deletions

View File

@ -20,7 +20,6 @@
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridValues.h>
#include <memory>
@ -53,8 +52,16 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) const {
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
for (auto &&conditional : marginal) {
// The last discrete conditional may be a DiscreteTableConditional
if (auto dtc =
std::dynamic_pointer_cast<DiscreteTableConditional>(conditional)) {
DiscreteConditional dc(dtc->nrFrontals(),
dtc->table().toDecisionTreeFactor());
joint = joint * dc;
} else {
joint = joint * (*conditional);
}
}
// Prune the joint. NOTE: again, possibly quite expensive.
const DiscreteConditional::shared_ptr pruned = joint.prune(maxNrLeaves);
@ -127,9 +134,16 @@ HybridValues HybridBayesNet::optimize() const {
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
if (auto dtc = conditional->asDiscrete<DiscreteTableConditional>()) {
// The number of keys should be small so should not
// be expensive to convert to DiscreteConditional.
discrete_fg.push_back(DiscreteConditional(
dtc->nrFrontals(), dtc->table().toDecisionTreeFactor()));
} else {
discrete_fg.push_back(conditional->asDiscrete());
}
}
}
// Solve for the MPE
DiscreteValues mpe = discrete_fg.optimize();

View File

@ -201,7 +201,8 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
/* ************************************************************************* */
void HybridBayesTree::prune(const size_t maxNrLeaves) {
auto discreteProbs = this->roots_.at(0)->conditional()->asDiscrete();
auto discreteProbs =
this->roots_.at(0)->conditional()->asDiscrete<DiscreteTableConditional>();
DiscreteConditional::shared_ptr prunedDiscreteProbs =
discreteProbs->prune(maxNrLeaves);