discreteConditionals returns DiscreteConditional

release/4.3a0
Varun Agrawal 2023-07-09 20:24:24 -04:00
parent f6b1872b13
commit 2940e69a73
2 changed files with 9 additions and 13 deletions

View File

@ -38,21 +38,17 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
}
/* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> discreteProbs;
DiscreteConditional::shared_ptr HybridBayesNet::discreteConditionals() const {
// The canonical decision tree factor which will get
// the discrete conditionals added to it.
DecisionTreeFactor discreteProbsFactor;
DiscreteConditional discreteProbs;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscrete());
discreteProbsFactor = discreteProbsFactor * f;
discreteProbs = discreteProbs * (*conditional->asDiscrete());
}
}
return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
return std::make_shared<DiscreteConditional>(discreteProbs);
}
/* ************************************************************************* */
@ -146,8 +142,8 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
/* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor &prunedDiscreteProbs) {
//TODO(Varun) Should prune the joint conditional, maybe during elimination?
// Loop with index since we need it later.
// TODO(Varun) Should prune the joint conditional, maybe during elimination?
// Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
@ -179,7 +175,7 @@ void HybridBayesNet::updateDiscreteConditionals(
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
DecisionTreeFactor::shared_ptr discreteConditionals =
DiscreteConditional::shared_ptr discreteConditionals =
this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves);

View File

@ -139,9 +139,9 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DecisionTreeFactor::shared_ptr
* @return DiscreteConditional::shared_ptr
*/
DecisionTreeFactor::shared_ptr discreteConditionals() const;
DiscreteConditional::shared_ptr discreteConditionals() const;
/**
* @brief Sample from an incomplete BayesNet, given missing variables.