cleaner model selection computation
							parent
							
								
									1e298be3b3
								
							
						
					
					
						commit
						b4f07a0162
					
				| 
						 | 
					@ -342,9 +342,11 @@ HybridValues HybridBayesNet::optimize() const {
 | 
				
			||||||
    for (auto &&f : *this) {
 | 
					    for (auto &&f : *this) {
 | 
				
			||||||
      if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
 | 
					      if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
 | 
				
			||||||
        error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
 | 
					        error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
 | 
					      } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
 | 
				
			||||||
        if (auto gm = hc->asMixture()) {
 | 
					        if (auto gm = hc->asMixture()) {
 | 
				
			||||||
          error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
 | 
					          error += gm->error(HybridValues(mu, DiscreteValues(assignment)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        } else if (auto g = hc->asGaussian()) {
 | 
					        } else if (auto g = hc->asGaussian()) {
 | 
				
			||||||
          error += g->error(mu);
 | 
					          error += g->error(mu);
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
| 
						 | 
					@ -356,11 +358,9 @@ HybridValues HybridBayesNet::optimize() const {
 | 
				
			||||||
  AlgebraicDecisionTree<Key> errorTree =
 | 
					  AlgebraicDecisionTree<Key> errorTree =
 | 
				
			||||||
      DecisionTree<Key, double>(labels, errors);
 | 
					      DecisionTree<Key, double>(labels, errors);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Compute model selection term
 | 
					  // Compute model selection term (with help from ADT methods)
 | 
				
			||||||
  AlgebraicDecisionTree<Key> model_selection_term = errorTree.apply(
 | 
					  AlgebraicDecisionTree<Key> model_selection_term =
 | 
				
			||||||
      [&log_norm_constants](const Assignment<Key> assignment, double err) {
 | 
					      (errorTree + log_norm_constants) * -1;
 | 
				
			||||||
        return -(err + log_norm_constants(assignment));
 | 
					 | 
				
			||||||
      });
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // std::cout << "model selection term" << std::endl;
 | 
					  // std::cout << "model selection term" << std::endl;
 | 
				
			||||||
  // model_selection_term.print("", DefaultKeyFormatter);
 | 
					  // model_selection_term.print("", DefaultKeyFormatter);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue