| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | /*
 | 
					
						
							|  |  |  |  * testCSP.cpp | 
					
						
							|  |  |  |  * @brief develop code for CSP solver | 
					
						
							|  |  |  |  * @date Feb 5, 2012 | 
					
						
							|  |  |  |  * @author Frank Dellaert | 
					
						
							|  |  |  |  */ | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2012-05-15 08:47:19 +08:00
										 |  |  | #include <gtsam_unstable/discrete/CSP.h>
 | 
					
						
							|  |  |  | #include <gtsam_unstable/discrete/Domain.h>
 | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | #include <CppUnitLite/TestHarness.h>
 | 
					
						
							|  |  |  | #include <iostream>
 | 
					
						
							|  |  |  | #include <fstream>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | using namespace std; | 
					
						
							|  |  |  | using namespace gtsam; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /* ************************************************************************* */ | 
					
						
							|  |  |  | TEST_UNSAFE( BinaryAllDif, allInOne) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  | 	// Create keys and ordering
 | 
					
						
							|  |  |  | 	size_t nrColors = 2; | 
					
						
							|  |  |  | //	DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", nrColors);
 | 
					
						
							|  |  |  | 	DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Check construction and conversion
 | 
					
						
							|  |  |  | 	BinaryAllDiff c1(ID, UT); | 
					
						
							|  |  |  | 	DecisionTreeFactor f1(ID & UT, "0 1 1 0"); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(f1,(DecisionTreeFactor)c1)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Check construction and conversion
 | 
					
						
							|  |  |  | 	BinaryAllDiff c2(UT, AZ); | 
					
						
							|  |  |  | 	DecisionTreeFactor f2(UT & AZ, "0 1 1 0"); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(f2,(DecisionTreeFactor)c2)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	DecisionTreeFactor f3 = f1*f2; | 
					
						
							|  |  |  | 	EXPECT(assert_equal(f3,c1*f2)); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(f3,c2*f1)); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /* ************************************************************************* */ | 
					
						
							|  |  |  | TEST_UNSAFE( CSP, allInOne) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  | 	// Create keys and ordering
 | 
					
						
							|  |  |  | 	size_t nrColors = 2; | 
					
						
							|  |  |  | 	DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Create the CSP
 | 
					
						
							|  |  |  | 	CSP csp; | 
					
						
							|  |  |  | 	csp.addAllDiff(ID,UT); | 
					
						
							|  |  |  | 	csp.addAllDiff(UT,AZ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Check an invalid combination, with ID==UT==AZ all same color
 | 
					
						
							|  |  |  | 	DiscreteFactor::Values invalid; | 
					
						
							|  |  |  | 	invalid[ID.first] = 0; | 
					
						
							|  |  |  | 	invalid[UT.first] = 0; | 
					
						
							|  |  |  | 	invalid[AZ.first] = 0; | 
					
						
							| 
									
										
										
										
											2012-05-25 22:51:03 +08:00
										 |  |  | 	EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// Check a valid combination
 | 
					
						
							|  |  |  | 	DiscreteFactor::Values valid; | 
					
						
							|  |  |  | 	valid[ID.first] = 0; | 
					
						
							|  |  |  | 	valid[UT.first] = 1; | 
					
						
							|  |  |  | 	valid[AZ.first] = 0; | 
					
						
							| 
									
										
										
										
											2012-05-25 22:51:03 +08:00
										 |  |  | 	EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// Just for fun, create the product and check it
 | 
					
						
							| 
									
										
										
										
											2012-05-25 22:51:03 +08:00
										 |  |  | 	DecisionTreeFactor product = csp.product(); | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 	// product.dot("product");
 | 
					
						
							| 
									
										
										
										
											2012-05-18 02:08:34 +08:00
										 |  |  | 	DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(expectedProduct,product)); | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// Solve
 | 
					
						
							|  |  |  | 	CSP::sharedValues mpe = csp.optimalAssignment(); | 
					
						
							|  |  |  | 	CSP::Values expected; | 
					
						
							|  |  |  | 	insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(expected,*mpe)); | 
					
						
							| 
									
										
										
										
											2012-05-25 22:51:03 +08:00
										 |  |  | 	EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /* ************************************************************************* */ | 
					
						
							|  |  |  | TEST_UNSAFE( CSP, WesternUS) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  | 	// Create keys
 | 
					
						
							|  |  |  | 	size_t nrColors = 4; | 
					
						
							|  |  |  | 	DiscreteKey | 
					
						
							|  |  |  | 	// Create ordering according to example in ND-CSP.lyx
 | 
					
						
							|  |  |  | 	WA(0, nrColors), OR(3, nrColors), CA(1, nrColors),NV(2, nrColors), | 
					
						
							|  |  |  | 	ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), | 
					
						
							|  |  |  | 	MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Create the CSP
 | 
					
						
							|  |  |  | 	CSP csp; | 
					
						
							|  |  |  | 	csp.addAllDiff(WA,ID); | 
					
						
							|  |  |  | 	csp.addAllDiff(WA,OR); | 
					
						
							|  |  |  | 	csp.addAllDiff(OR,ID); | 
					
						
							|  |  |  | 	csp.addAllDiff(OR,CA); | 
					
						
							|  |  |  | 	csp.addAllDiff(OR,NV); | 
					
						
							|  |  |  | 	csp.addAllDiff(CA,NV); | 
					
						
							|  |  |  | 	csp.addAllDiff(CA,AZ); | 
					
						
							|  |  |  | 	csp.addAllDiff(ID,MT); | 
					
						
							|  |  |  | 	csp.addAllDiff(ID,WY); | 
					
						
							|  |  |  | 	csp.addAllDiff(ID,UT); | 
					
						
							|  |  |  | 	csp.addAllDiff(ID,NV); | 
					
						
							|  |  |  | 	csp.addAllDiff(NV,UT); | 
					
						
							|  |  |  | 	csp.addAllDiff(NV,AZ); | 
					
						
							|  |  |  | 	csp.addAllDiff(UT,WY); | 
					
						
							|  |  |  | 	csp.addAllDiff(UT,CO); | 
					
						
							|  |  |  | 	csp.addAllDiff(UT,NM); | 
					
						
							|  |  |  | 	csp.addAllDiff(UT,AZ); | 
					
						
							|  |  |  | 	csp.addAllDiff(AZ,CO); | 
					
						
							|  |  |  | 	csp.addAllDiff(AZ,NM); | 
					
						
							|  |  |  | 	csp.addAllDiff(MT,WY); | 
					
						
							|  |  |  | 	csp.addAllDiff(WY,CO); | 
					
						
							|  |  |  | 	csp.addAllDiff(CO,NM); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Solve
 | 
					
						
							|  |  |  | 	CSP::sharedValues mpe = csp.optimalAssignment(); | 
					
						
							|  |  |  | 	// GTSAM_PRINT(*mpe);
 | 
					
						
							|  |  |  | 	CSP::Values expected; | 
					
						
							|  |  |  | 	insert(expected) | 
					
						
							|  |  |  | 	(WA.first,1)(CA.first,1)(NV.first,3)(OR.first,0) | 
					
						
							|  |  |  | 	(MT.first,1)(WY.first,0)(NM.first,3)(CO.first,2) | 
					
						
							|  |  |  | 	(ID.first,2)(UT.first,1)(AZ.first,0); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(expected,*mpe)); | 
					
						
							| 
									
										
										
										
											2012-05-25 22:51:03 +08:00
										 |  |  | 	EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// Write out the dual graph for hmetis
 | 
					
						
							|  |  |  | #ifdef DUAL
 | 
					
						
							|  |  |  | 	VariableIndex index(csp); | 
					
						
							|  |  |  | 	index.print("index"); | 
					
						
							|  |  |  | 	ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/US-West-dual.txt"); | 
					
						
							|  |  |  |   index.outputMetisFormat(os); | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /* ************************************************************************* */ | 
					
						
							|  |  |  | TEST_UNSAFE( CSP, AllDiff) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  | 	// Create keys and ordering
 | 
					
						
							|  |  |  | 	size_t nrColors = 3; | 
					
						
							|  |  |  | 	DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Create the CSP
 | 
					
						
							|  |  |  | 	CSP csp; | 
					
						
							|  |  |  | 	vector<DiscreteKey> dkeys; | 
					
						
							|  |  |  | 	dkeys += ID,UT,AZ; | 
					
						
							|  |  |  | 	csp.addAllDiff(dkeys); | 
					
						
							|  |  |  | 	csp.addSingleValue(AZ,2); | 
					
						
							| 
									
										
										
										
											2012-05-25 22:51:03 +08:00
										 |  |  | //	GTSAM_PRINT(csp);
 | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// Check construction and conversion
 | 
					
						
							|  |  |  | 	SingleValue s(AZ,2); | 
					
						
							|  |  |  | 	DecisionTreeFactor f1(AZ,"0 0 1"); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(f1,(DecisionTreeFactor)s)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Check construction and conversion
 | 
					
						
							|  |  |  | 	AllDiff alldiff(dkeys); | 
					
						
							|  |  |  | 	DecisionTreeFactor actual = (DecisionTreeFactor)alldiff; | 
					
						
							|  |  |  | //	GTSAM_PRINT(actual);
 | 
					
						
							|  |  |  | //	actual.dot("actual");
 | 
					
						
							|  |  |  | 	DecisionTreeFactor f2(ID & AZ & UT, | 
					
						
							|  |  |  | 			"0 0 0  0 0 1  0 1 0   0 0 1  0 0 0  1 0 0   0 1 0  1 0 0  0 0 0"); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(f2,actual)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Check an invalid combination, with ID==UT==AZ all same color
 | 
					
						
							|  |  |  | 	DiscreteFactor::Values invalid; | 
					
						
							|  |  |  | 	invalid[ID.first] = 0; | 
					
						
							|  |  |  | 	invalid[UT.first] = 1; | 
					
						
							|  |  |  | 	invalid[AZ.first] = 0; | 
					
						
							| 
									
										
										
										
											2012-05-25 22:51:03 +08:00
										 |  |  | 	EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// Check a valid combination
 | 
					
						
							|  |  |  | 	DiscreteFactor::Values valid; | 
					
						
							|  |  |  | 	valid[ID.first] = 0; | 
					
						
							|  |  |  | 	valid[UT.first] = 1; | 
					
						
							|  |  |  | 	valid[AZ.first] = 2; | 
					
						
							| 
									
										
										
										
											2012-05-25 22:51:03 +08:00
										 |  |  | 	EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// Solve
 | 
					
						
							|  |  |  | 	CSP::sharedValues mpe = csp.optimalAssignment(); | 
					
						
							|  |  |  | 	CSP::Values expected; | 
					
						
							|  |  |  | 	insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(expected,*mpe)); | 
					
						
							| 
									
										
										
										
											2012-05-25 22:51:03 +08:00
										 |  |  | 	EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); | 
					
						
							| 
									
										
										
										
											2012-04-16 06:35:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 	// Arc-consistency
 | 
					
						
							|  |  |  | 	vector<Domain> domains; | 
					
						
							|  |  |  | 	domains += Domain(ID), Domain(AZ), Domain(UT); | 
					
						
							|  |  |  | 	SingleValue singleValue(AZ,2); | 
					
						
							|  |  |  | 	EXPECT(singleValue.ensureArcConsistency(1,domains)); | 
					
						
							|  |  |  | 	EXPECT(alldiff.ensureArcConsistency(0,domains)); | 
					
						
							|  |  |  | 	EXPECT(!alldiff.ensureArcConsistency(1,domains)); | 
					
						
							|  |  |  | 	EXPECT(alldiff.ensureArcConsistency(2,domains)); | 
					
						
							|  |  |  | 	LONGS_EQUAL(2,domains[0].nrValues()); | 
					
						
							|  |  |  | 	LONGS_EQUAL(1,domains[1].nrValues()); | 
					
						
							|  |  |  | 	LONGS_EQUAL(2,domains[2].nrValues()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Parial application, version 1
 | 
					
						
							|  |  |  | 	DiscreteFactor::Values known; | 
					
						
							|  |  |  | 	known[AZ.first] = 2; | 
					
						
							|  |  |  | 	DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known); | 
					
						
							|  |  |  | 	DecisionTreeFactor f3(ID & UT, "0 1 1  1 0 1  1 1 0"); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(f3,reduced1->operator DecisionTreeFactor())); | 
					
						
							|  |  |  | 	DiscreteFactor::shared_ptr reduced2 = singleValue.partiallyApply(known); | 
					
						
							|  |  |  | 	DecisionTreeFactor f4(AZ, "0 0 1"); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(f4,reduced2->operator DecisionTreeFactor())); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// Parial application, version 2
 | 
					
						
							|  |  |  | 	DiscreteFactor::shared_ptr reduced3 = alldiff.partiallyApply(domains); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(f3,reduced3->operator DecisionTreeFactor())); | 
					
						
							|  |  |  | 	DiscreteFactor::shared_ptr reduced4 = singleValue.partiallyApply(domains); | 
					
						
							|  |  |  | 	EXPECT(assert_equal(f4,reduced4->operator DecisionTreeFactor())); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 	// full arc-consistency test
 | 
					
						
							|  |  |  | 	csp.runArcConsistency(nrColors); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /* ************************************************************************* */ | 
					
						
							|  |  |  | int main() { | 
					
						
							|  |  |  | 	TestResult tr; | 
					
						
							|  |  |  | 	return TestRegistry::runAllTests(tr); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | /* ************************************************************************* */ | 
					
						
							|  |  |  | 
 |