fix HybridBayesTree::optimize to account for pruned nodes

release/4.3a0
Varun Agrawal 2022-11-09 20:04:21 -05:00
parent 98d3186615
commit 7ae4e57d66
1 changed files with 24 additions and 7 deletions

View File

@ -73,6 +73,8 @@ struct HybridAssignmentData {
GaussianBayesTree::sharedNode parentClique_; GaussianBayesTree::sharedNode parentClique_;
// The gaussian bayes tree that will be recursively created. // The gaussian bayes tree that will be recursively created.
GaussianBayesTree* gaussianbayesTree_; GaussianBayesTree* gaussianbayesTree_;
// Flag indicating if all the nodes are valid. Used in optimize().
bool valid_;
/** /**
* @brief Construct a new Hybrid Assignment Data object. * @brief Construct a new Hybrid Assignment Data object.
@ -83,10 +85,13 @@ struct HybridAssignmentData {
*/ */
HybridAssignmentData(const DiscreteValues& assignment, HybridAssignmentData(const DiscreteValues& assignment,
const GaussianBayesTree::sharedNode& parentClique, const GaussianBayesTree::sharedNode& parentClique,
GaussianBayesTree* gbt) GaussianBayesTree* gbt, bool valid = true)
: assignment_(assignment), : assignment_(assignment),
parentClique_(parentClique), parentClique_(parentClique),
gaussianbayesTree_(gbt) {} gaussianbayesTree_(gbt),
valid_(valid) {}
bool isValid() const { return valid_; }
/** /**
* @brief A function used during tree traversal that operates on each node * @brief A function used during tree traversal that operates on each node
@ -101,6 +106,7 @@ struct HybridAssignmentData {
HybridAssignmentData& parentData) { HybridAssignmentData& parentData) {
// Extract the gaussian conditional from the Hybrid clique // Extract the gaussian conditional from the Hybrid clique
HybridConditional::shared_ptr hybrid_conditional = node->conditional(); HybridConditional::shared_ptr hybrid_conditional = node->conditional();
GaussianConditional::shared_ptr conditional; GaussianConditional::shared_ptr conditional;
if (hybrid_conditional->isHybrid()) { if (hybrid_conditional->isHybrid()) {
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_); conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
@ -111,15 +117,21 @@ struct HybridAssignmentData {
conditional = boost::make_shared<GaussianConditional>(); conditional = boost::make_shared<GaussianConditional>();
} }
// Create the GaussianClique for the current node GaussianBayesTree::sharedNode clique;
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional); if (conditional) {
// Add the current clique to the GaussianBayesTree. // Create the GaussianClique for the current node
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_); clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
// Add the current clique to the GaussianBayesTree.
parentData.gaussianbayesTree_->addClique(clique,
parentData.parentClique_);
} else {
parentData.valid_ = false;
}
// Create new HybridAssignmentData where the current node is the parent // Create new HybridAssignmentData where the current node is the parent
// This will be passed down to the children nodes // This will be passed down to the children nodes
HybridAssignmentData data(parentData.assignment_, clique, HybridAssignmentData data(parentData.assignment_, clique,
parentData.gaussianbayesTree_); parentData.gaussianbayesTree_, parentData.valid_);
return data; return data;
} }
}; };
@ -138,6 +150,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
visitorPost); visitorPost);
} }
if (!rootData.isValid()) {
return VectorValues();
}
VectorValues result = gbt.optimize(); VectorValues result = gbt.optimize();
// Return the optimized bayes net result. // Return the optimized bayes net result.
@ -151,6 +166,8 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves); DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
decisionTree->root_ = prunedDecisionTree.root_; decisionTree->root_ = prunedDecisionTree.root_;
// this->print();
// decisionTree->print("", DefaultKeyFormatter);
/// Helper struct for pruning the hybrid bayes tree. /// Helper struct for pruning the hybrid bayes tree.
struct HybridPrunerData { struct HybridPrunerData {