add model_selection method to HybridBayesNet

release/4.3a0
Varun Agrawal 2024-01-03 16:32:21 -05:00
parent 82e0c0dae1
commit 8a61c49bb3
2 changed files with 37 additions and 26 deletions

View File

@ -283,35 +283,20 @@ GaussianBayesNetValTree HybridBayesNet::assembleTree() const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const { AlgebraicDecisionTree<Key> HybridBayesNet::model_selection() const {
// Collect all the discrete factors to compute MPE
DiscreteFactorGraph discrete_fg;
/* /*
Perform the integration of L(X;M,Z)P(X|M) To perform model selection, we need:
which is the model selection term. q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
By Bayes' rule, P(X|M,Z) L(X;M,Z)P(X|M), If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
the joint Gaussian distribution. = exp(log(q) - log(k)) = exp(-error - log(k))
= exp(-(error + log(k))),
where error is computed at the corresponding MAP point, gbn.error(mu).
This can be computed by multiplying all the exponentiated errors So we compute (error + log(k)) and exponentiate later
of each of the conditionals, which we do below in hybrid case. */
*/
/*
To perform model selection, we need:
q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
thus, q * sqrt((2*pi)^n*det(Sigma)) = q/k = exp(log(q/k))
= exp(log(q) - log(k)) = exp(-error - log(k))
= exp(-(error + log(k))),
where error is computed at the corresponding MAP point, gbn.error(mu).
So we compute (error + log(k)) and exponentiate later
*/
std::set<DiscreteKey> discreteKeySet;
GaussianBayesNetValTree bnTree = assembleTree(); GaussianBayesNetValTree bnTree = assembleTree();
GaussianBayesNetValTree bn_error = bnTree.apply( GaussianBayesNetValTree bn_error = bnTree.apply(
@ -356,6 +341,19 @@ HybridValues HybridBayesNet::optimize() const {
[&max_log](const double &x) { return std::exp(x - max_log); }); [&max_log](const double &x) { return std::exp(x - max_log); });
model_selection = model_selection.normalize(model_selection.sum()); model_selection = model_selection.normalize(model_selection.sum());
return model_selection;
}
/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Collect all the discrete factors to compute MPE
DiscreteFactorGraph discrete_fg;
// Compute model selection term
AlgebraicDecisionTree<Key> model_selection_term = model_selection();
// Get the set of all discrete keys involved in model selection
std::set<DiscreteKey> discreteKeySet;
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (conditional->isDiscrete()) { if (conditional->isDiscrete()) {
discrete_fg.push_back(conditional->asDiscrete()); discrete_fg.push_back(conditional->asDiscrete());
@ -380,7 +378,7 @@ HybridValues HybridBayesNet::optimize() const {
if (discreteKeySet.size() > 0) { if (discreteKeySet.size() > 0) {
discrete_fg.push_back(DecisionTreeFactor( discrete_fg.push_back(DecisionTreeFactor(
DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()), DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()),
model_selection)); model_selection_term));
} }
// Solve for the MPE // Solve for the MPE

View File

@ -120,6 +120,19 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
GaussianBayesNetValTree assembleTree() const; GaussianBayesNetValTree assembleTree() const;
/*
Perform the integration of L(X;M,Z)P(X|M)
which is the model selection term.
By Bayes' rule, P(X|M,Z) L(X;M,Z)P(X|M),
hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
the joint Gaussian distribution.
This can be computed by multiplying all the exponentiated errors
of each of the conditionals.
*/
AlgebraicDecisionTree<Key> model_selection() const;
/** /**
* @brief Solve the HybridBayesNet by first computing the MPE of all the * @brief Solve the HybridBayesNet by first computing the MPE of all the
* discrete variables and then optimizing the continuous variables based on * discrete variables and then optimizing the continuous variables based on