Merge pull request #1713 from borglab/model-selection-bayestree
						commit
						cdcc64407e
					
				| 
						 | 
					@ -160,8 +160,8 @@ void GaussianMixture::print(const std::string &s,
 | 
				
			||||||
  for (auto &dk : discreteKeys()) {
 | 
					  for (auto &dk : discreteKeys()) {
 | 
				
			||||||
    std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
 | 
					    std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  std::cout << "\n";
 | 
					  std::cout << std::endl
 | 
				
			||||||
  std::cout << " logNormalizationConstant: " << logConstant_ << "\n"
 | 
					            << " logNormalizationConstant: " << logConstant_ << std::endl
 | 
				
			||||||
            << std::endl;
 | 
					            << std::endl;
 | 
				
			||||||
  conditionals_.print(
 | 
					  conditionals_.print(
 | 
				
			||||||
      "", [&](Key k) { return formatter(k); },
 | 
					      "", [&](Key k) { return formatter(k); },
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -40,17 +40,17 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* ************************************************************************* */
 | 
					/* ************************************************************************* */
 | 
				
			||||||
HybridValues HybridBayesTree::optimize() const {
 | 
					HybridValues HybridBayesTree::optimize() const {
 | 
				
			||||||
  DiscreteBayesNet dbn;
 | 
					  DiscreteFactorGraph discrete_fg;
 | 
				
			||||||
  DiscreteValues mpe;
 | 
					  DiscreteValues mpe;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  auto root = roots_.at(0);
 | 
					  auto root = roots_.at(0);
 | 
				
			||||||
  // Access the clique and get the underlying hybrid conditional
 | 
					  // Access the clique and get the underlying hybrid conditional
 | 
				
			||||||
  HybridConditional::shared_ptr root_conditional = root->conditional();
 | 
					  HybridConditional::shared_ptr root_conditional = root->conditional();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // The root should be discrete only, we compute the MPE
 | 
					  //  The root should be discrete only, we compute the MPE
 | 
				
			||||||
  if (root_conditional->isDiscrete()) {
 | 
					  if (root_conditional->isDiscrete()) {
 | 
				
			||||||
    dbn.push_back(root_conditional->asDiscrete());
 | 
					    discrete_fg.push_back(root_conditional->asDiscrete());
 | 
				
			||||||
    mpe = DiscreteFactorGraph(dbn).optimize();
 | 
					    mpe = discrete_fg.optimize();
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    throw std::runtime_error(
 | 
					    throw std::runtime_error(
 | 
				
			||||||
        "HybridBayesTree root is not discrete-only. Please check elimination "
 | 
					        "HybridBayesTree root is not discrete-only. Please check elimination "
 | 
				
			||||||
| 
						 | 
					@ -136,8 +136,7 @@ struct HybridAssignmentData {
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* *************************************************************************
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 */
 | 
					 | 
				
			||||||
GaussianBayesTree HybridBayesTree::choose(
 | 
					GaussianBayesTree HybridBayesTree::choose(
 | 
				
			||||||
    const DiscreteValues& assignment) const {
 | 
					    const DiscreteValues& assignment) const {
 | 
				
			||||||
  GaussianBayesTree gbt;
 | 
					  GaussianBayesTree gbt;
 | 
				
			||||||
| 
						 | 
					@ -157,8 +156,12 @@ GaussianBayesTree HybridBayesTree::choose(
 | 
				
			||||||
  return gbt;
 | 
					  return gbt;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* *************************************************************************
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 */
 | 
					double HybridBayesTree::error(const HybridValues& values) const {
 | 
				
			||||||
 | 
					  return HybridGaussianFactorGraph(*this).error(values);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
 | 
					VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
 | 
				
			||||||
  GaussianBayesTree gbt = this->choose(assignment);
 | 
					  GaussianBayesTree gbt = this->choose(assignment);
 | 
				
			||||||
  // If empty GaussianBayesTree, means a clique is pruned hence invalid
 | 
					  // If empty GaussianBayesTree, means a clique is pruned hence invalid
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -84,6 +84,9 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
 | 
				
			||||||
   */
 | 
					   */
 | 
				
			||||||
  GaussianBayesTree choose(const DiscreteValues& assignment) const;
 | 
					  GaussianBayesTree choose(const DiscreteValues& assignment) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Error for all conditionals. */
 | 
				
			||||||
 | 
					  double error(const HybridValues& values) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * @brief Optimize the hybrid Bayes tree by computing the MPE for the current
 | 
					   * @brief Optimize the hybrid Bayes tree by computing the MPE for the current
 | 
				
			||||||
   * set of discrete variables and using it to compute the best continuous
 | 
					   * set of discrete variables and using it to compute the best continuous
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -59,6 +59,10 @@ class GTSAM_EXPORT HybridFactorGraph : public FactorGraph<Factor> {
 | 
				
			||||||
  template <class DERIVEDFACTOR>
 | 
					  template <class DERIVEDFACTOR>
 | 
				
			||||||
  HybridFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
 | 
					  HybridFactorGraph(const FactorGraph<DERIVEDFACTOR>& graph) : Base(graph) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Construct from container of factors (shared_ptr or plain objects) */
 | 
				
			||||||
 | 
					  template <class CONTAINER>
 | 
				
			||||||
 | 
					  explicit HybridFactorGraph(const CONTAINER& factors) : Base(factors) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /// @}
 | 
					  /// @}
 | 
				
			||||||
  /// @name Extra methods to inspect discrete/continuous keys.
 | 
					  /// @name Extra methods to inspect discrete/continuous keys.
 | 
				
			||||||
  /// @{
 | 
					  /// @{
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -353,6 +353,8 @@ static std::shared_ptr<Factor> createGaussianMixtureFactor(
 | 
				
			||||||
    if (factor) {
 | 
					    if (factor) {
 | 
				
			||||||
      auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
 | 
					      auto hf = std::dynamic_pointer_cast<HessianFactor>(factor);
 | 
				
			||||||
      if (!hf) throw std::runtime_error("Expected HessianFactor!");
 | 
					      if (!hf) throw std::runtime_error("Expected HessianFactor!");
 | 
				
			||||||
 | 
					      // Add 2.0 term since the constant term will be premultiplied by 0.5
 | 
				
			||||||
 | 
					      // as per the Hessian definition
 | 
				
			||||||
      hf->constantTerm() += 2.0 * conditional->logNormalizationConstant();
 | 
					      hf->constantTerm() += 2.0 * conditional->logNormalizationConstant();
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    return factor;
 | 
					    return factor;
 | 
				
			||||||
| 
						 | 
					@ -586,4 +588,24 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::probPrime(
 | 
				
			||||||
  return prob_tree;
 | 
					  return prob_tree;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					GaussianFactorGraph HybridGaussianFactorGraph::operator()(
 | 
				
			||||||
 | 
					    const DiscreteValues &assignment) const {
 | 
				
			||||||
 | 
					  GaussianFactorGraph gfg;
 | 
				
			||||||
 | 
					  for (auto &&f : *this) {
 | 
				
			||||||
 | 
					    if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(f)) {
 | 
				
			||||||
 | 
					      gfg.push_back(gf);
 | 
				
			||||||
 | 
					    } else if (auto gc = std::dynamic_pointer_cast<GaussianConditional>(f)) {
 | 
				
			||||||
 | 
					      gfg.push_back(gf);
 | 
				
			||||||
 | 
					    } else if (auto gmf = std::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
 | 
				
			||||||
 | 
					      gfg.push_back((*gmf)(assignment));
 | 
				
			||||||
 | 
					    } else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) {
 | 
				
			||||||
 | 
					      gfg.push_back((*gm)(assignment));
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      continue;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return gfg;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace gtsam
 | 
					}  // namespace gtsam
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -126,6 +126,11 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
 | 
				
			||||||
  /// @brief Default constructor.
 | 
					  /// @brief Default constructor.
 | 
				
			||||||
  HybridGaussianFactorGraph() = default;
 | 
					  HybridGaussianFactorGraph() = default;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Construct from container of factors (shared_ptr or plain objects) */
 | 
				
			||||||
 | 
					  template <class CONTAINER>
 | 
				
			||||||
 | 
					  explicit HybridGaussianFactorGraph(const CONTAINER& factors)
 | 
				
			||||||
 | 
					      : Base(factors) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /**
 | 
					  /**
 | 
				
			||||||
   * Implicit copy/downcast constructor to override explicit template container
 | 
					   * Implicit copy/downcast constructor to override explicit template container
 | 
				
			||||||
   * constructor. In BayesTree this is used for:
 | 
					   * constructor. In BayesTree this is used for:
 | 
				
			||||||
| 
						 | 
					@ -213,6 +218,10 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
 | 
				
			||||||
  GaussianFactorGraphTree assembleGraphTree() const;
 | 
					  GaussianFactorGraphTree assembleGraphTree() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /// @}
 | 
					  /// @}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Get the GaussianFactorGraph at a given discrete assignment.
 | 
				
			||||||
 | 
					  GaussianFactorGraph operator()(const DiscreteValues& assignment) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}  // namespace gtsam
 | 
					}  // namespace gtsam
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -92,7 +92,10 @@ class GaussianMixture : gtsam::HybridFactor {
 | 
				
			||||||
                  const std::vector<gtsam::GaussianConditional::shared_ptr>&
 | 
					                  const std::vector<gtsam::GaussianConditional::shared_ptr>&
 | 
				
			||||||
                      conditionalsList);
 | 
					                      conditionalsList);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  gtsam::GaussianMixtureFactor* likelihood(const gtsam::VectorValues &frontals) const;
 | 
					  gtsam::GaussianMixtureFactor* likelihood(
 | 
				
			||||||
 | 
					      const gtsam::VectorValues& frontals) const;
 | 
				
			||||||
 | 
					  double logProbability(const gtsam::HybridValues& values) const;
 | 
				
			||||||
 | 
					  double evaluate(const gtsam::HybridValues& values) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  void print(string s = "GaussianMixture\n",
 | 
					  void print(string s = "GaussianMixture\n",
 | 
				
			||||||
             const gtsam::KeyFormatter& keyFormatter =
 | 
					             const gtsam::KeyFormatter& keyFormatter =
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -490,6 +490,58 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ****************************************************************************/
 | 
				
			||||||
 | 
					// Select a particular continuous factor graph given a discrete assignment
 | 
				
			||||||
 | 
					TEST(HybridGaussianFactorGraph, DiscreteSelection) {
 | 
				
			||||||
 | 
					  Switching s(3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  HybridGaussianFactorGraph graph = s.linearizedFactorGraph;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  DiscreteValues dv00{{M(0), 0}, {M(1), 0}};
 | 
				
			||||||
 | 
					  GaussianFactorGraph continuous_00 = graph(dv00);
 | 
				
			||||||
 | 
					  GaussianFactorGraph expected_00;
 | 
				
			||||||
 | 
					  expected_00.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10)));
 | 
				
			||||||
 | 
					  expected_00.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-1)));
 | 
				
			||||||
 | 
					  expected_00.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-1)));
 | 
				
			||||||
 | 
					  expected_00.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10)));
 | 
				
			||||||
 | 
					  expected_00.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT(assert_equal(expected_00, continuous_00));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  DiscreteValues dv01{{M(0), 0}, {M(1), 1}};
 | 
				
			||||||
 | 
					  GaussianFactorGraph continuous_01 = graph(dv01);
 | 
				
			||||||
 | 
					  GaussianFactorGraph expected_01;
 | 
				
			||||||
 | 
					  expected_01.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10)));
 | 
				
			||||||
 | 
					  expected_01.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-1)));
 | 
				
			||||||
 | 
					  expected_01.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-0)));
 | 
				
			||||||
 | 
					  expected_01.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10)));
 | 
				
			||||||
 | 
					  expected_01.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT(assert_equal(expected_01, continuous_01));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  DiscreteValues dv10{{M(0), 1}, {M(1), 0}};
 | 
				
			||||||
 | 
					  GaussianFactorGraph continuous_10 = graph(dv10);
 | 
				
			||||||
 | 
					  GaussianFactorGraph expected_10;
 | 
				
			||||||
 | 
					  expected_10.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10)));
 | 
				
			||||||
 | 
					  expected_10.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-0)));
 | 
				
			||||||
 | 
					  expected_10.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-1)));
 | 
				
			||||||
 | 
					  expected_10.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10)));
 | 
				
			||||||
 | 
					  expected_10.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT(assert_equal(expected_10, continuous_10));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  DiscreteValues dv11{{M(0), 1}, {M(1), 1}};
 | 
				
			||||||
 | 
					  GaussianFactorGraph continuous_11 = graph(dv11);
 | 
				
			||||||
 | 
					  GaussianFactorGraph expected_11;
 | 
				
			||||||
 | 
					  expected_11.push_back(JacobianFactor(X(0), I_1x1 * 10, Vector1(-10)));
 | 
				
			||||||
 | 
					  expected_11.push_back(JacobianFactor(X(0), -I_1x1, X(1), I_1x1, Vector1(-0)));
 | 
				
			||||||
 | 
					  expected_11.push_back(JacobianFactor(X(1), -I_1x1, X(2), I_1x1, Vector1(-0)));
 | 
				
			||||||
 | 
					  expected_11.push_back(JacobianFactor(X(1), I_1x1 * 10, Vector1(-10)));
 | 
				
			||||||
 | 
					  expected_11.push_back(JacobianFactor(X(2), I_1x1 * 10, Vector1(-10)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  EXPECT(assert_equal(expected_11, continuous_11));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
/* ************************************************************************* */
 | 
					/* ************************************************************************* */
 | 
				
			||||||
TEST(HybridGaussianFactorGraph, optimize) {
 | 
					TEST(HybridGaussianFactorGraph, optimize) {
 | 
				
			||||||
  HybridGaussianFactorGraph hfg;
 | 
					  HybridGaussianFactorGraph hfg;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -121,7 +121,7 @@ namespace gtsam {
 | 
				
			||||||
      const auto mean = solve({});  // solve for mean.
 | 
					      const auto mean = solve({});  // solve for mean.
 | 
				
			||||||
      mean.print("  mean", formatter);
 | 
					      mean.print("  mean", formatter);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    cout << "  logNormalizationConstant: " << logNormalizationConstant() << std::endl;
 | 
					    cout << "  logNormalizationConstant: " << logNormalizationConstant() << endl;
 | 
				
			||||||
    if (model_)
 | 
					    if (model_)
 | 
				
			||||||
      model_->print("  Noise model: ");
 | 
					      model_->print("  Noise model: ");
 | 
				
			||||||
    else
 | 
					    else
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -511,7 +511,7 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
 | 
				
			||||||
  GaussianConditional(size_t key, gtsam::Vector d, gtsam::Matrix R, size_t name1, gtsam::Matrix S,
 | 
					  GaussianConditional(size_t key, gtsam::Vector d, gtsam::Matrix R, size_t name1, gtsam::Matrix S,
 | 
				
			||||||
                      size_t name2, gtsam::Matrix T,
 | 
					                      size_t name2, gtsam::Matrix T,
 | 
				
			||||||
                      const gtsam::noiseModel::Diagonal* sigmas);
 | 
					                      const gtsam::noiseModel::Diagonal* sigmas);
 | 
				
			||||||
  GaussianConditional(const vector<std::pair<gtsam::Key, gtsam::Matrix>> terms,
 | 
					  GaussianConditional(const std::vector<std::pair<gtsam::Key, gtsam::Matrix>> terms,
 | 
				
			||||||
                      size_t nrFrontals, gtsam::Vector d,
 | 
					                      size_t nrFrontals, gtsam::Vector d,
 | 
				
			||||||
                      const gtsam::noiseModel::Diagonal* sigmas);
 | 
					                      const gtsam::noiseModel::Diagonal* sigmas);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -773,4 +773,4 @@ class KalmanFilter {
 | 
				
			||||||
      gtsam::Vector z, gtsam::Matrix Q);
 | 
					      gtsam::Vector z, gtsam::Matrix Q);
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -63,6 +63,6 @@ A RegularJacobianFactor that uses some badly documented reduction on the Jacobia
 | 
				
			||||||
 | 
					
 | 
				
			||||||
A RegularJacobianFactor that eliminates a point using sequential elimination.
 | 
					A RegularJacobianFactor that eliminates a point using sequential elimination.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### JacobianFactorQR
 | 
					### JacobianFactorSVD
 | 
				
			||||||
 | 
					
 | 
				
			||||||
A RegularJacobianFactor that uses the "Nullspace Trick" by Mourikis et al. See the documentation in the file, which *is* well documented.
 | 
					A RegularJacobianFactor that uses the "Nullspace Trick" by Mourikis et al. See the documentation in the file, which *is* well documented.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue