nicer HybridBayesNet::optimize with normalized errors
							parent
							
								
									39f7ac20a1
								
							
						
					
					
						commit
						c374a26b45
					
				| 
						 | 
					@ -223,22 +223,38 @@ HybridValues HybridBayesNet::optimize() const {
 | 
				
			||||||
  DiscreteFactorGraph discrete_fg;
 | 
					  DiscreteFactorGraph discrete_fg;
 | 
				
			||||||
  VectorValues continuousValues;
 | 
					  VectorValues continuousValues;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Error values for each hybrid factor
 | 
				
			||||||
 | 
					  AlgebraicDecisionTree<Key> error(0.0);
 | 
				
			||||||
 | 
					  std::set<DiscreteKey> discreteKeySet;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  for (auto &&conditional : *this) {
 | 
					  for (auto &&conditional : *this) {
 | 
				
			||||||
    if (conditional->isDiscrete()) {
 | 
					    if (conditional->isDiscrete()) {
 | 
				
			||||||
      discrete_fg.push_back(conditional->asDiscrete());
 | 
					      discrete_fg.push_back(conditional->asDiscrete());
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      /*
 | 
					      /*
 | 
				
			||||||
      Perform the integration of L(X;M, Z)P(X|M) which is the model selection
 | 
					      Perform the integration of L(X;M,Z)P(X|M)
 | 
				
			||||||
      term.
 | 
					      which is the model selection term.
 | 
				
			||||||
      TODO(Varun) Write better comments detailing the whole process.
 | 
					
 | 
				
			||||||
 | 
					      By Bayes' rule, P(X|M) ∝ L(X;M,Z)P(X|M),
 | 
				
			||||||
 | 
					      hence L(X;M,Z)P(X|M) is the unnormalized probabilty of
 | 
				
			||||||
 | 
					      the joint Gaussian distribution.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      This can be computed by multiplying all the exponentiaed errors
 | 
				
			||||||
 | 
					      of each of the conditionals, which we do below in hybrid case.
 | 
				
			||||||
      */
 | 
					      */
 | 
				
			||||||
      if (conditional->isContinuous()) {
 | 
					      if (conditional->isContinuous()) {
 | 
				
			||||||
 | 
					        /*
 | 
				
			||||||
 | 
					        If we are here, it means there are no discrete variables in
 | 
				
			||||||
 | 
					        the Bayes net (due to strong elimination ordering).
 | 
				
			||||||
 | 
					        This is a continuous-only problem hence model selection doesn't matter.
 | 
				
			||||||
 | 
					        */
 | 
				
			||||||
        auto gc = conditional->asGaussian();
 | 
					        auto gc = conditional->asGaussian();
 | 
				
			||||||
        for (GaussianConditional::const_iterator frontal = gc->beginFrontals();
 | 
					        for (GaussianConditional::const_iterator frontal = gc->beginFrontals();
 | 
				
			||||||
             frontal != gc->endFrontals(); ++frontal) {
 | 
					             frontal != gc->endFrontals(); ++frontal) {
 | 
				
			||||||
          continuousValues.insert_or_assign(*frontal,
 | 
					          continuousValues.insert_or_assign(*frontal,
 | 
				
			||||||
                                            Vector::Zero(gc->getDim(frontal)));
 | 
					                                            Vector::Zero(gc->getDim(frontal)));
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      } else if (conditional->isHybrid()) {
 | 
					      } else if (conditional->isHybrid()) {
 | 
				
			||||||
        auto gm = conditional->asMixture();
 | 
					        auto gm = conditional->asMixture();
 | 
				
			||||||
        gm->conditionals().apply(
 | 
					        gm->conditionals().apply(
 | 
				
			||||||
| 
						 | 
					@ -253,36 +269,47 @@ HybridValues HybridBayesNet::optimize() const {
 | 
				
			||||||
              return gc;
 | 
					              return gc;
 | 
				
			||||||
            });
 | 
					            });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        DecisionTree<Key, double> error = gm->error(continuousValues);
 | 
					        /*
 | 
				
			||||||
 | 
					        To perform model selection, we need:
 | 
				
			||||||
 | 
					          q(mu; M, Z) * sqrt((2*pi)^n*det(Sigma))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Functional to take error and compute the probability
 | 
					        If q(mu; M, Z) = exp(-error) & k = 1.0 / sqrt((2*pi)^n*det(Sigma))
 | 
				
			||||||
        auto integrate = [&gm](const double &error) {
 | 
					        thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q/k))
 | 
				
			||||||
          // q(mu; M, Z) = exp(-error)
 | 
					        = exp(log(q) - log(k)) = exp(-error - log(k))
 | 
				
			||||||
          // k = 1.0 / sqrt((2*pi)^n*det(Sigma))
 | 
					        = exp(-(error + log(k)))
 | 
				
			||||||
          // thus, q*sqrt(|2*pi*Sigma|) = q/k = exp(log(q) - log(k))
 | 
					
 | 
				
			||||||
          // = exp(-error - log(k))
 | 
					        So let's compute (error + log(k)) and exponentiate later
 | 
				
			||||||
          double prob = std::exp(-error - gm->logNormalizationConstant());
 | 
					        */
 | 
				
			||||||
          if (prob > 1e-12) {
 | 
					        error = error + gm->error(continuousValues);
 | 
				
			||||||
            return prob;
 | 
					
 | 
				
			||||||
          } else {
 | 
					        // Add the logNormalization constant to the error
 | 
				
			||||||
            return 1.0;
 | 
					        // Also compute the mean for normalization (for numerical stability)
 | 
				
			||||||
          }
 | 
					        double mean = 0.0;
 | 
				
			||||||
 | 
					        auto addConstant = [&gm, &mean](const double &error) {
 | 
				
			||||||
 | 
					          double e = error + gm->logNormalizationConstant();
 | 
				
			||||||
 | 
					          mean += e;
 | 
				
			||||||
 | 
					          return e;
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
        AlgebraicDecisionTree<Key> model_selection =
 | 
					        error = error.apply(addConstant);
 | 
				
			||||||
            DecisionTree<Key, double>(error, integrate);
 | 
					        // Normalize by the mean
 | 
				
			||||||
 | 
					        error = error.apply([&mean](double x) { return x / mean; });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        std::cout << "\n\nmodel selection";
 | 
					        // Include the discrete keys
 | 
				
			||||||
        model_selection.print("", DefaultKeyFormatter);
 | 
					        std::copy(gm->discreteKeys().begin(), gm->discreteKeys().end(),
 | 
				
			||||||
        discrete_fg.push_back(
 | 
					                  std::inserter(discreteKeySet, discreteKeySet.end()));
 | 
				
			||||||
            DecisionTreeFactor(gm->discreteKeys(), model_selection));
 | 
					 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  AlgebraicDecisionTree<Key> model_selection = DecisionTree<Key, double>(
 | 
				
			||||||
 | 
					      error, [](const double &error) { return std::exp(-error); });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  discrete_fg.push_back(DecisionTreeFactor(
 | 
				
			||||||
 | 
					      DiscreteKeys(discreteKeySet.begin(), discreteKeySet.end()),
 | 
				
			||||||
 | 
					      model_selection));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Solve for the MPE
 | 
					  // Solve for the MPE
 | 
				
			||||||
  discrete_fg.print();
 | 
					 | 
				
			||||||
  DiscreteValues mpe = discrete_fg.optimize();
 | 
					  DiscreteValues mpe = discrete_fg.optimize();
 | 
				
			||||||
  mpe.print("mpe");
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Given the MPE, compute the optimal continuous values.
 | 
					  // Given the MPE, compute the optimal continuous values.
 | 
				
			||||||
  return HybridValues(optimize(mpe), mpe);
 | 
					  return HybridValues(optimize(mpe), mpe);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue