normalize model selection term
							parent
							
								
									7b56c96b43
								
							
						
					
					
						commit
						e549a9b41f
					
				|  | @ -239,7 +239,7 @@ HybridValues HybridBayesNet::optimize() const { | ||||||
|       hence L(X;M,Z)P(X|M) is the unnormalized probabilty of |       hence L(X;M,Z)P(X|M) is the unnormalized probabilty of | ||||||
|       the joint Gaussian distribution. |       the joint Gaussian distribution. | ||||||
| 
 | 
 | ||||||
|       This can be computed by multiplying all the exponentiaed errors |       This can be computed by multiplying all the exponentiated errors | ||||||
|       of each of the conditionals, which we do below in hybrid case. |       of each of the conditionals, which we do below in hybrid case. | ||||||
|       */ |       */ | ||||||
|       if (conditional->isContinuous()) { |       if (conditional->isContinuous()) { | ||||||
|  | @ -288,7 +288,7 @@ HybridValues HybridBayesNet::optimize() const { | ||||||
|         double sum = 0.0; |         double sum = 0.0; | ||||||
|         auto addConstant = [&gm, &sum](const double &error) { |         auto addConstant = [&gm, &sum](const double &error) { | ||||||
|           double e = error + gm->logNormalizationConstant(); |           double e = error + gm->logNormalizationConstant(); | ||||||
|           sum += e; |           sum += std::abs(e); | ||||||
|           return e; |           return e; | ||||||
|         }; |         }; | ||||||
|         error = error.apply(addConstant); |         error = error.apply(addConstant); | ||||||
|  | @ -302,12 +302,17 @@ HybridValues HybridBayesNet::optimize() const { | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   double min_log = error.min(); | ||||||
|   AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>( |   AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>( | ||||||
|       error, [](const double &error) { return std::exp(-error); }); |       error, [&min_log](const double &x) { return std::exp(-(x - min_log)); }); | ||||||
|  |   model_selection = model_selection + exp(-min_log); | ||||||
| 
 | 
 | ||||||
|  |   // Only add model_selection if we have discrete keys
 | ||||||
|  |   if (discreteKeySet.size() > 0) { | ||||||
|     discrete_fg.push_back(DecisionTreeFactor( |     discrete_fg.push_back(DecisionTreeFactor( | ||||||
|         DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()), |         DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()), | ||||||
|         model_selection)); |         model_selection)); | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|   // Solve for the MPE
 |   // Solve for the MPE
 | ||||||
|   DiscreteValues mpe = discrete_fg.optimize(); |   DiscreteValues mpe = discrete_fg.optimize(); | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue