gtsam/gtsam_unstable/discrete/tests/testSudoku.cpp

204 lines
5.7 KiB
C++
Raw Normal View History

2012-04-16 06:35:28 +08:00
/*
* testSudoku.cpp
* @brief develop code for Sudoku CSP solver
* @date Jan 29, 2012
* @author Frank Dellaert
*/
#include <CppUnitLite/TestHarness.h>
2021-11-18 23:54:00 +08:00
#include <gtsam_unstable/discrete/CSP.h>
#include <boost/assign/std/map.hpp>
using boost::assign::insert;
2021-11-18 23:54:00 +08:00
#include <stdarg.h>
2012-04-16 06:35:28 +08:00
#include <iostream>
#include <sstream>
using namespace std;
using namespace gtsam;
#define PRINT false
2021-11-18 23:54:00 +08:00
class Sudoku : public CSP {
/// sudoku size
size_t n_;
2012-04-16 06:35:28 +08:00
/// discrete keys
typedef std::pair<size_t, size_t> IJ;
std::map<IJ, DiscreteKey> dkeys_;
2012-04-16 06:35:28 +08:00
2021-11-18 23:54:00 +08:00
public:
/// return DiscreteKey for cell(i,j)
const DiscreteKey& dkey(size_t i, size_t j) const {
return dkeys_.at(IJ(i, j));
}
/// return Key for cell(i,j)
2021-11-18 23:54:00 +08:00
Key key(size_t i, size_t j) const { return dkey(i, j).first; }
/// Constructor
2021-11-18 23:54:00 +08:00
Sudoku(size_t n, ...) : n_(n) {
// Create variables, ordering, and unary constraints
va_list ap;
va_start(ap, n);
2021-11-18 23:54:00 +08:00
Key k = 0;
for (size_t i = 0; i < n; ++i) {
for (size_t j = 0; j < n; ++j, ++k) {
// create the key
IJ ij(i, j);
dkeys_[ij] = DiscreteKey(k, n);
// get the unary constraint, if any
int value = va_arg(ap, int);
// cout << value << " ";
if (value != 0) addSingleValue(dkeys_[ij], value - 1);
}
2021-11-18 23:54:00 +08:00
// cout << endl;
}
va_end(ap);
// add row constraints
for (size_t i = 0; i < n; i++) {
DiscreteKeys dkeys;
2021-11-18 23:54:00 +08:00
for (size_t j = 0; j < n; j++) dkeys += dkey(i, j);
addAllDiff(dkeys);
}
// add col constraints
for (size_t j = 0; j < n; j++) {
DiscreteKeys dkeys;
2021-11-18 23:54:00 +08:00
for (size_t i = 0; i < n; i++) dkeys += dkey(i, j);
addAllDiff(dkeys);
}
// add box constraints
size_t N = (size_t)sqrt(double(n)), i0 = 0;
for (size_t I = 0; I < N; I++) {
size_t j0 = 0;
for (size_t J = 0; J < N; J++) {
// Box I,J
DiscreteKeys dkeys;
for (size_t i = i0; i < i0 + N; i++)
2021-11-18 23:54:00 +08:00
for (size_t j = j0; j < j0 + N; j++) dkeys += dkey(i, j);
addAllDiff(dkeys);
j0 += N;
}
i0 += N;
}
}
/// Print readable form of assignment
2021-11-21 05:15:05 +08:00
void printAssignment(const DiscreteFactor::Values& assignment) const {
for (size_t i = 0; i < n_; i++) {
for (size_t j = 0; j < n_; j++) {
Key k = key(i, j);
2021-11-21 05:15:05 +08:00
cout << 1 + assignment.at(k) << " ";
}
cout << endl;
}
}
/// solve and print solution
void printSolution() {
2021-11-21 05:15:05 +08:00
auto MPE = optimalAssignment();
printAssignment(MPE);
}
2012-04-16 06:35:28 +08:00
};
/* ************************************************************************* */
2021-11-18 23:54:00 +08:00
TEST_UNSAFE(Sudoku, small) {
2021-11-19 00:31:11 +08:00
Sudoku csp(4, //
1, 0, 0, 4, //
0, 0, 0, 0, //
4, 0, 2, 0, //
0, 1, 0, 0);
// Do BP
2021-11-18 23:54:00 +08:00
csp.runArcConsistency(4, 10, PRINT);
// optimize and check
2021-11-21 05:15:05 +08:00
auto solution = csp.optimalAssignment();
CSP::Values expected;
2021-11-18 23:54:00 +08:00
insert(expected)(csp.key(0, 0), 0)(csp.key(0, 1), 1)(csp.key(0, 2), 2)(
csp.key(0, 3), 3)(csp.key(1, 0), 2)(csp.key(1, 1), 3)(csp.key(1, 2), 0)(
csp.key(1, 3), 1)(csp.key(2, 0), 3)(csp.key(2, 1), 2)(csp.key(2, 2), 1)(
csp.key(2, 3), 0)(csp.key(3, 0), 1)(csp.key(3, 1), 0)(csp.key(3, 2), 3)(
csp.key(3, 3), 2);
2021-11-21 05:15:05 +08:00
EXPECT(assert_equal(expected, solution));
2021-11-18 23:54:00 +08:00
// csp.printAssignment(solution);
2012-04-16 06:35:28 +08:00
}
/* ************************************************************************* */
2021-11-18 23:54:00 +08:00
TEST_UNSAFE(Sudoku, easy) {
2021-11-19 00:31:11 +08:00
Sudoku sudoku(9, //
0, 0, 5, 0, 9, 0, 0, 0, 1, //
0, 0, 0, 0, 0, 2, 0, 7, 3, //
7, 6, 0, 0, 0, 8, 2, 0, 0, //
2012-04-16 06:35:28 +08:00
2021-11-19 00:31:11 +08:00
0, 1, 2, 0, 0, 9, 0, 0, 4, //
0, 0, 0, 2, 0, 3, 0, 0, 0, //
3, 0, 0, 1, 0, 0, 9, 6, 0, //
2012-04-16 06:35:28 +08:00
2021-11-19 00:31:11 +08:00
0, 0, 1, 9, 0, 0, 0, 5, 8, //
9, 7, 0, 5, 0, 0, 0, 0, 0, //
5, 0, 0, 0, 3, 0, 7, 0, 0);
2012-04-16 06:35:28 +08:00
// Do BP
2021-11-18 23:54:00 +08:00
sudoku.runArcConsistency(4, 10, PRINT);
2012-04-16 06:35:28 +08:00
// sudoku.printSolution(); // don't do it
2012-04-16 06:35:28 +08:00
}
/* ************************************************************************* */
2021-11-18 23:54:00 +08:00
TEST_UNSAFE(Sudoku, extreme) {
2021-11-19 00:31:11 +08:00
Sudoku sudoku(9, //
0, 0, 9, 7, 4, 8, 0, 0, 0, 7, //
0, 0, 0, 0, 0, 0, 0, 0, 0, 2, //
0, 1, 0, 9, 0, 0, 0, 0, 0, 7, //
0, 0, 0, 2, 4, 0, 0, 6, 4, 0, //
1, 0, 5, 9, 0, 0, 9, 8, 0, 0, //
0, 3, 0, 0, 0, 0, 0, 8, 0, 3, //
0, 2, 0, 0, 0, 0, 0, 0, 0, 0, //
0, 6, 0, 0, 0, 2, 7, 5, 9, 0, 0);
2012-04-16 06:35:28 +08:00
// Do BP
2021-11-18 23:54:00 +08:00
sudoku.runArcConsistency(9, 10, PRINT);
2012-04-16 06:35:28 +08:00
#ifdef METIS
VariableIndexOrdered index(sudoku);
index.print("index");
ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/extreme-dual.txt");
2012-04-16 06:35:28 +08:00
index.outputMetisFormat(os);
#endif
2021-11-18 23:54:00 +08:00
// sudoku.printSolution(); // don't do it
2012-04-16 06:35:28 +08:00
}
/* ************************************************************************* */
2021-11-18 23:54:00 +08:00
TEST_UNSAFE(Sudoku, AJC_3star_Feb8_2012) {
2021-11-19 00:31:11 +08:00
Sudoku sudoku(9, //
9, 5, 0, 0, 0, 6, 0, 0, 0, //
0, 8, 4, 0, 7, 0, 0, 0, 0, //
6, 2, 0, 5, 0, 0, 4, 0, 0, //
0, 0, 0, 2, 9, 0, 6, 0, 0, //
0, 9, 0, 0, 0, 0, 0, 2, 0, //
0, 0, 2, 0, 6, 3, 0, 0, 0, //
0, 0, 9, 0, 0, 7, 0, 6, 8, //
0, 0, 0, 0, 3, 0, 2, 9, 0, //
0, 0, 0, 1, 0, 0, 0, 3, 7);
2012-04-16 06:35:28 +08:00
// Do BP
2021-11-18 23:54:00 +08:00
sudoku.runArcConsistency(9, 10, PRINT);
2012-04-16 06:35:28 +08:00
2021-11-18 23:54:00 +08:00
// sudoku.printSolution(); // don't do it
2012-04-16 06:35:28 +08:00
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
2012-04-16 06:35:28 +08:00
}
/* ************************************************************************* */