265 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			C++
		
	
	
			
		
		
	
	
			265 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			C++
		
	
	
/*
 | 
						|
 * testSudoku.cpp
 | 
						|
 * @brief develop code for Sudoku CSP solver
 | 
						|
 * @date Jan 29, 2012
 | 
						|
 * @author Frank Dellaert
 | 
						|
 */
 | 
						|
 | 
						|
#include <CppUnitLite/TestHarness.h>
 | 
						|
#include <gtsam/inference/Symbol.h>
 | 
						|
#include <gtsam_unstable/discrete/CSP.h>
 | 
						|
 | 
						|
#include <stdarg.h>
 | 
						|
 | 
						|
#include <iostream>
 | 
						|
#include <sstream>
 | 
						|
 | 
						|
using namespace std;
 | 
						|
using namespace gtsam;
 | 
						|
 | 
						|
#define PRINT false
 | 
						|
 | 
						|
/// A class that encodes Sudoku's as a CSP problem
 | 
						|
class Sudoku : public CSP {
 | 
						|
  size_t n_;  ///< Side of Sudoku, e.g. 4 or 9
 | 
						|
 | 
						|
  /// Mapping from base i,j coordinates to discrete keys:
 | 
						|
  using IJ = std::pair<size_t, size_t>;
 | 
						|
  std::map<IJ, DiscreteKey> dkeys_;
 | 
						|
 | 
						|
 public:
 | 
						|
  /// return DiscreteKey for cell(i,j)
 | 
						|
  const DiscreteKey& dkey(size_t i, size_t j) const {
 | 
						|
    return dkeys_.at(IJ(i, j));
 | 
						|
  }
 | 
						|
 | 
						|
  /// return Key for cell(i,j)
 | 
						|
  Key key(size_t i, size_t j) const { return dkey(i, j).first; }
 | 
						|
 | 
						|
  /// Constructor
 | 
						|
  Sudoku(size_t n, ...) : n_(n) {
 | 
						|
    // Create variables, ordering, and unary constraints
 | 
						|
    va_list ap;
 | 
						|
    va_start(ap, n);
 | 
						|
    for (size_t i = 0; i < n; ++i) {
 | 
						|
      for (size_t j = 0; j < n; ++j) {
 | 
						|
        // create the key
 | 
						|
        IJ ij(i, j);
 | 
						|
        Symbol key('1' + i, j + 1);
 | 
						|
        dkeys_[ij] = DiscreteKey(key, n);
 | 
						|
        // get the unary constraint, if any
 | 
						|
        int value = va_arg(ap, int);
 | 
						|
        if (value != 0) addSingleValue(dkeys_[ij], value - 1);
 | 
						|
      }
 | 
						|
      // cout << endl;
 | 
						|
    }
 | 
						|
    va_end(ap);
 | 
						|
 | 
						|
    // add row constraints
 | 
						|
    for (size_t i = 0; i < n; i++) {
 | 
						|
      DiscreteKeys dkeys;
 | 
						|
      for (size_t j = 0; j < n; j++) dkeys.push_back(dkey(i, j));
 | 
						|
      addAllDiff(dkeys);
 | 
						|
    }
 | 
						|
 | 
						|
    // add col constraints
 | 
						|
    for (size_t j = 0; j < n; j++) {
 | 
						|
      DiscreteKeys dkeys;
 | 
						|
      for (size_t i = 0; i < n; i++) dkeys.push_back(dkey(i, j));
 | 
						|
      addAllDiff(dkeys);
 | 
						|
    }
 | 
						|
 | 
						|
    // add box constraints
 | 
						|
    size_t N = (size_t)sqrt(double(n)), i0 = 0;
 | 
						|
    for (size_t I = 0; I < N; I++) {
 | 
						|
      size_t j0 = 0;
 | 
						|
      for (size_t J = 0; J < N; J++) {
 | 
						|
        // Box I,J
 | 
						|
        DiscreteKeys dkeys;
 | 
						|
        for (size_t i = i0; i < i0 + N; i++)
 | 
						|
          for (size_t j = j0; j < j0 + N; j++) dkeys.push_back(dkey(i, j));
 | 
						|
        addAllDiff(dkeys);
 | 
						|
        j0 += N;
 | 
						|
      }
 | 
						|
      i0 += N;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  /// Print readable form of assignment
 | 
						|
  void printAssignment(const DiscreteValues& assignment) const {
 | 
						|
    for (size_t i = 0; i < n_; i++) {
 | 
						|
      for (size_t j = 0; j < n_; j++) {
 | 
						|
        Key k = key(i, j);
 | 
						|
        cout << 1 + assignment.at(k) << " ";
 | 
						|
      }
 | 
						|
      cout << endl;
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  /// solve and print solution
 | 
						|
  void printSolution() const {
 | 
						|
    auto MPE = optimize();
 | 
						|
    printAssignment(MPE);
 | 
						|
  }
 | 
						|
 | 
						|
  // Print domain
 | 
						|
  void printDomains(const Domains& domains) {
 | 
						|
    for (size_t i = 0; i < n_; i++) {
 | 
						|
      for (size_t j = 0; j < n_; j++) {
 | 
						|
        Key k = key(i, j);
 | 
						|
        cout << domains.at(k).base1Str();
 | 
						|
        cout << "\t";
 | 
						|
      }  // i
 | 
						|
      cout << endl;
 | 
						|
    }  // j
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
/* ************************************************************************* */
 | 
						|
TEST(Sudoku, small) {
 | 
						|
  Sudoku csp(4,           //
 | 
						|
             1, 0, 0, 4,  //
 | 
						|
             0, 0, 0, 0,  //
 | 
						|
             4, 0, 2, 0,  //
 | 
						|
             0, 1, 0, 0);
 | 
						|
 | 
						|
  // optimize and check
 | 
						|
  auto solution = csp.optimize();
 | 
						|
  DiscreteValues expected;
 | 
						|
  expected = {{csp.key(0, 0), 0}, {csp.key(0, 1), 1},
 | 
						|
              {csp.key(0, 2), 2}, {csp.key(0, 3), 3},  //
 | 
						|
              {csp.key(1, 0), 2}, {csp.key(1, 1), 3},
 | 
						|
              {csp.key(1, 2), 0}, {csp.key(1, 3), 1},  //
 | 
						|
              {csp.key(2, 0), 3}, {csp.key(2, 1), 2},
 | 
						|
              {csp.key(2, 2), 1}, {csp.key(2, 3), 0},  //
 | 
						|
              {csp.key(3, 0), 1}, {csp.key(3, 1), 0},
 | 
						|
              {csp.key(3, 2), 3}, {csp.key(3, 3), 2}};
 | 
						|
  EXPECT(assert_equal(expected, solution));
 | 
						|
  // csp.printAssignment(solution);
 | 
						|
 | 
						|
  // Do BP (AC1)
 | 
						|
  auto domains = csp.runArcConsistency(4, 3);
 | 
						|
  // csp.printDomains(domains);
 | 
						|
  Domain domain44 = domains.at(Symbol('4', 4));
 | 
						|
  EXPECT_LONGS_EQUAL(1, domain44.nrValues());
 | 
						|
 | 
						|
  // Test Creation of a new, simpler CSP
 | 
						|
  CSP new_csp = csp.partiallyApply(domains);
 | 
						|
  // Should only be 16 new Domains
 | 
						|
  EXPECT_LONGS_EQUAL(16, new_csp.size());
 | 
						|
 | 
						|
  // Check that solution
 | 
						|
  auto new_solution = new_csp.optimize();
 | 
						|
  // csp.printAssignment(new_solution);
 | 
						|
  EXPECT(assert_equal(expected, new_solution));
 | 
						|
}
 | 
						|
 | 
						|
/* ************************************************************************* */
 | 
						|
TEST(Sudoku, easy) {
 | 
						|
  Sudoku csp(9,                          //
 | 
						|
             0, 0, 5, 0, 9, 0, 0, 0, 1,  //
 | 
						|
             0, 0, 0, 0, 0, 2, 0, 7, 3,  //
 | 
						|
             7, 6, 0, 0, 0, 8, 2, 0, 0,  //
 | 
						|
 | 
						|
             0, 1, 2, 0, 0, 9, 0, 0, 4,  //
 | 
						|
             0, 0, 0, 2, 0, 3, 0, 0, 0,  //
 | 
						|
             3, 0, 0, 1, 0, 0, 9, 6, 0,  //
 | 
						|
 | 
						|
             0, 0, 1, 9, 0, 0, 0, 5, 8,  //
 | 
						|
             9, 7, 0, 5, 0, 0, 0, 0, 0,  //
 | 
						|
             5, 0, 0, 0, 3, 0, 7, 0, 0);
 | 
						|
 | 
						|
  // csp.printSolution(); // don't do it
 | 
						|
 | 
						|
  // Do BP (AC1)
 | 
						|
  auto domains = csp.runArcConsistency(9, 10);
 | 
						|
  // csp.printDomains(domains);
 | 
						|
  Key key99 = Symbol('9', 9);
 | 
						|
  Domain domain99 = domains.at(key99);
 | 
						|
  EXPECT_LONGS_EQUAL(1, domain99.nrValues());
 | 
						|
 | 
						|
  // Test Creation of a new, simpler CSP
 | 
						|
  CSP new_csp = csp.partiallyApply(domains);
 | 
						|
  // 81 new Domains, and still 26 all-diff constraints
 | 
						|
  EXPECT_LONGS_EQUAL(81 + 26, new_csp.size());
 | 
						|
 | 
						|
  // csp.printSolution(); // still don't do it ! :-(
 | 
						|
}
 | 
						|
 | 
						|
/* ************************************************************************* */
 | 
						|
TEST(Sudoku, extreme) {
 | 
						|
  Sudoku csp(9,                             //
 | 
						|
             0, 0, 9, 7, 4, 8, 0, 0, 0, 7,  //
 | 
						|
             0, 0, 0, 0, 0, 0, 0, 0, 0, 2,  //
 | 
						|
             0, 1, 0, 9, 0, 0, 0, 0, 0, 7,  //
 | 
						|
             0, 0, 0, 2, 4, 0, 0, 6, 4, 0,  //
 | 
						|
             1, 0, 5, 9, 0, 0, 9, 8, 0, 0,  //
 | 
						|
             0, 3, 0, 0, 0, 0, 0, 8, 0, 3,  //
 | 
						|
             0, 2, 0, 0, 0, 0, 0, 0, 0, 0,  //
 | 
						|
             0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0);
 | 
						|
 | 
						|
  // Do BP
 | 
						|
  csp.runArcConsistency(9, 10);
 | 
						|
 | 
						|
#ifdef METIS
 | 
						|
  VariableIndexOrdered index(csp);
 | 
						|
  index.print("index");
 | 
						|
  ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/extreme-dual.txt");
 | 
						|
  index.outputMetisFormat(os);
 | 
						|
#endif
 | 
						|
 | 
						|
  // Do BP (AC1)
 | 
						|
  auto domains = csp.runArcConsistency(9, 10);
 | 
						|
  // csp.printDomains(domains);
 | 
						|
  Key key99 = Symbol('9', 9);
 | 
						|
  Domain domain99 = domains.at(key99);
 | 
						|
  EXPECT_LONGS_EQUAL(2, domain99.nrValues());
 | 
						|
 | 
						|
  // Test Creation of a new, simpler CSP
 | 
						|
  CSP new_csp = csp.partiallyApply(domains);
 | 
						|
  // 81 new Domains, and still 20 all-diff constraints
 | 
						|
  EXPECT_LONGS_EQUAL(81 + 20, new_csp.size());
 | 
						|
 | 
						|
  // csp.printSolution(); // still don't do it ! :-(
 | 
						|
}
 | 
						|
 | 
						|
/* ************************************************************************* */
 | 
						|
TEST(Sudoku, AJC_3star_Feb8_2012) {
 | 
						|
  Sudoku csp(9,                          //
 | 
						|
             9, 5, 0, 0, 0, 6, 0, 0, 0,  //
 | 
						|
             0, 8, 4, 0, 7, 0, 0, 0, 0,  //
 | 
						|
             6, 2, 0, 5, 0, 0, 4, 0, 0,  //
 | 
						|
 | 
						|
             0, 0, 0, 2, 9, 0, 6, 0, 0,  //
 | 
						|
             0, 9, 0, 0, 0, 0, 0, 2, 0,  //
 | 
						|
             0, 0, 2, 0, 6, 3, 0, 0, 0,  //
 | 
						|
 | 
						|
             0, 0, 9, 0, 0, 7, 0, 6, 8,  //
 | 
						|
             0, 0, 0, 0, 3, 0, 2, 9, 0,  //
 | 
						|
             0, 0, 0, 1, 0, 0, 0, 3, 7);
 | 
						|
 | 
						|
  // Do BP (AC1)
 | 
						|
  auto domains = csp.runArcConsistency(9, 10);
 | 
						|
  // csp.printDomains(domains);
 | 
						|
  Key key99 = Symbol('9', 9);
 | 
						|
  Domain domain99 = domains.at(key99);
 | 
						|
  EXPECT_LONGS_EQUAL(1, domain99.nrValues());
 | 
						|
 | 
						|
  // Test Creation of a new, simpler CSP
 | 
						|
  CSP new_csp = csp.partiallyApply(domains);
 | 
						|
  // Just the 81 new Domains
 | 
						|
  EXPECT_LONGS_EQUAL(81, new_csp.size());
 | 
						|
 | 
						|
  // Check that solution
 | 
						|
  auto solution = new_csp.optimize();
 | 
						|
  // csp.printAssignment(solution);
 | 
						|
  EXPECT_LONGS_EQUAL(6, solution.at(key99));
 | 
						|
}
 | 
						|
 | 
						|
/* ************************************************************************* */
 | 
						|
int main() {
 | 
						|
  TestResult tr;
 | 
						|
  return TestRegistry::runAllTests(tr);
 | 
						|
}
 | 
						|
/* ************************************************************************* */
 |