fix HybridBayesTree::optimize to account for pruned nodes
parent
98d3186615
commit
7ae4e57d66
|
|
@ -73,6 +73,8 @@ struct HybridAssignmentData {
|
|||
GaussianBayesTree::sharedNode parentClique_;
|
||||
// The gaussian bayes tree that will be recursively created.
|
||||
GaussianBayesTree* gaussianbayesTree_;
|
||||
// Flag indicating if all the nodes are valid. Used in optimize().
|
||||
bool valid_;
|
||||
|
||||
/**
|
||||
* @brief Construct a new Hybrid Assignment Data object.
|
||||
|
|
@ -83,10 +85,13 @@ struct HybridAssignmentData {
|
|||
*/
|
||||
HybridAssignmentData(const DiscreteValues& assignment,
|
||||
const GaussianBayesTree::sharedNode& parentClique,
|
||||
GaussianBayesTree* gbt)
|
||||
GaussianBayesTree* gbt, bool valid = true)
|
||||
: assignment_(assignment),
|
||||
parentClique_(parentClique),
|
||||
gaussianbayesTree_(gbt) {}
|
||||
gaussianbayesTree_(gbt),
|
||||
valid_(valid) {}
|
||||
|
||||
bool isValid() const { return valid_; }
|
||||
|
||||
/**
|
||||
* @brief A function used during tree traversal that operates on each node
|
||||
|
|
@ -101,6 +106,7 @@ struct HybridAssignmentData {
|
|||
HybridAssignmentData& parentData) {
|
||||
// Extract the gaussian conditional from the Hybrid clique
|
||||
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
|
||||
|
||||
GaussianConditional::shared_ptr conditional;
|
||||
if (hybrid_conditional->isHybrid()) {
|
||||
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
|
||||
|
|
@ -111,15 +117,21 @@ struct HybridAssignmentData {
|
|||
conditional = boost::make_shared<GaussianConditional>();
|
||||
}
|
||||
|
||||
// Create the GaussianClique for the current node
|
||||
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
|
||||
// Add the current clique to the GaussianBayesTree.
|
||||
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);
|
||||
GaussianBayesTree::sharedNode clique;
|
||||
if (conditional) {
|
||||
// Create the GaussianClique for the current node
|
||||
clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
|
||||
// Add the current clique to the GaussianBayesTree.
|
||||
parentData.gaussianbayesTree_->addClique(clique,
|
||||
parentData.parentClique_);
|
||||
} else {
|
||||
parentData.valid_ = false;
|
||||
}
|
||||
|
||||
// Create new HybridAssignmentData where the current node is the parent
|
||||
// This will be passed down to the children nodes
|
||||
HybridAssignmentData data(parentData.assignment_, clique,
|
||||
parentData.gaussianbayesTree_);
|
||||
parentData.gaussianbayesTree_, parentData.valid_);
|
||||
return data;
|
||||
}
|
||||
};
|
||||
|
|
@ -138,6 +150,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
|||
visitorPost);
|
||||
}
|
||||
|
||||
if (!rootData.isValid()) {
|
||||
return VectorValues();
|
||||
}
|
||||
VectorValues result = gbt.optimize();
|
||||
|
||||
// Return the optimized bayes net result.
|
||||
|
|
@ -151,6 +166,8 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
|||
|
||||
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
|
||||
decisionTree->root_ = prunedDecisionTree.root_;
|
||||
// this->print();
|
||||
// decisionTree->print("", DefaultKeyFormatter);
|
||||
|
||||
/// Helper struct for pruning the hybrid bayes tree.
|
||||
struct HybridPrunerData {
|
||||
|
|
|
|||
Loading…
Reference in New Issue