Merge branch 'hybrid-custom-discrete' into discrete-table-conditional
						commit
						782f39a0e2
					
				|  | @ -47,6 +47,15 @@ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, | |||
|                                          const DecisionTreeFactor& f) | ||||
|     : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| DiscreteConditional::DiscreteConditional(size_t nrFrontals, | ||||
|                                          const DecisionTreeFactor& f, | ||||
|                                          const Ordering& orderedKeys) | ||||
|     : BaseFactor(f), BaseConditional(nrFrontals) { | ||||
|   keys_.clear(); | ||||
|   keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************** */ | ||||
| DiscreteConditional::DiscreteConditional(size_t nrFrontals, | ||||
|                                          const DiscreteKeys& keys, | ||||
|  |  | |||
|  | @ -56,6 +56,17 @@ class GTSAM_EXPORT DiscreteConditional | |||
|   /// Construct from factor, taking the first `nFrontals` keys as frontals.
 | ||||
|   DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Construct from DecisionTreeFactor, | ||||
|    * taking the first `nrFrontals` from `orderedKeys`. | ||||
|    * | ||||
|    * @param nrFrontals The number of frontal variables. | ||||
|    * @param f The DecisionTreeFactor to construct from. | ||||
|    * @param orderedKeys Ordered list of keys involved in the conditional. | ||||
|    */ | ||||
|   DiscreteConditional(size_t nrFrontals, const DecisionTreeFactor& f, | ||||
|                       const Ordering& orderedKeys); | ||||
| 
 | ||||
|   /**
 | ||||
|    * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first | ||||
|    * `nFrontals` keys as frontals, in the order given. | ||||
|  |  | |||
|  | @ -252,6 +252,15 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { | |||
| DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { | ||||
|   DiscreteKeys dkeys = discreteKeys(); | ||||
| 
 | ||||
|   // If no keys, then return empty DecisionTreeFactor
 | ||||
|   if (dkeys.size() == 0) { | ||||
|     AlgebraicDecisionTree<Key> tree; | ||||
|     if (sparse_table_.size() != 0) { | ||||
|       tree = AlgebraicDecisionTree<Key>(sparse_table_.coeff(0)); | ||||
|     } | ||||
|     return DecisionTreeFactor(dkeys, tree); | ||||
|   } | ||||
| 
 | ||||
|   std::vector<double> table; | ||||
|   for (auto i = 0; i < sparse_table_.size(); i++) { | ||||
|     table.push_back(sparse_table_.coeff(i)); | ||||
|  |  | |||
|  | @ -360,12 +360,16 @@ discreteElimination(const HybridGaussianFactorGraph &factors, | |||
|   // All the discrete variables should form a single clique,
 | ||||
|   // so we can sum out on all the variables as frontals.
 | ||||
|   // This should give an empty separator.
 | ||||
|   Ordering orderedKeys(product.keys()); | ||||
|   TableFactor::shared_ptr sum = product.sum(orderedKeys); | ||||
|   TableFactor::shared_ptr sum = product.sum(frontalKeys); | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttoc_(EliminateDiscreteSum); | ||||
| #endif | ||||
| 
 | ||||
|   // Ordering keys for the conditional so that frontalKeys are really in front
 | ||||
|   Ordering orderedKeys; | ||||
|   orderedKeys.insert(orderedKeys.end(), frontalKeys.begin(), frontalKeys.end()); | ||||
|   orderedKeys.insert(orderedKeys.end(), sum->keys().begin(), sum->keys().end()); | ||||
| 
 | ||||
| #if GTSAM_HYBRID_TIMING | ||||
|   gttic_(EliminateDiscreteFormDiscreteConditional); | ||||
| #endif | ||||
|  |  | |||
|  | @ -169,6 +169,7 @@ TEST(GaussianMixture, GaussianMixtureModel2) { | |||
|     EXPECT_DOUBLES_EQUAL(expected, posterior2(m1Assignment), 1e-8); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() { | ||||
|   TestResult tr; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue