fix HybridBayesTree::optimize to account for pruned nodes
parent
98d3186615
commit
7ae4e57d66
|
|
@ -73,6 +73,8 @@ struct HybridAssignmentData {
|
||||||
GaussianBayesTree::sharedNode parentClique_;
|
GaussianBayesTree::sharedNode parentClique_;
|
||||||
// The gaussian bayes tree that will be recursively created.
|
// The gaussian bayes tree that will be recursively created.
|
||||||
GaussianBayesTree* gaussianbayesTree_;
|
GaussianBayesTree* gaussianbayesTree_;
|
||||||
|
// Flag indicating if all the nodes are valid. Used in optimize().
|
||||||
|
bool valid_;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct a new Hybrid Assignment Data object.
|
* @brief Construct a new Hybrid Assignment Data object.
|
||||||
|
|
@ -83,10 +85,13 @@ struct HybridAssignmentData {
|
||||||
*/
|
*/
|
||||||
HybridAssignmentData(const DiscreteValues& assignment,
|
HybridAssignmentData(const DiscreteValues& assignment,
|
||||||
const GaussianBayesTree::sharedNode& parentClique,
|
const GaussianBayesTree::sharedNode& parentClique,
|
||||||
GaussianBayesTree* gbt)
|
GaussianBayesTree* gbt, bool valid = true)
|
||||||
: assignment_(assignment),
|
: assignment_(assignment),
|
||||||
parentClique_(parentClique),
|
parentClique_(parentClique),
|
||||||
gaussianbayesTree_(gbt) {}
|
gaussianbayesTree_(gbt),
|
||||||
|
valid_(valid) {}
|
||||||
|
|
||||||
|
bool isValid() const { return valid_; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A function used during tree traversal that operates on each node
|
* @brief A function used during tree traversal that operates on each node
|
||||||
|
|
@ -101,6 +106,7 @@ struct HybridAssignmentData {
|
||||||
HybridAssignmentData& parentData) {
|
HybridAssignmentData& parentData) {
|
||||||
// Extract the gaussian conditional from the Hybrid clique
|
// Extract the gaussian conditional from the Hybrid clique
|
||||||
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
|
HybridConditional::shared_ptr hybrid_conditional = node->conditional();
|
||||||
|
|
||||||
GaussianConditional::shared_ptr conditional;
|
GaussianConditional::shared_ptr conditional;
|
||||||
if (hybrid_conditional->isHybrid()) {
|
if (hybrid_conditional->isHybrid()) {
|
||||||
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
|
conditional = (*hybrid_conditional->asMixture())(parentData.assignment_);
|
||||||
|
|
@ -111,15 +117,21 @@ struct HybridAssignmentData {
|
||||||
conditional = boost::make_shared<GaussianConditional>();
|
conditional = boost::make_shared<GaussianConditional>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the GaussianClique for the current node
|
GaussianBayesTree::sharedNode clique;
|
||||||
auto clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
|
if (conditional) {
|
||||||
// Add the current clique to the GaussianBayesTree.
|
// Create the GaussianClique for the current node
|
||||||
parentData.gaussianbayesTree_->addClique(clique, parentData.parentClique_);
|
clique = boost::make_shared<GaussianBayesTree::Node>(conditional);
|
||||||
|
// Add the current clique to the GaussianBayesTree.
|
||||||
|
parentData.gaussianbayesTree_->addClique(clique,
|
||||||
|
parentData.parentClique_);
|
||||||
|
} else {
|
||||||
|
parentData.valid_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
// Create new HybridAssignmentData where the current node is the parent
|
// Create new HybridAssignmentData where the current node is the parent
|
||||||
// This will be passed down to the children nodes
|
// This will be passed down to the children nodes
|
||||||
HybridAssignmentData data(parentData.assignment_, clique,
|
HybridAssignmentData data(parentData.assignment_, clique,
|
||||||
parentData.gaussianbayesTree_);
|
parentData.gaussianbayesTree_, parentData.valid_);
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -138,6 +150,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
|
||||||
visitorPost);
|
visitorPost);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!rootData.isValid()) {
|
||||||
|
return VectorValues();
|
||||||
|
}
|
||||||
VectorValues result = gbt.optimize();
|
VectorValues result = gbt.optimize();
|
||||||
|
|
||||||
// Return the optimized bayes net result.
|
// Return the optimized bayes net result.
|
||||||
|
|
@ -151,6 +166,8 @@ void HybridBayesTree::prune(const size_t maxNrLeaves) {
|
||||||
|
|
||||||
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
|
DecisionTreeFactor prunedDecisionTree = decisionTree->prune(maxNrLeaves);
|
||||||
decisionTree->root_ = prunedDecisionTree.root_;
|
decisionTree->root_ = prunedDecisionTree.root_;
|
||||||
|
// this->print();
|
||||||
|
// decisionTree->print("", DefaultKeyFormatter);
|
||||||
|
|
||||||
/// Helper struct for pruning the hybrid bayes tree.
|
/// Helper struct for pruning the hybrid bayes tree.
|
||||||
struct HybridPrunerData {
|
struct HybridPrunerData {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue