fix HybridBayesTree::optimize to account for pruned nodes

release/4.3a0
Varun Agrawal 2022-11-09 20:04:21 -05:00
parent 98d3186615
commit 7ae4e57d66
1 changed files with 24 additions and 7 deletions

View File

@ -73,6 +73,8 @@ struct HybridAssignmentData {
GaussianBayesTree::sharedNode parentClique_;
// The gaussian bayes tree that will be recursively created.
GaussianBayesTree* gaussianbayesTree_;
// Flag indicating if all the nodes are valid. Used in optimize().
bool valid_;
/**
* @brief Construct a new Hybrid Assignment Data object.
@ -83,10 +85,13 @@ struct HybridAssignmentData {
*/
HybridAssignmentData(const DiscreteValues& assignment,
const GaussianBayesTree::sharedNode& parentClique,
GaussianBayesTree* gbt)
GaussianBayesTree* gbt, bool valid = true)
: assignment_(assignment),
parentClique_(parentClique),
gaussianbayesTree_(gbt) {}
gaussianbayesTree_(gbt),
valid_(valid) {}
bool isValid() const { return valid_; }
/**
* @brief A function used during tree traversal that operates on each node
@ -101,6 +106,7 @@ struct HybridAssignmentData {
HybridAssignmentData& parentData) {
// Extract the gaussian conditional from the Hybrid clique
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
GaussianConditional::shared_ptr conditional;
if (hybrid_conditional->isHybrid()) {
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
@ -111,15 +117,21 @@ struct HybridAssignmentData {
conditional = boost::make_shared<GaussianConditional>();
}
// Create the GaussianClique for the current node
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);
GaussianBayesTree::sharedNode clique;
if (conditional) {
// Create the GaussianClique for the current node
clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique,
parentData.parentClique_);
} else {
parentData.valid_ = false;
}
// Create new HybridAssignmentData where the current node is the parent
// This will be passed down to the children nodes
HybridAssignmentData data(parentData.assignment_, clique,
parentData.gaussianbayesTree_);
parentData.gaussianbayesTree_, parentData.valid_);
return data;
}
};
@ -138,6 +150,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
visitorPost);
}
if (!rootData.isValid()) {
return VectorValues();
}
VectorValues result = gbt.optimize();
// Return the optimized bayes net result.
@ -151,6 +166,8 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_;
// this->print();
// decisionTree->print("", DefaultKeyFormatter);
/// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData {