| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | /*
 | 
					
						
							|  |  |  |  * AllDiff.cpp | 
					
						
							|  |  |  |  * @brief General "all-different" constraint | 
					
						
							|  |  |  |  * @date Feb 6, 2012 | 
					
						
							|  |  |  |  * @author Frank Dellaert | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-18 23:17:49 +08:00
										 |  |  | #include <gtsam/base/Testable.h>
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  | #include <gtsam_unstable/discrete/AllDiff.h>
 | 
					
						
							|  |  |  | #include <gtsam_unstable/discrete/Domain.h>
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-13 05:43:29 +08:00
										 |  |  | #include <optional>
 | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace gtsam { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  | /* ************************************************************************* */ | 
					
						
							|  |  |  | AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) { | 
					
						
							|  |  |  |   for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey); | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  | /* ************************************************************************* */ | 
					
						
							|  |  |  | void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { | 
					
						
							|  |  |  |   std::cout << s << "AllDiff on "; | 
					
						
							|  |  |  |   for (Key dkey : keys_) std::cout << formatter(dkey) << " "; | 
					
						
							|  |  |  |   std::cout << std::endl; | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  | /* ************************************************************************* */ | 
					
						
							| 
									
										
										
										
											2021-12-14 02:46:53 +08:00
										 |  |  | double AllDiff::operator()(const DiscreteValues& values) const { | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  |   std::set<size_t> taken;  // record values taken by keys
 | 
					
						
							|  |  |  |   for (Key dkey : keys_) { | 
					
						
							|  |  |  |     size_t value = values.at(dkey);      // get the value for that key
 | 
					
						
							|  |  |  |     if (taken.count(value)) return 0.0;  // check if value alreday taken
 | 
					
						
							|  |  |  |     taken.insert(value);  // if not, record it as taken and keep checking
 | 
					
						
							| 
									
										
										
										
											2021-11-18 23:17:49 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  |   return 1.0; | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2021-11-18 23:17:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  | /* ************************************************************************* */ | 
					
						
							|  |  |  | DecisionTreeFactor AllDiff::toDecisionTreeFactor() const { | 
					
						
							|  |  |  |   // We will do this by converting the allDif into many BinaryAllDiff
 | 
					
						
							|  |  |  |   // constraints
 | 
					
						
							|  |  |  |   DecisionTreeFactor converted; | 
					
						
							|  |  |  |   size_t nrKeys = keys_.size(); | 
					
						
							|  |  |  |   for (size_t i1 = 0; i1 < nrKeys; i1++) | 
					
						
							|  |  |  |     for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { | 
					
						
							|  |  |  |       BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2)); | 
					
						
							|  |  |  |       converted = converted * binary12.toDecisionTreeFactor(); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   return converted; | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  | /* ************************************************************************* */ | 
					
						
							|  |  |  | DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { | 
					
						
							|  |  |  |   // TODO: can we do this more efficiently?
 | 
					
						
							|  |  |  |   return toDecisionTreeFactor() * f; | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  | /* ************************************************************************* */ | 
					
						
							| 
									
										
										
										
											2021-11-21 04:52:12 +08:00
										 |  |  | bool AllDiff::ensureArcConsistency(Key j, Domains* domains) const { | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  |   Domain& Dj = domains->at(j); | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  |   // Though strictly not part of allDiff, we check for
 | 
					
						
							| 
									
										
										
										
											2021-11-21 04:52:12 +08:00
										 |  |  |   // a value in domains->at(j) that does not occur in any other connected domain.
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  |   // If found, we make this a singleton...
 | 
					
						
							|  |  |  |   // TODO: make a new constraint where this really is true
 | 
					
						
							| 
									
										
										
										
											2023-01-13 05:43:29 +08:00
										 |  |  |   std::optional<Domain> maybeChanged = Dj.checkAllDiff(keys_, *domains); | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  |   if (maybeChanged) { | 
					
						
							|  |  |  |     Dj = *maybeChanged; | 
					
						
							|  |  |  |     return true; | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2021-11-18 23:17:49 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  |   // Check all other domains for singletons and erase corresponding values.
 | 
					
						
							|  |  |  |   // This is the same as arc-consistency on the equivalent binary constraints
 | 
					
						
							|  |  |  |   bool changed = false; | 
					
						
							|  |  |  |   for (Key k : keys_) | 
					
						
							|  |  |  |     if (k != j) { | 
					
						
							|  |  |  |       const Domain& Dk = domains->at(k); | 
					
						
							|  |  |  |       if (Dk.isSingleton()) {  // check if singleton
 | 
					
						
							|  |  |  |         size_t value = Dk.firstValue(); | 
					
						
							|  |  |  |         if (Dj.contains(value)) { | 
					
						
							|  |  |  |           Dj.erase(value);  // erase value if true
 | 
					
						
							|  |  |  |           changed = true; | 
					
						
							| 
									
										
										
										
											2012-10-02 22:40:07 +08:00
										 |  |  |         } | 
					
						
							|  |  |  |       } | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |   return changed; | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  | /* ************************************************************************* */ | 
					
						
							| 
									
										
										
										
											2021-12-14 02:46:53 +08:00
										 |  |  | Constraint::shared_ptr AllDiff::partiallyApply(const DiscreteValues& values) const { | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  |   DiscreteKeys newKeys; | 
					
						
							|  |  |  |   // loop over keys and add them only if they do not appear in values
 | 
					
						
							|  |  |  |   for (Key k : keys_) | 
					
						
							|  |  |  |     if (values.find(k) == values.end()) { | 
					
						
							|  |  |  |       newKeys.push_back(DiscreteKey(k, cardinalities_.at(k))); | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-01-18 06:05:12 +08:00
										 |  |  |   return std::make_shared<AllDiff>(newKeys); | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  | } | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  | /* ************************************************************************* */ | 
					
						
							|  |  |  | Constraint::shared_ptr AllDiff::partiallyApply( | 
					
						
							| 
									
										
										
										
											2021-11-21 04:52:12 +08:00
										 |  |  |     const Domains& domains) const { | 
					
						
							| 
									
										
										
										
											2021-12-14 02:46:53 +08:00
										 |  |  |   DiscreteValues known; | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  |   for (Key k : keys_) { | 
					
						
							| 
									
										
										
										
											2021-11-21 04:52:12 +08:00
										 |  |  |     const Domain& Dk = domains.at(k); | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  |     if (Dk.isSingleton()) known[k] = Dk.firstValue(); | 
					
						
							| 
									
										
										
										
											2012-10-02 22:40:07 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  |   return partiallyApply(known); | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-19 04:08:01 +08:00
										 |  |  | /* ************************************************************************* */ | 
					
						
							|  |  |  | }  // namespace gtsam
 |