diff --git a/gtsam/discrete/AlgebraicDecisionTree.h b/gtsam/discrete/AlgebraicDecisionTree.h new file mode 100644 index 000000000..7c182f387 --- /dev/null +++ b/gtsam/discrete/AlgebraicDecisionTree.h @@ -0,0 +1,134 @@ +/* + * @file AlgebraicDecisionTree.h + * @brief Algebraic Decision Trees + * @author Frank Dellaert + * @date Mar 14, 2011 + */ + +#pragma once + +#include + +namespace gtsam { + + /** + * Algebraic Decision Trees fix the range to double + * Just has some nice constructors and some syntactic sugar + * TODO: consider eliminating this class altogether? + */ + template + class AlgebraicDecisionTree: public DecisionTree { + + public: + + typedef DecisionTree Super; + + /** The Real ring with addition and multiplication */ + struct Ring { + static inline double zero() { + return 0.0; + } + static inline double one() { + return 1.0; + } + static inline double add(const double& a, const double& b) { + return a + b; + } + static inline double max(const double& a, const double& b) { + return std::max(a, b); + } + static inline double mul(const double& a, const double& b) { + return a * b; + } + static inline double div(const double& a, const double& b) { + return a / b; + } + static inline double id(const double& x) { + return x; + } + }; + + AlgebraicDecisionTree() : + Super(1.0) { + } + + AlgebraicDecisionTree(const Super& add) : + Super(add) { + } + + /** Create a new leaf function splitting on a variable */ + AlgebraicDecisionTree(const L& label, double y1, double y2) : + Super(label, y1, y2) { + } + + /** Create a new leaf function splitting on a variable */ + AlgebraicDecisionTree(const typename Super::LabelC& labelC, double y1, double y2) : + Super(labelC, y1, y2) { + } + + /** Create from keys and vector table */ + AlgebraicDecisionTree // + (const std::vector& labelCs, const std::vector& ys) { + this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + ys.end()); + } + + /** Create from keys and string table */ + AlgebraicDecisionTree // + (const std::vector& labelCs, const std::string& table) { + // Convert string to doubles + std::vector ys; + std::istringstream iss(table); + std::copy(std::istream_iterator(iss), + std::istream_iterator(), std::back_inserter(ys)); + + // now call recursive Create + this->root_ = Super::create(labelCs.begin(), labelCs.end(), ys.begin(), + ys.end()); + } + + /** Create a new function splitting on a variable */ + template + AlgebraicDecisionTree(Iterator begin, Iterator end, const L& label) : + Super(NULL) { + this->root_ = compose(begin, end, label); + } + + /** Convert */ + template + AlgebraicDecisionTree(const AlgebraicDecisionTree& other, + const std::map& map) { + this->root_ = this->template convert(other.root_, map, + Ring::id); + } + + /** sum */ + AlgebraicDecisionTree operator+(const AlgebraicDecisionTree& g) const { + return this->apply(g, &Ring::add); + } + + /** product */ + AlgebraicDecisionTree operator*(const AlgebraicDecisionTree& g) const { + return this->apply(g, &Ring::mul); + } + + /** division */ + AlgebraicDecisionTree operator/(const AlgebraicDecisionTree& g) const { + return this->apply(g, &Ring::div); + } + + /** sum out variable */ + AlgebraicDecisionTree sum(const L& label, size_t cardinality) const { + return this->combine(label, cardinality, &Ring::add); + } + + /** sum out variable */ + AlgebraicDecisionTree sum(const typename Super::LabelC& labelC) const { + return this->combine(labelC, &Ring::add); + } + + }; +// AlgebraicDecisionTree + +} +// namespace gtsam diff --git a/gtsam/discrete/AllDiff.cpp b/gtsam/discrete/AllDiff.cpp new file mode 100644 index 000000000..b06dbae14 --- /dev/null +++ b/gtsam/discrete/AllDiff.cpp @@ -0,0 +1,110 @@ +/* + * AllDiff.cpp + * @brief General "all-different" constraint + * @date Feb 6, 2012 + * @author Frank Dellaert + */ + +#include +#include +#include +#include + +namespace gtsam { + + /* ************************************************************************* */ + AllDiff::AllDiff(const DiscreteKeys& dkeys) : + DiscreteFactor(dkeys.indices()) { + BOOST_FOREACH(const DiscreteKey& dkey, dkeys) + cardinalities_.insert(dkey); + } + + /* ************************************************************************* */ + void AllDiff::print(const std::string& s) const { + std::cout << s << ": AllDiff on "; + BOOST_FOREACH (Index dkey, keys_) + std::cout << dkey << " "; + std::cout << std::endl; + } + + /* ************************************************************************* */ + double AllDiff::operator()(const Values& values) const { + std::set < size_t > taken; // record values taken by keys + BOOST_FOREACH(Index dkey, keys_) { + size_t value = values.at(dkey); // get the value for that key + if (taken.count(value)) return 0.0;// check if value alreday taken + taken.insert(value);// if not, record it as taken and keep checking + } + return 1.0; + } + + /* ************************************************************************* */ + AllDiff::operator DecisionTreeFactor() const { + // We will do this by converting the allDif into many BinaryAllDiff constraints + DecisionTreeFactor converted; + size_t nrKeys = keys_.size(); + for (size_t i1 = 0; i1 < nrKeys; i1++) + for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) { + BinaryAllDiff binary12(discreteKey(i1),discreteKey(i2)); + converted = converted * binary12; + } + return converted; + } + + /* ************************************************************************* */ + DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return DecisionTreeFactor(*this) * f; + } + + /* ************************************************************************* */ + bool AllDiff::ensureArcConsistency(size_t j, std::vector& domains) const { + // Though strictly not part of allDiff, we check for + // a value in domains[j] that does not occur in any other connected domain. + // If found, we make this a singleton... + // TODO: make a new constraint where this really is true + Domain& Dj = domains[j]; + if (Dj.checkAllDiff(keys_, domains)) return true; + + // Check all other domains for singletons and erase corresponding values + // This is the same as arc-consistency on the equivalent binary constraints + bool changed = false; + BOOST_FOREACH(Index k, keys_) + if (k != j) { + const Domain& Dk = domains[k]; + if (Dk.isSingleton()) { // check if singleton + size_t value = Dk.firstValue(); + if (Dj.contains(value)) { + Dj.erase(value); // erase value if true + changed = true; + } + } + } + return changed; + } + + /* ************************************************************************* */ + DiscreteFactor::shared_ptr AllDiff::partiallyApply(const Values& values) const { + DiscreteKeys newKeys; + // loop over keys and add them only if they do not appear in values + BOOST_FOREACH(Index k, keys_) + if (values.find(k) == values.end()) { + newKeys.push_back(DiscreteKey(k,cardinalities_.at(k))); + } + return boost::make_shared(newKeys); + } + + /* ************************************************************************* */ + DiscreteFactor::shared_ptr AllDiff::partiallyApply( + const std::vector& domains) const { + DiscreteFactor::Values known; + BOOST_FOREACH(Index k, keys_) { + const Domain& Dk = domains[k]; + if (Dk.isSingleton()) + known[k] = Dk.firstValue(); + } + return partiallyApply(known); + } + + /* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/AllDiff.h b/gtsam/discrete/AllDiff.h new file mode 100644 index 000000000..57ed3f9d3 --- /dev/null +++ b/gtsam/discrete/AllDiff.h @@ -0,0 +1,64 @@ +/* + * AllDiff.h + * @brief General "all-different" constraint + * @date Feb 6, 2012 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include + +namespace gtsam { + + /** + * General AllDiff constraint + * Returns 1 if values for all keys are different, 0 otherwise + * DiscreteFactors are all awkward in that they have to store two types of keys: + * for each variable we have a Index and an Index. In this factor, we + * keep the Indices locally, and the Indices are stored in IndexFactor. + */ + class AllDiff: public DiscreteFactor { + + std::map cardinalities_; + + DiscreteKey discreteKey(size_t i) const { + Index j = keys_[i]; + return DiscreteKey(j,cardinalities_.at(j)); + } + + public: + + /// Constructor + AllDiff(const DiscreteKeys& dkeys); + + // print + virtual void print(const std::string& s = "") const; + + /// Calculate value = expensive ! + virtual double operator()(const Values& values) const; + + /// Convert into a decisiontree, can be *very* expensive ! + virtual operator DecisionTreeFactor() const; + + /// Multiply into a decisiontree + virtual DecisionTreeFactor operator*(const DecisionTreeFactor& f) const; + + /* + * Ensure Arc-consistency + * Arc-consistency involves creating binaryAllDiff constraints + * In which case the combinatorial hyper-arc explosion disappears. + * @param j domain to be checked + * @param domains all other domains + */ + bool ensureArcConsistency(size_t j, std::vector& domains) const; + + /// Partially apply known values + virtual DiscreteFactor::shared_ptr partiallyApply(const Values&) const; + + /// Partially apply known values, domain version + virtual DiscreteFactor::shared_ptr partiallyApply(const std::vector&) const; + }; + +} // namespace gtsam diff --git a/gtsam/discrete/Assignment.h b/gtsam/discrete/Assignment.h new file mode 100644 index 000000000..0150f6ff9 --- /dev/null +++ b/gtsam/discrete/Assignment.h @@ -0,0 +1,36 @@ +/* + * @file Assignment.h + * @brief An assignment from labels to a discrete value index (size_t) + * @author Frank Dellaert + * @date Feb 5, 2012 + */ + +#pragma once + +#include +#include +#include + +namespace gtsam { + + /** + * An assignment from labels to value index (size_t). + * Assigns to each label a value. Implemented as a simple map. + * A discrete factor takes an Assignment and returns a value. + */ + template + class Assignment: public std::map { + public: + void print(const std::string& s = "Assignment: ") const { + std::cout << s << ": "; + BOOST_FOREACH(const typename Assignment::value_type& keyValue, *this) + std::cout << "(" << keyValue.first << ", " << keyValue.second << ")"; + std::cout << std::endl; + } + + bool equals(const Assignment& other, double tol = 1e-9) const { + return (*this == other); + } + }; + +} // namespace gtsam diff --git a/gtsam/discrete/BinaryAllDiff.h b/gtsam/discrete/BinaryAllDiff.h new file mode 100644 index 000000000..a97adb2e3 --- /dev/null +++ b/gtsam/discrete/BinaryAllDiff.h @@ -0,0 +1,87 @@ +/* + * BinaryAllDiff.h + * @brief Binary "all-different" constraint + * @date Feb 6, 2012 + * @author Frank Dellaert + */ + +#pragma once + +#include + +namespace gtsam { + + /** + * Binary AllDiff constraint + * Returns 1 if values for two keys are different, 0 otherwise + * DiscreteFactors are all awkward in that they have to store two types of keys: + * for each variable we have a Index and an Index. In this factor, we + * keep the Indices locally, and the Indices are stored in IndexFactor. + */ + class BinaryAllDiff: public DiscreteFactor { + + size_t cardinality0_, cardinality1_; /// cardinality + + public: + + /// Constructor + BinaryAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) : + DiscreteFactor(key1.first, key2.first), + cardinality0_(key1.second), cardinality1_(key2.second) { + } + + // print + virtual void print(const std::string& s = "") const { + std::cout << s << ": BinaryAllDiff on " << keys_[0] << " and " << keys_[1] + << std::endl; + } + + /// Calculate value + virtual double operator()(const Values& values) const { + return (double) (values.at(keys_[0]) != values.at(keys_[1])); + } + + /// Convert into a decisiontree + virtual operator DecisionTreeFactor() const { + DiscreteKeys keys; + keys.push_back(DiscreteKey(keys_[0],cardinality0_)); + keys.push_back(DiscreteKey(keys_[1],cardinality1_)); + std::vector table; + for (size_t i1 = 0; i1 < cardinality0_; i1++) + for (size_t i2 = 0; i2 < cardinality1_; i2++) + table.push_back(i1 != i2); + DecisionTreeFactor converted(keys, table); + return converted; + } + + /// Multiply into a decisiontree + virtual DecisionTreeFactor operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return DecisionTreeFactor(*this) * f; + } + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + /// + bool ensureArcConsistency(size_t j, std::vector& domains) const { +// throw std::runtime_error( +// "BinaryAllDiff::ensureArcConsistency not implemented"); + return false; + } + + /// Partially apply known values + virtual DiscreteFactor::shared_ptr partiallyApply(const Values&) const { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } + + /// Partially apply known values, domain version + virtual DiscreteFactor::shared_ptr partiallyApply( + const std::vector&) const { + throw std::runtime_error("BinaryAllDiff::partiallyApply not implemented"); + } + }; + +} // namespace gtsam diff --git a/gtsam/discrete/CMakeLists.txt b/gtsam/discrete/CMakeLists.txt new file mode 100644 index 000000000..42917e59b --- /dev/null +++ b/gtsam/discrete/CMakeLists.txt @@ -0,0 +1,35 @@ +# Install headers +set(subdir discrete) +file(GLOB discrete_headers "*.h") +# FIXME: exclude headers +install(FILES ${discrete_headers} DESTINATION include/gtsam2/discrete) + +# Set up library dependencies +set (discrete_local_libs + discrete) + +set (discrete_full_libs + gtsam2-static) + +# Exclude tests that don't work +set (discrete_excluded_tests +"${CMAKE_CURRENT_SOURCE_DIR}/tests/testTypedDiscreteFactor.cpp" +"${CMAKE_CURRENT_SOURCE_DIR}/tests/testTypedDiscreteFactorGraph.cpp" +"${CMAKE_CURRENT_SOURCE_DIR}/tests/testPotentialTable.cpp") + +# Add all tests +gtsam_add_subdir_tests(discrete "${discrete_local_libs}" "${discrete_full_libs}" "${discrete_excluded_tests}") + +# add examples +foreach(example schedulingExample schedulingQuals12) + add_executable(${example} "examples/${example}.cpp") + add_dependencies(${example} gtsam2-static) + target_link_libraries(${example} ${Boost_LIBRARIES} gtsam2-static) + add_custom_target(${example}.run ${EXECUTABLE_OUTPUT_PATH}${example} ${ARGN}) +endforeach(example) + +# Build timing scripts +#if (GTSAM_BUILD_TIMING) +# gtsam_add_timing(discrete "${discrete_local_libs}") +#endif(GTSAM_BUILD_TIMING) + diff --git a/gtsam/discrete/CSP.cpp b/gtsam/discrete/CSP.cpp new file mode 100644 index 000000000..903daccf4 --- /dev/null +++ b/gtsam/discrete/CSP.cpp @@ -0,0 +1,94 @@ +/* + * CSP.cpp + * @brief Constraint Satisfaction Problem class + * @date Feb 6, 2012 + * @author Frank Dellaert + */ + +#include +#include +#include +#include +#include + +namespace gtsam { + + /// Find the best total assignment - can be expensive + CSP::sharedValues CSP::optimalAssignment() const { + DiscreteSequentialSolver solver(*this); + DiscreteBayesNet::shared_ptr chordal = solver.eliminate(); + sharedValues mpe = optimize(*chordal); + return mpe; + } + + void CSP::runArcConsistency(size_t cardinality, size_t nrIterations, bool print) const { + // Create VariableIndex + VariableIndex index(*this); + // index.print(); + + size_t n = index.size(); + + // Initialize domains + std::vector < Domain > domains; + for (size_t j = 0; j < n; j++) + domains.push_back(Domain(DiscreteKey(j,cardinality))); + + // Create array of flags indicating a domain changed or not + std::vector changed(n); + + // iterate nrIterations over entire grid + for (size_t it = 0; it < nrIterations; it++) { + bool anyChange = false; + // iterate over all cells + for (size_t v = 0; v < n; v++) { + // keep track of which domains changed + changed[v] = false; + // loop over all factors/constraints for variable v + const VariableIndex::Factors& factors = index[v]; + BOOST_FOREACH(size_t f,factors) { + // if not already a singleton + if (!domains[v].isSingleton()) { + // get the constraint and call its ensureArcConsistency method + DiscreteFactor::shared_ptr factor = (*this)[f]; + changed[v] = factor->ensureArcConsistency(v,domains) || changed[v]; + } + } // f + if (changed[v]) anyChange = true; + } // v + if (!anyChange) break; + // TODO: Sudoku specific hack + if (print) { + if (cardinality == 9 && n == 81) { + for (size_t i = 0, v = 0; i < sqrt(n); i++) { + for (size_t j = 0; j < sqrt(n); j++, v++) { + if (changed[v]) cout << "*"; + domains[v].print(); + cout << "\t"; + } // i + cout << endl; + } // j + } else { + for (size_t v = 0; v < n; v++) { + if (changed[v]) cout << "*"; + domains[v].print(); + cout << "\t"; + } // v + } + cout << endl; + } // print + } // it + +#ifndef INPROGRESS + // Now create new problem with all singleton variables removed + // We do this by adding simplifying all factors using parial application + // TODO: create a new ordering as we go, to ensure a connected graph + // KeyOrdering ordering; + // vector dkeys; + BOOST_FOREACH(const DiscreteFactor::shared_ptr& factor, factors_) { + DiscreteFactor::shared_ptr reduced = factor->partiallyApply(domains); + if (print) reduced->print(); + } +#endif + } +} // gtsam + diff --git a/gtsam/discrete/CSP.h b/gtsam/discrete/CSP.h new file mode 100644 index 000000000..973d104fa --- /dev/null +++ b/gtsam/discrete/CSP.h @@ -0,0 +1,71 @@ +/* + * CSP.h + * @brief Constraint Satisfaction Problem class + * @date Feb 6, 2012 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include +#include + +namespace gtsam { + + /** + * Constraint Satisfaction Problem class + * A specialization of a DiscreteFactorGraph. + * It knows about CSP-specific constraints and algorithms + */ + class CSP: public DiscreteFactorGraph { + + public: + /// Constructor + CSP() { + } + + /// Add a unary constraint, allowing only a single value + void addSingleValue(const DiscreteKey& dkey, size_t value) { + boost::shared_ptr factor(new SingleValue(dkey, value)); + push_back(factor); + } + + /// Add a binary AllDiff constraint + void addAllDiff(const DiscreteKey& key1, const DiscreteKey& key2) { + boost::shared_ptr factor( + new BinaryAllDiff(key1, key2)); + push_back(factor); + } + + /// Add a general AllDiff constraint + void addAllDiff(const DiscreteKeys& dkeys) { + boost::shared_ptr factor(new AllDiff(dkeys)); + push_back(factor); + } + + /// Find the best total assignment - can be expensive + sharedValues optimalAssignment() const; + + /* + * Perform loopy belief propagation + * True belief propagation would check for each value in domain + * whether any satisfying separator assignment can be found. + * This corresponds to hyper-arc consistency in CSP speak. + * This can be done by creating a mini-factor graph and search. + * For a nine-by-nine Sudoku, the search tree will be 8+6+6=20 levels deep. + * It will be very expensive to exclude values that way. + */ + // void applyBeliefPropagation(size_t nrIterations = 10) const; + /* + * Apply arc-consistency ~ Approximate loopy belief propagation + * We need to give the domains to a constraint, and it returns + * a domain whose values don't conflict in the arc-consistency way. + * TODO: should get cardinality from Indices + */ + void runArcConsistency(size_t cardinality, size_t nrIterations = 10, + bool print = false) const; + }; + +} // gtsam + diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h new file mode 100644 index 000000000..6a6519d83 --- /dev/null +++ b/gtsam/discrete/DecisionTree-inl.h @@ -0,0 +1,667 @@ +/* + * @file DecisionTree.h + * @brief Decision Tree for use in DiscreteFactors + * @author Frank Dellaert + * @author Can Erdogan + * @date Jan 30, 2012 + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace gtsam { + + using namespace boost::assign; + + /*********************************************************************************/ + // Node + /*********************************************************************************/ +#ifdef DT_DEBUG_MEMORY + template + int DecisionTree::Node::nrNodes = 0; +#endif + + /*********************************************************************************/ + // Leaf + /*********************************************************************************/ + template + class DecisionTree::Leaf: public DecisionTree::Node { + + /** constant stored in this leaf */ + Y constant_; + + public: + + /** Constructor from constant */ + Leaf(const Y& constant) : + constant_(constant) {} + + /** return the constant */ + const Y& constant() const { + return constant_; + } + + /// Leaf-Leaf equality + bool sameLeaf(const Leaf& q) const { + return constant_ == q.constant_; + } + + /// polymorphic equality: is q is a leaf, could be + bool sameLeaf(const Node& q) const { + return (q.isLeaf() && q.sameLeaf(*this)); + } + + /** equality up to tolerance */ + bool equals(const Node& q, double tol) const { + const Leaf* other = dynamic_cast (&q); + if (!other) return false; + return fabs(this->constant_ - other->constant_) < tol; + } + + /** print */ + void print(const std::string& s) const { + bool showZero = true; + if (showZero || constant_) std::cout << s << " Leaf " << constant_ << std::endl; + } + + /** to graphviz file */ + void dot(std::ostream& os, bool showZero) const { + if (showZero || constant_) os << "\"" << this->id() << "\" [label=\"" + << boost::format("%4.2g") % constant_ + << "\", shape=box, rank=sink, height=0.35, fixedsize=true]\n"; // width=0.55, + } + + /** evaluate */ + const Y& operator()(const Assignment& x) const { + return constant_; + } + + /** apply unary operator */ + NodePtr apply(const Unary& op) const { + NodePtr f(new Leaf(op(constant_))); + return f; + } + + // Apply binary operator "h = f op g" on Leaf node + // Note op is not assumed commutative so we need to keep track of order + // Simply calls apply on argument to call correct virtual method: + // fL.apply_f_op_g(gL) -> gL.apply_g_op_fL(fL) (below) + // fL.apply_f_op_g(gC) -> gC.apply_g_op_fL(fL) (Choice) + NodePtr apply_f_op_g(const Node& g, const Binary& op) const { + return g.apply_g_op_fL(*this, op); + } + + // Applying binary operator to two leaves results in a leaf + NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const { + NodePtr h(new Leaf(op(fL.constant_, constant_))); // fL op gL + return h; + } + + // If second argument is a Choice node, call it's apply with leaf as second + NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const { + return fC.apply_fC_op_gL(*this, op); // operand order back to normal + } + + /** choose a branch, create new memory ! */ + NodePtr choose(const L& label, size_t index) const { + return NodePtr(new Leaf(constant())); + } + + bool isLeaf() const { return true; } + + }; // Leaf + + /*********************************************************************************/ + // Choice + /*********************************************************************************/ + template + class DecisionTree::Choice: public DecisionTree::Node { + + /** the label of the variable on which we split */ + L label_; + + /** The children of this Choice node. */ + std::vector branches_; + + private: + /** incremental allSame */ + size_t allSame_; + + typedef boost::shared_ptr ChoicePtr; + + public: + + virtual ~Choice() { +#ifdef DT_DEBUG_MEMORY + std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id() << std::std::endl; +#endif + } + + /** If all branches of a choice node f are the same, just return a branch */ + static NodePtr Unique(const ChoicePtr& f) { +#ifndef DT_NO_PRUNING + if (f->allSame_) { + assert(f->branches().size() > 0); + NodePtr f0 = f->branches_[0]; + assert(f0->isLeaf()); + NodePtr newLeaf(new Leaf(boost::dynamic_pointer_cast(f0)->constant())); + return newLeaf; + } else +#endif + return f; + } + + bool isLeaf() const { return false; } + + /** Constructor, given choice label and mandatory expected branch count */ + Choice(const L& label, size_t count) : + label_(label), allSame_(true) { + branches_.reserve(count); + } + + /** + * Construct from applying binary op to two Choice nodes + */ + Choice(const Choice& f, const Choice& g, const Binary& op) : + allSame_(true) { + + // Choose what to do based on label + if (f.label() > g.label()) { + // f higher than g + label_ = f.label(); + size_t count = f.nrChoices(); + branches_.reserve(count); + for (size_t i = 0; i < count; i++) + push_back(f.branches_[i]->apply_f_op_g(g, op)); + } else if (g.label() > f.label()) { + // f lower than g + label_ = g.label(); + size_t count = g.nrChoices(); + branches_.reserve(count); + for (size_t i = 0; i < count; i++) + push_back(g.branches_[i]->apply_g_op_fC(f, op)); + } else { + // f same level as g + label_ = f.label(); + size_t count = f.nrChoices(); + branches_.reserve(count); + for (size_t i = 0; i < count; i++) + push_back(f.branches_[i]->apply_f_op_g(*g.branches_[i], op)); + } + } + + const L& label() const { + return label_; + } + + size_t nrChoices() const { + return branches_.size(); + } + + const std::vector& branches() const { + return branches_; + } + + /** add a branch: TODO merge into constructor */ + void push_back(const NodePtr& node) { + // allSame_ is restricted to leaf nodes in a decision tree + if (allSame_ && !branches_.empty()) { + allSame_ = node->sameLeaf(*branches_.back()); + } + branches_.push_back(node); + } + + /** print (as a tree) */ + void print(const std::string& s) const { + std::cout << s << " Choice("; + // std::cout << this << ","; + std::cout << label_ << ") " << std::endl; + for (size_t i = 0; i < branches_.size(); i++) + branches_[i]->print((boost::format("%s %d") % s % i).str()); + } + + /** output to graphviz (as a a graph) */ + void dot(std::ostream& os, bool showZero) const { + os << "\"" << this->id() << "\" [shape=circle, label=\"" << label_ + << "\"]\n"; + for (size_t i = 0; i < branches_.size(); i++) { + NodePtr branch = branches_[i]; + + // Check if zero + if (!showZero) { + const Leaf* leaf = dynamic_cast (branch.get()); + if (leaf && !leaf->constant()) continue; + } + + os << "\"" << this->id() << "\" -> \"" << branch->id() << "\""; + if (i == 0) os << " [style=dashed]"; + if (i > 1) os << " [style=bold]"; + os << std::endl; + branch->dot(os, showZero); + } + } + + /// Choice-Leaf equality: always false + bool sameLeaf(const Leaf& q) const { + return false; + } + + /// polymorphic equality: if q is a leaf, could be... + bool sameLeaf(const Node& q) const { + return (q.isLeaf() && q.sameLeaf(*this)); + } + + /** equality up to tolerance */ + bool equals(const Node& q, double tol) const { + const Choice* other = dynamic_cast (&q); + if (!other) return false; + if (this->label_ != other->label_) return false; + if (branches_.size() != other->branches_.size()) return false; + // we don't care about shared pointers being equal here + for (size_t i = 0; i < branches_.size(); i++) + if (!(branches_[i]->equals(*(other->branches_[i]), tol))) return false; + return true; + } + + /** evaluate */ + const Y& operator()(const Assignment& x) const { +#ifndef NDEBUG + typename Assignment::const_iterator it = x.find(label_); + if (it == x.end()) { + std::cout << "Trying to find value for " << label_ << std::endl; + throw std::invalid_argument( + "DecisionTree::operator(): value undefined for a label"); + } +#endif + size_t index = x.at(label_); + NodePtr child = branches_[index]; + return (*child)(x); + } + + /** + * Construct from applying unary op to a Choice node + */ + Choice(const L& label, const Choice& f, const Unary& op) : + label_(label), allSame_(true) { + + branches_.reserve(f.branches_.size()); // reserve space + BOOST_FOREACH (const NodePtr& branch, f.branches_) + push_back(branch->apply(op)); + } + + /** apply unary operator */ + NodePtr apply(const Unary& op) const { + boost::shared_ptr r(new Choice(label_, *this, op)); + return Unique(r); + } + + // Apply binary operator "h = f op g" on Choice node + // Note op is not assumed commutative so we need to keep track of order + // Simply calls apply on argument to call correct virtual method: + // fC.apply_f_op_g(gL) -> gL.apply_g_op_fC(fC) -> (Leaf) + // fC.apply_f_op_g(gC) -> gC.apply_g_op_fC(fC) -> (below) + NodePtr apply_f_op_g(const Node& g, const Binary& op) const { + return g.apply_g_op_fC(*this, op); + } + + // If second argument of binary op is Leaf node, recurse on branches + NodePtr apply_g_op_fL(const Leaf& fL, const Binary& op) const { + boost::shared_ptr h(new Choice(label(), nrChoices())); + BOOST_FOREACH(NodePtr branch, branches_) + h->push_back(fL.apply_f_op_g(*branch, op)); + return Unique(h); + } + + // If second argument of binary op is Choice, call constructor + NodePtr apply_g_op_fC(const Choice& fC, const Binary& op) const { + boost::shared_ptr h(new Choice(fC, *this, op)); + return Unique(h); + } + + // If second argument of binary op is Leaf + template + NodePtr apply_fC_op_gL(const Leaf& gL, OP op) const { + boost::shared_ptr h(new Choice(label(), nrChoices())); + BOOST_FOREACH(const NodePtr& branch, branches_) + h->push_back(branch->apply_f_op_g(gL, op)); + return Unique(h); + } + + /** choose a branch, recursively */ + NodePtr choose(const L& label, size_t index) const { + if (label_ == label) + return branches_[index]; // choose branch + + // second case, not label of interest, just recurse + boost::shared_ptr r(new Choice(label_, branches_.size())); + BOOST_FOREACH(const NodePtr& branch, branches_) + r->push_back(branch->choose(label, index)); + return Unique(r); + } + + }; // Choice + + /*********************************************************************************/ + // DecisionTree + /*********************************************************************************/ + template + DecisionTree::DecisionTree() { + } + + template + DecisionTree::DecisionTree(const NodePtr& root) : + root_(root) { + } + + /*********************************************************************************/ + template + DecisionTree::DecisionTree(const Y& y) { + root_ = NodePtr(new Leaf(y)); + } + + /*********************************************************************************/ + template + DecisionTree::DecisionTree(// + const L& label, const Y& y1, const Y& y2) { + boost::shared_ptr a(new Choice(label, 2)); + NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); + a->push_back(l1); + a->push_back(l2); + root_ = Choice::Unique(a); + } + + /*********************************************************************************/ + template + DecisionTree::DecisionTree(// + const LabelC& labelC, const Y& y1, const Y& y2) { + if (labelC.second != 2) throw std::invalid_argument( + "DecisionTree: binary constructor called with non-binary label"); + boost::shared_ptr a(new Choice(labelC.first, 2)); + NodePtr l1(new Leaf(y1)), l2(new Leaf(y2)); + a->push_back(l1); + a->push_back(l2); + root_ = Choice::Unique(a); + } + + /*********************************************************************************/ + template + DecisionTree::DecisionTree(const std::vector& labelCs, + const std::vector& ys) { + // call recursive Create + root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); + } + + /*********************************************************************************/ + template + DecisionTree::DecisionTree(const std::vector& labelCs, + const std::string& table) { + + // Convert std::string to doubles + std::vector ys; + std::istringstream iss(table); + copy(std::istream_iterator(iss), std::istream_iterator(), + back_inserter(ys)); + + // now call recursive Create + root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end()); + } + + /*********************************************************************************/ + template + template DecisionTree::DecisionTree( + Iterator begin, Iterator end, const L& label) { + root_ = compose(begin, end, label); + } + + /*********************************************************************************/ + template + DecisionTree::DecisionTree(const L& label, + const DecisionTree& f0, const DecisionTree& f1) { + std::vector functions; + functions += f0, f1; + root_ = compose(functions.begin(), functions.end(), label); + } + + /*********************************************************************************/ + template + template + DecisionTree::DecisionTree(const DecisionTree& other, + const std::map& map, boost::function op) { + root_ = convert(other.root_, map, op); + } + + /*********************************************************************************/ + // Called by two constructors above. + // Takes a label and a corresponding range of decision trees, and creates a new + // decision tree. However, the order of the labels needs to be respected, so we + // cannot just create a root Choice node on the label: if the label is not the + // highest label, we need to do a complicated and expensive recursive call. + template template + typename DecisionTree::NodePtr DecisionTree::compose( + Iterator begin, Iterator end, const L& label) const { + + // find highest label among branches + boost::optional highestLabel; + boost::optional nrChoices; + for (Iterator it = begin; it != end; it++) { + if (it->root_->isLeaf()) continue; + boost::shared_ptr c = boost::dynamic_pointer_cast (it->root_); + if (!highestLabel || c->label() > *highestLabel) { + highestLabel.reset(c->label()); + nrChoices.reset(c->nrChoices()); + } + } + + // if label is already in correct order, just put together a choice on label + if (!highestLabel || label > *highestLabel) { + boost::shared_ptr choiceOnLabel(new Choice(label, end - begin)); + for (Iterator it = begin; it != end; it++) + choiceOnLabel->push_back(it->root_); + return Choice::Unique(choiceOnLabel); + } + + // Set up a new choice on the highest label + boost::shared_ptr choiceOnHighestLabel(new Choice(*highestLabel, *nrChoices)); + // now, for all possible values of highestLabel + for (size_t index = 0; index < *nrChoices; index++) { + // make a new set of functions for composing by iterating over the given + // functions, and selecting the appropriate branch. + std::vector functions; + for (Iterator it = begin; it != end; it++) { + // by restricting the input functions to value i for labelBelow + DecisionTree chosen = it->choose(*highestLabel, index); + functions.push_back(chosen); + } + // We then recurse, for all values of the highest label + NodePtr fi = compose(functions.begin(), functions.end(), label); + choiceOnHighestLabel->push_back(fi); + } + return Choice::Unique(choiceOnHighestLabel); + } + + /*********************************************************************************/ + // "create" is a bit of a complicated thing, but very useful. + // It takes a range of labels and a corresponding range of values, + // and creates a decision tree, as follows: + // - if there is only one label, creates a choice node with values in leaves + // - otherwise, it evenly splits up the range of values and creates a tree for + // each sub-range, and assigns that tree to first label's choices + // Example: + // create([B A],[1 2 3 4]) would call + // create([A],[1 2]) + // create([A],[3 4]) + // and produce + // B=0 + // A=0: 1 + // A=1: 2 + // B=1 + // A=0: 3 + // A=1: 4 + // Note, through the magic of "compose", create([A B],[1 2 3 4]) will produce + // exactly the same tree as above: the highest label is always the root. + // However, it will be *way* faster if labels are given highest to lowest. + template + template + typename DecisionTree::NodePtr DecisionTree::create( + It begin, It end, ValueIt beginY, ValueIt endY) const { + + // get crucial counts + size_t nrChoices = begin->second; + size_t size = endY - beginY; + + // Find the next key to work on + It labelC = begin + 1; + if (labelC == end) { + // Base case: only one key left + // Create a simple choice node with values as leaves. + if (size != nrChoices) { + std::cout << "Trying to create DD on " << begin->first << std::endl; + std::cout << boost::format("DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl; + throw std::invalid_argument("DecisionTree::create invalid argument"); + } + boost::shared_ptr choice(new Choice(begin->first, endY - beginY)); + for (ValueIt y = beginY; y != endY; y++) + choice->push_back(NodePtr(new Leaf(*y))); + return Choice::Unique(choice); + } + + // Recursive case: perform "Shannon expansion" + // Creates one tree (i.e.,function) for each choice of current key + // by calling create recursively, and then puts them all together. + std::vector functions; + size_t split = size / nrChoices; + for (size_t i = 0; i < nrChoices; i++, beginY += split) { + NodePtr f = create(labelC, end, beginY, beginY + split); + functions += DecisionTree(f); + } + return compose(functions.begin(), functions.end(), begin->first); + } + + /*********************************************************************************/ + template + template + typename DecisionTree::NodePtr DecisionTree::convert( + const typename DecisionTree::NodePtr& f, const std::map& map, + boost::function op) { + + typedef DecisionTree MX; + typedef typename MX::Leaf MXLeaf; + typedef typename MX::Choice MXChoice; + typedef typename MX::NodePtr MXNodePtr; + typedef DecisionTree LY; + + // ugliness below because apparently we can't have templated virtual functions + // If leaf, apply unary conversion "op" and create a unique leaf + const MXLeaf* leaf = dynamic_cast (f.get()); + if (leaf) return NodePtr(new Leaf(op(leaf->constant()))); + + // Check if Choice + boost::shared_ptr choice = boost::dynamic_pointer_cast (f); + if (!choice) throw std::invalid_argument( + "DecisionTree::Convert: Invalid NodePtr"); + + // get new label + M oldLabel = choice->label(); + L newLabel = map.at(oldLabel); + + // put together via Shannon expansion otherwise not sorted. + std::vector functions; + BOOST_FOREACH(const MXNodePtr& branch, choice->branches()) { + LY converted(convert(branch, map, op)); + functions += converted; + } + return LY::compose(functions.begin(), functions.end(), newLabel); + } + + /*********************************************************************************/ + template + bool DecisionTree::equals(const DecisionTree& other, double tol) const { + return root_->equals(*other.root_, tol); + } + + template + void DecisionTree::print(const std::string& s) const { + root_->print(s); + } + + template + bool DecisionTree::operator==(const DecisionTree& other) const { + return root_->equals(*other.root_); + } + + template + const Y& DecisionTree::operator()(const Assignment& x) const { + return root_->operator ()(x); + } + + template + DecisionTree DecisionTree::apply(const Unary& op) const { + return DecisionTree(root_->apply(op)); + } + + /*********************************************************************************/ + template + DecisionTree DecisionTree::apply(const DecisionTree& g, + const Binary& op) const { + // apply the operaton on the root of both diagrams + NodePtr h = root_->apply_f_op_g(*g.root_, op); + // create a new class with the resulting root "h" + DecisionTree result(h); + return result; + } + + /*********************************************************************************/ + // The way this works: + // We have an ADT, picture it as a tree. + // At a certain depth, we have a branch on "label". + // The function "choose(label,index)" will return a tree of one less depth, + // where there is no more branch on "label": only the subtree under that + // branch point corresponding to the value "index" is left instead. + // The function below get all these smaller trees and "ops" them together. + template + DecisionTree DecisionTree::combine(const L& label, + size_t cardinality, const Binary& op) const { + DecisionTree result = choose(label, 0); + for (size_t index = 1; index < cardinality; index++) { + DecisionTree chosen = choose(label, index); + result = result.apply(chosen, op); + } + return result; + } + + /*********************************************************************************/ + template + void DecisionTree::dot(std::ostream& os, bool showZero) const { + os << "digraph G {\n"; + root_->dot(os, showZero); + os << " [ordering=out]}" << std::endl; + } + + template + void DecisionTree::dot(const std::string& name, bool showZero) const { + std::ofstream os((name + ".dot").c_str()); + dot(os, showZero); + system( + ("dot -Tpdf " + name + ".dot -o " + name + ".pdf >& /dev/null").c_str()); + } + +/*********************************************************************************/ + +} // namespace gtsam + + diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h new file mode 100644 index 000000000..75263475d --- /dev/null +++ b/gtsam/discrete/DecisionTree.h @@ -0,0 +1,218 @@ +/* + * @file DecisionTree.h + * @brief Decision Tree for use in DiscreteFactors + * @author Frank Dellaert + * @author Can Erdogan + * @date Jan 30, 2012 + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace gtsam { + + /** + * Algebraic Decision Trees + * L = label for variables + * Y = function range (any algebra), e.g., bool, int, double + */ + template + class DecisionTree { + + public: + + /** Handy typedefs for unary and binary function types */ + typedef boost::function Unary; + typedef boost::function Binary; + + /** A label annotated with cardinality */ + typedef std::pair LabelC; + + /** DD's consist of Leaf and Choice nodes, both subclasses of Node */ + class Leaf; + class Choice; + + /** ------------------------ Node base class --------------------------- */ + class Node { + public: + typedef boost::shared_ptr Ptr; + +#ifdef DT_DEBUG_MEMORY + static int nrNodes; +#endif + + // Constructor + Node() { +#ifdef DT_DEBUG_MEMORY + std::cout << ++nrNodes << " constructed " << id() << std::endl; std::cout.flush(); + +#endif + } + + // Destructor + virtual ~Node() { +#ifdef DT_DEBUG_MEMORY + std::cout << --nrNodes << " destructed " << id() << std::endl; std::cout.flush(); + +#endif + } + + // Unique ID for dot files + const void* id() const { return this; } + + // everything else is virtual, no documentation here as internal + virtual void print(const std::string& s = "") const = 0; + virtual void dot(std::ostream& os, bool showZero) const = 0; + virtual bool sameLeaf(const Leaf& q) const = 0; + virtual bool sameLeaf(const Node& q) const = 0; + virtual bool equals(const Node& other, double tol = 1e-9) const = 0; + virtual const Y& operator()(const Assignment& x) const = 0; + virtual Ptr apply(const Unary& op) const = 0; + virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0; + virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0; + virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0; + virtual Ptr choose(const L& label, size_t index) const = 0; + virtual bool isLeaf() const = 0; + }; + /** ------------------------ Node base class --------------------------- */ + + public: + + /** A function is a shared pointer to the root of an ADD */ + typedef typename Node::Ptr NodePtr; + + /* an AlgebraicDecisionTree just contains the root */ + NodePtr root_; + + protected: + + /** Internal recursive function to create from keys, cardinalities, and Y values */ + template + NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const; + + /** Convert to a different type */ + template NodePtr + convert(const typename DecisionTree::NodePtr& f, const std::map& map, boost::function op); + + /** Default constructor */ + DecisionTree(); + + public: + + /// @name Standard Constructors + /// @{ + + /** Create a constant */ + DecisionTree(const Y& y); + + /** Create a new leaf function splitting on a variable */ + DecisionTree(const L& label, const Y& y1, const Y& y2); + + /** Allow Label+Cardinality for convenience */ + DecisionTree(const LabelC& label, const Y& y1, const Y& y2); + + /** Create from keys and string table */ + DecisionTree(const std::vector& labelCs, const std::vector& ys); + + /** Create from keys and string table */ + DecisionTree(const std::vector& labelCs, const std::string& table); + + /** Create DecisionTree from others */ + template + DecisionTree(Iterator begin, Iterator end, const L& label); + + /** Create DecisionTree from others others (binary version) */ + DecisionTree(const L& label, // + const DecisionTree& f0, const DecisionTree& f1); + + /** Convert from a different type */ + template + DecisionTree(const DecisionTree& other, + const std::map& map, boost::function op); + + /// @} + /// @name Testable + /// @{ + + /** GTSAM-style print */ + void print(const std::string& s = "DecisionTree") const; + + // Testable + bool equals(const DecisionTree& other, double tol = 1e-9) const; + + /// @} + /// @name Standard Interface + /// @{ + + /** Make virtual */ + virtual ~DecisionTree() { + } + + /** equality */ + bool operator==(const DecisionTree& q) const; + + /** evaluate */ + const Y& operator()(const Assignment& x) const; + + /** apply Unary operation "op" to f */ + DecisionTree apply(const Unary& op) const; + + /** apply binary operation "op" to f and g */ + DecisionTree apply(const DecisionTree& g, const Binary& op) const; + + /** create a new function where value(label)==index */ + DecisionTree choose(const L& label, size_t index) const { + NodePtr newRoot = root_->choose(label, index); + return DecisionTree(newRoot); + } + + /** combine subtrees on key with binary operation "op" */ + DecisionTree combine(const L& label, size_t cardinality, const Binary& op) const; + + /** combine with LabelC for convenience */ + DecisionTree combine(const LabelC& labelC, const Binary& op) const { + return combine(labelC.first, labelC.second, op); + } + + /** output to graphviz format, stream version */ + void dot(std::ostream& os, bool showZero = true) const; + + /** output to graphviz format, open a file */ + void dot(const std::string& name, bool showZero = true) const; + + /// @name Advanced Interface + /// @{ + + // internal use only + DecisionTree(const NodePtr& root); + + // internal use only + template NodePtr + compose(Iterator begin, Iterator end, const L& label) const; + + /// @} + + }; // DecisionTree + + /** free versions of apply */ + + template + DecisionTree apply(const DecisionTree& f, + const typename DecisionTree::Unary& op) { + return f.apply(op); + } + + template + DecisionTree apply(const DecisionTree& f, + const DecisionTree& g, + const typename DecisionTree::Binary& op) { + return f.apply(g, op); + } + +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp new file mode 100644 index 000000000..f883a6026 --- /dev/null +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -0,0 +1,91 @@ +/* + * DecisionTreeFactor.cpp + * @brief: discrete factor + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include +#include + +#include + +using namespace std; + +namespace gtsam { + + /* ******************************************************************************** */ + DecisionTreeFactor::DecisionTreeFactor() { + } + + /* ******************************************************************************** */ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys, + const ADT& potentials) : + DiscreteFactor(keys.indices()), Potentials(keys, potentials) { + } + + /* *************************************************************************/ + DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) : + DiscreteFactor(c.keys()), Potentials(c) { + } + + /* ************************************************************************* */ + bool DecisionTreeFactor::equals(const This& other, double tol) const { + return IndexFactor::equals(other, tol) && Potentials::equals(other, tol); + } + + /* ************************************************************************* */ + void DecisionTreeFactor::print(const string& s) const { + cout << s << ":\n"; + IndexFactor::print("IndexFactor:"); + Potentials::print("Potentials:"); + } + + /* ************************************************************************* */ + DecisionTreeFactor DecisionTreeFactor::apply // + (const DecisionTreeFactor& f, ADT::Binary op) const { + map cs; // new cardinalities + // make unique key-cardinality map + BOOST_FOREACH(Index j, keys()) cs[j] = cardinality(j); + BOOST_FOREACH(Index j, f.keys()) cs[j] = f.cardinality(j); + // Convert map into keys + DiscreteKeys keys; + BOOST_FOREACH(const DiscreteKey& key, cs) + keys.push_back(key); + // apply operand + ADT result = ADT::apply(f, op); + // Make a new factor + return DecisionTreeFactor(keys, result); + } + + /* ************************************************************************* */ + DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine // + (size_t nrFrontals, ADT::Binary op) const { + + if (nrFrontals == 0 || nrFrontals > size()) throw invalid_argument( + (boost::format( + "DecisionTreeFactor::combine: invalid number of frontal keys %d, nr.keys=%d") + % nrFrontals % size()).str()); + + // sum over nrFrontals keys + size_t i; + ADT result(*this); + for (i = 0; i < nrFrontals; i++) { + Index j = keys()[i]; + result = result.combine(j, cardinality(j), op); + } + + // create new factor, note we start keys after nrFrontals + DiscreteKeys dkeys; + for (; i < keys().size(); i++) { + Index j = keys()[i]; + dkeys.push_back(DiscreteKey(j,cardinality(j))); + } + shared_ptr f(new DecisionTreeFactor(dkeys, result)); + return f; + } + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h new file mode 100644 index 000000000..b68730e34 --- /dev/null +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -0,0 +1,148 @@ +/* + * DecisionTreeFactor.h + * + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + */ + +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace gtsam { + + class DiscreteConditional; + + /** + * A discrete probabilistic factor + */ + class DecisionTreeFactor: public DiscreteFactor, public Potentials { + + public: + + // typedefs needed to play nice with gtsam + typedef DecisionTreeFactor This; + typedef DiscreteConditional ConditionalType; + typedef boost::shared_ptr shared_ptr; + + /// Index label and cardinality + typedef std::pair IndexC; + + public: + + /// @name Standard Constructors + /// @{ + + /** Default constructor for I/O */ + DecisionTreeFactor(); + + /** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */ + DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); + + /** Constructor from Indices and (string or doubles) */ + template + DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) : + DiscreteFactor(keys.indices()), Potentials(keys, table) { + } + + /** Construct from a DiscreteConditional type */ + DecisionTreeFactor(const DiscreteConditional& c); + + /// @} + /// @name Testable + /// @{ + + /// equality + bool equals(const DecisionTreeFactor& other, double tol = 1e-9) const; + + // print + void print(const std::string& s = "DecisionTreeFactor: ") const; + + /// @} + /// @name Standard Interface + /// @{ + + /// Value is just look up in AlgebraicDecisonTree + virtual double operator()(const Values& values) const { + return Potentials::operator()(values); + } + + /// multiply two factors + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const { + return apply(f, ADT::Ring::mul); + } + + /// divide by factor f (safely) + DecisionTreeFactor operator/(const DecisionTreeFactor& f) const { + return apply(f, safe_div); + } + + /// Convert into a decisiontree + virtual operator DecisionTreeFactor() const { + return *this; + } + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(size_t nrFrontals) const { + return combine(nrFrontals, ADT::Ring::add); + } + + /// Create new factor by maximizing over all values with the same separator values + shared_ptr max(size_t nrFrontals) const { + return combine(nrFrontals, ADT::Ring::max); + } + + /// @} + /// @name Advanced Interface + /// @{ + + /** + * Apply binary operator (*this) "op" f + * @param f the second argument for op + * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + */ + DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const; + + /** + * Combine frontal variables using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @return shared pointer to newly created DecisionTreeFactor + */ + shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + /// + bool ensureArcConsistency(size_t j, std::vector& domains) const { +// throw std::runtime_error( +// "DecisionTreeFactor::ensureArcConsistency not implemented"); + return false; + } + + /// Partially apply known values + virtual DiscreteFactor::shared_ptr partiallyApply(const Values&) const { + throw std::runtime_error("DecisionTreeFactor::partiallyApply not implemented"); + } + + /// Partially apply known values, domain version + virtual DiscreteFactor::shared_ptr partiallyApply( + const std::vector&) const { + throw std::runtime_error("DecisionTreeFactor::partiallyApply not implemented"); + } + /// @} + }; +// DecisionTreeFactor + +}// namespace gtsam diff --git a/gtsam/discrete/DiscreteBayesNet.cpp b/gtsam/discrete/DiscreteBayesNet.cpp new file mode 100644 index 000000000..e233743eb --- /dev/null +++ b/gtsam/discrete/DiscreteBayesNet.cpp @@ -0,0 +1,48 @@ +/* + * DiscreteBayesNet.cpp + * + * @date Feb 15, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include +#include +#include + +namespace gtsam { + + // Explicitly instantiate so we don't have to include everywhere + template class BayesNet ; + + /* ************************************************************************* */ + void add_front(DiscreteBayesNet& bayesNet, const Signature& s) { + bayesNet.push_front(boost::make_shared(s)); + } + + /* ************************************************************************* */ + void add(DiscreteBayesNet& bayesNet, const Signature& s) { + bayesNet.push_back(boost::make_shared(s)); + } + + /* ************************************************************************* */ + DiscreteFactor::sharedValues optimize(const DiscreteBayesNet& bn) { + // solve each node in turn in topological sort order (parents first) + DiscreteFactor::sharedValues result(new DiscreteFactor::Values()); + BOOST_REVERSE_FOREACH (DiscreteConditional::shared_ptr conditional, bn) + conditional->solveInPlace(*result); + return result; + } + + /* ************************************************************************* */ + DiscreteFactor::sharedValues sample(const DiscreteBayesNet& bn) { + // sample each node in turn in topological sort order (parents first) + DiscreteFactor::sharedValues result(new DiscreteFactor::Values()); + BOOST_REVERSE_FOREACH(DiscreteConditional::shared_ptr conditional, bn) + conditional->sampleInPlace(*result); + return result; + } + +/* ************************************************************************* */ +} // namespace diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h new file mode 100644 index 000000000..c0bba84eb --- /dev/null +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -0,0 +1,33 @@ +/* + * DiscreteBayesNet.h + * + * @date Feb 15, 2011 + * @author Duy-Nguyen Ta + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace gtsam { + + typedef BayesNet DiscreteBayesNet; + + /** Add a DiscreteCondtional */ + void add(DiscreteBayesNet&, const Signature& s); + + /** Add a DiscreteCondtional in front, when listing parents first*/ + void add_front(DiscreteBayesNet&, const Signature& s); + + /** Optimize function for back-substitution. */ + DiscreteFactor::sharedValues optimize(const DiscreteBayesNet& bn); + + /** Do ancestral sampling */ + DiscreteFactor::sharedValues sample(const DiscreteBayesNet& bn); + +} // namespace + diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp new file mode 100644 index 000000000..00e41ce63 --- /dev/null +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -0,0 +1,152 @@ +/* + * DiscreteConditional.cpp + * + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +using namespace std; + +namespace gtsam { + + /* ******************************************************************************** */ + DiscreteConditional::DiscreteConditional(const size_t nrFrontals, + const DecisionTreeFactor& f) : + IndexConditional(f.keys(), nrFrontals), Potentials( + f / (*f.sum(nrFrontals))) { + } + + /* ******************************************************************************** */ + DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal) : + IndexConditional(joint.keys(), joint.size() - marginal.size()), Potentials( + ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal) { + assert(nrFrontals() == 1); + if (ISDEBUG("DiscreteConditional::DiscreteConditional")) cout + << (firstFrontalKey()) << endl; + } + + /* ******************************************************************************** */ + DiscreteConditional::DiscreteConditional(const Signature& signature) : + IndexConditional(signature.indices(), 1), Potentials( + signature.discreteKeysParentsFirst(), signature.cpt()) { + } + + /* ******************************************************************************** */ + Potentials::ADT DiscreteConditional::choose( + const Values& parentsValues) const { + ADT pFS(*this); + BOOST_FOREACH(Index key, parents()) + try { + Index j = (key); + size_t value = parentsValues.at(j); + pFS = pFS.choose(j, value); + } catch (exception& e) { + throw runtime_error( + "DiscreteConditional::choose: parent value missing"); + }; + return pFS; + } + + /* ******************************************************************************** */ + void DiscreteConditional::solveInPlace(Values& values) const { + assert(nrFrontals() == 1); + Index j = (firstFrontalKey()); + size_t mpe = solve(values); // Solve for variable + values[j] = mpe; // store result in partial solution + } + + /* ******************************************************************************** */ + void DiscreteConditional::sampleInPlace(Values& values) const { + assert(nrFrontals() == 1); + Index j = (firstFrontalKey()); + size_t sampled = sample(values); // Sample variable + values[j] = sampled; // store result in partial solution + } + + /* ******************************************************************************** */ + size_t DiscreteConditional::solve(const Values& parentsValues) const { + + // TODO: is this really the fastest way? I think it is. + ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + + // Then, find the max over all remaining + // TODO, only works for one key now, seems horribly slow this way + size_t mpe = 0; + Values frontals; + double maxP = 0; + assert(nrFrontals() == 1); + Index j = (firstFrontalKey()); + for (size_t value = 0; value < cardinality(j); value++) { + frontals[j] = value; + double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + mpe = value; + } + } + return mpe; + } + + /* ******************************************************************************** */ + size_t DiscreteConditional::sample(const Values& parentsValues) const { + + using boost::uniform_real; + static boost::mt19937 gen(2); // random number generator + + bool debug = ISDEBUG("DiscreteConditional::sample"); + + // Get the correct conditional density + ADT pFS = choose(parentsValues); // P(F|S=parentsValues) + if (debug) GTSAM_PRINT(pFS); + + // get cumulative distribution function (cdf) + // TODO, only works for one key now, seems horribly slow this way + assert(nrFrontals() == 1); + Index j = (firstFrontalKey()); + size_t nj = cardinality(j); + vector cdf(nj); + Values frontals; + double sum = 0; + for (size_t value = 0; value < nj; value++) { + frontals[j] = value; + double pValueS = pFS(frontals); // P(F=value|S=parentsValues) + sum += pValueS; // accumulate + if (debug) cout << sum << " "; + if (pValueS == 1) { + if (debug) cout << "--> " << value << endl; + return value; // shortcut exit + } + cdf[value] = sum; + } + + // inspired by http://www.boost.org/doc/libs/1_46_1/doc/html/boost_random/tutorial.html + uniform_real<> dist(0, cdf.back()); + boost::variate_generator > die(gen, dist); + size_t sampled = lower_bound(cdf.begin(), cdf.end(), die()) - cdf.begin(); + if (debug) cout << "-> " << sampled << endl; + + return sampled; + + return 0; + } + +/* ******************************************************************************** */ + +} // namespace diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h new file mode 100644 index 000000000..5a7881dd9 --- /dev/null +++ b/gtsam/discrete/DiscreteConditional.h @@ -0,0 +1,110 @@ +/* + * DiscreteConditional.h + * + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#pragma once + +#include +#include +#include +#include + +namespace gtsam { + + /** + * Discrete Conditional Density + * Derives from DecisionTreeFactor + */ + class DiscreteConditional: public IndexConditional, public Potentials { + + public: + // typedefs needed to play nice with gtsam + typedef DiscreteFactor FactorType; + typedef boost::shared_ptr shared_ptr; + typedef IndexConditional Base; + + /** A map from keys to values */ + typedef Assignment Values; + typedef boost::shared_ptr sharedValues; + + /// @name Standard Constructors + /// @{ + + /** default constructor needed for serialization */ + DiscreteConditional() { + } + + /** constructor from factor */ + DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + + /** Construct from signature */ + DiscreteConditional(const Signature& signature); + + /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ + DiscreteConditional(const DecisionTreeFactor& joint, + const DecisionTreeFactor& marginal); + + /// @} + /// @name Testable + /// @{ + + /** GTSAM-style print */ + void print(const std::string& s = "Discrete Conditional: ") const { + std::cout << s << std::endl; + IndexConditional::print(s); + Potentials::print(s); + } + + /** GTSAM-style equals */ + bool equals(const DiscreteConditional& other, double tol = 1e-9) const { + return IndexConditional::equals(other, tol) + && Potentials::equals(other, tol); + } + + /// @} + /// @name Standard Interface + /// @{ + + /** Convert to a factor */ + DecisionTreeFactor::shared_ptr toFactor() const { + return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); + } + + /** Restrict to given parent values, returns AlgebraicDecisionDiagram */ + ADT choose(const Assignment& parentsValues) const; + + /** + * solve a conditional + * @param parentsAssignment Known values of the parents + * @return MPE value of the child (1 frontal variable). + */ + size_t solve(const Values& parentsValues) const; + + /** + * sample + * @param parentsAssignment Known values of the parents + * @return sample from conditional + */ + size_t sample(const Values& parentsValues) const; + + /// @} + /// @name Advanced Interface + /// @{ + + /// solve a conditional, in place + void solveInPlace(Values& parentsValues) const; + + /// sample in place, stores result in partial solution + void sampleInPlace(Values& parentsValues) const; + + /// @} + + }; +// DiscreteConditional + +}// gtsam + diff --git a/gtsam/discrete/DiscreteFactor.cpp b/gtsam/discrete/DiscreteFactor.cpp new file mode 100644 index 000000000..6ed7b70ce --- /dev/null +++ b/gtsam/discrete/DiscreteFactor.cpp @@ -0,0 +1,21 @@ +/* + * DiscreteFactor.cpp + * @brief: discrete factor + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include + +using namespace std; + +namespace gtsam { + + /* ******************************************************************************** */ + DiscreteFactor::DiscreteFactor() { + } + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h new file mode 100644 index 000000000..8b20b3f18 --- /dev/null +++ b/gtsam/discrete/DiscreteFactor.h @@ -0,0 +1,108 @@ +/* + * DiscreteFactor.h + * + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + */ + +#pragma once + +#include +#include + +namespace gtsam { + + class DecisionTreeFactor; + class DiscreteConditional; + class Domain; + + /** + * Base class for discrete probabilistic factors + * The most general one is the derived DecisionTreeFactor + */ + class DiscreteFactor: public IndexFactor { + + public: + + // typedefs needed to play nice with gtsam + typedef DiscreteFactor This; + typedef DiscreteConditional ConditionalType; + typedef boost::shared_ptr shared_ptr; + + /** A map from keys to values */ + typedef Assignment Values; + typedef boost::shared_ptr sharedValues; + + protected: + + /// Construct n-way factor + DiscreteFactor(const std::vector& js) : + IndexFactor(js) { + } + + /// Construct unary factor + DiscreteFactor(Index j) : + IndexFactor(j) { + } + + /// Construct binary factor + DiscreteFactor(Index j1, Index j2) : + IndexFactor(j1, j2) { + } + + /// construct from container + template + DiscreteFactor(KeyIterator beginKey, KeyIterator endKey) : + IndexFactor(beginKey, endKey) { + } + + public: + + /// @name Standard Constructors + /// @{ + + /// Default constructor for I/O + DiscreteFactor(); + + /// Virtual destructor + virtual ~DiscreteFactor() {} + + /// @} + /// @name Testable + /// @{ + + // print + virtual void print(const std::string& s = "DiscreteFactor") const { + IndexFactor::print(s); + } + + /// @} + /// @name Standard Interface + /// @{ + + /// Find value for given assignment of values to variables + virtual double operator()(const Values&) const = 0; + + /// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor + virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; + + virtual operator DecisionTreeFactor() const = 0; + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + virtual bool ensureArcConsistency(size_t j, std::vector& domains) const = 0; + + /// Partially apply known values + virtual shared_ptr partiallyApply(const Values&) const = 0; + + + /// Partially apply known values, domain version + virtual shared_ptr partiallyApply(const std::vector&) const = 0; + /// @} + }; +// DiscreteFactor + +}// namespace gtsam diff --git a/gtsam/discrete/DiscreteFactorGraph.cpp b/gtsam/discrete/DiscreteFactorGraph.cpp new file mode 100644 index 000000000..3c9c42689 --- /dev/null +++ b/gtsam/discrete/DiscreteFactorGraph.cpp @@ -0,0 +1,82 @@ +/* + * DiscreteFactorGraph.cpp + * + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + */ + +//#define ENABLE_TIMING +#include +#include +#include +#include + +namespace gtsam { + +// Explicitly instantiate so we don't have to include everywhere +template class FactorGraph ; +template class EliminationTree ; + +/* ************************************************************************* */ +DiscreteFactorGraph::DiscreteFactorGraph() { +} + +/* ************************************************************************* */ +DiscreteFactorGraph::DiscreteFactorGraph( + const BayesNet& bayesNet) : + FactorGraph(bayesNet) { +} + +/* ************************************************************************* */ +FastSet DiscreteFactorGraph::keys() const { + FastSet keys; + BOOST_FOREACH(const sharedFactor& factor, *this) + if (factor) keys.insert(factor->begin(), factor->end()); + return keys; +} + +/* ************************************************************************* */ +DecisionTreeFactor DiscreteFactorGraph::product() const { + DecisionTreeFactor result; + BOOST_FOREACH(const sharedFactor& factor, *this) + if (factor) result = (*factor) * result; + return result; +} + +/* ************************************************************************* */ +double DiscreteFactorGraph::operator()( + const DiscreteFactor::Values &values) const { + double product = 1.0; + BOOST_FOREACH( const sharedFactor& factor, factors_ ) + product *= (*factor)(values); + return product; +} + +/* ************************************************************************* */ +pair // +EliminateDiscrete(const FactorGraph& factors, size_t num) { + + // PRODUCT: multiply all factors + tic(1, "product"); + DecisionTreeFactor product; + BOOST_FOREACH(const DiscreteFactor::shared_ptr& factor, factors) + product = (*factor) * product; + toc(1, "product"); + + // sum out frontals, this is the factor on the separator + tic(2, "sum"); + DecisionTreeFactor::shared_ptr sum = product.sum(num); + toc(2, "sum"); + + // now divide product/sum to get conditional + tic(3, "divide"); + DiscreteConditional::shared_ptr cond(new DiscreteConditional(product, *sum)); + toc(3, "divide"); + tictoc_finishedIteration(); + + return make_pair(cond, sum); +} + +/* ************************************************************************* */ +} // namespace + diff --git a/gtsam/discrete/DiscreteFactorGraph.h b/gtsam/discrete/DiscreteFactorGraph.h new file mode 100644 index 000000000..8afacc51b --- /dev/null +++ b/gtsam/discrete/DiscreteFactorGraph.h @@ -0,0 +1,87 @@ +/* + * DiscreteFactorGraph.h + * + * @date Feb 14, 2011 + * @author Duy-Nguyen Ta + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace gtsam { + +class DiscreteFactorGraph: public FactorGraph { +public: + + /** A map from keys to values */ + typedef std::vector Indices; + typedef Assignment Values; + typedef boost::shared_ptr sharedValues; + + /** Construct empty factor graph */ + DiscreteFactorGraph(); + + /** Constructor from a factor graph of GaussianFactor or a derived type */ + template + DiscreteFactorGraph(const FactorGraph& fg) { + push_back(fg); + } + + /** construct from a BayesNet */ + DiscreteFactorGraph(const BayesNet& bayesNet); + + template + void add(const DiscreteKey& j, SOURCE table) { + DiscreteKeys keys; + keys.push_back(j); + push_back(boost::make_shared(keys, table)); + } + + template + void add(const DiscreteKey& j1, const DiscreteKey& j2, SOURCE table) { + DiscreteKeys keys; + keys.push_back(j1); + keys.push_back(j2); + push_back(boost::make_shared(keys, table)); + } + + /** add shared discreteFactor immediately from arguments */ + template + void add(const DiscreteKeys& keys, SOURCE table) { + push_back(boost::make_shared(keys, table)); + } + + /** Return the set of variables involved in the factors (set union) */ + FastSet keys() const; + + /** return product of all factors as a single factor */ + DecisionTreeFactor product() const; + + /** Evaluates the factor graph given values, returns the joint probability of the factor graph given specific instantiation of values*/ + double operator()(const DiscreteFactor::Values & values) const; + + /// print + void print(const std::string& s = "DiscreteFactorGraph") const { + std::cout << s << std::endl; + std::cout << "size: " << size() << std::endl; + for (size_t i = 0; i < factors_.size(); i++) { + std::stringstream ss; + ss << "factor " << i << ": "; + if (factors_[i] != NULL) factors_[i]->print(ss.str()); + } + } + +}; +// DiscreteFactorGraph + +/** Main elimination function for DiscreteFactorGraph */ +std::pair, DecisionTreeFactor::shared_ptr> +EliminateDiscrete(const FactorGraph& factors, + size_t nrFrontals = 1); + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteKey.cpp b/gtsam/discrete/DiscreteKey.cpp new file mode 100644 index 000000000..9ab4207bb --- /dev/null +++ b/gtsam/discrete/DiscreteKey.cpp @@ -0,0 +1,57 @@ +/* + * DiscreteKey.h + * @brief specialized key for discrete variables + * @author Frank Dellaert + * @date Feb 28, 2011 + */ + +#include +#include // for key names +#include // FOREACH +#include "DiscreteKey.h" + +namespace gtsam { + + using namespace std; + + bool OldDiscreteKey::equals(const OldDiscreteKey& other, double tol) const { + return (*this == other); + } + + void OldDiscreteKey::print(const string& s) const { + cout << s << *this; + } + + ostream& operator <<(ostream &os, const OldDiscreteKey &key) { + os << key.name_; + return os; + } + + DiscreteKeys::DiscreteKeys(const vector& cs) { + for (size_t i = 0; i < cs.size(); i++) { + string name = boost::str(boost::format("v%1%") % i); + push_back(DiscreteKey(i, cs[i])); + } + } + + vector DiscreteKeys::indices() const { + vector < Index > js; + BOOST_FOREACH(const DiscreteKey& key, *this) + js.push_back(key.first); + return js; + } + + map DiscreteKeys::cardinalities() const { + map cs; + cs.insert(begin(),end()); +// BOOST_FOREACH(const DiscreteKey& key, *this) +// cs.insert(key); + return cs; + } + + DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2) { + DiscreteKeys keys(key1); + return keys & key2; + } + +} diff --git a/gtsam/discrete/DiscreteKey.h b/gtsam/discrete/DiscreteKey.h new file mode 100644 index 000000000..23deaf9d8 --- /dev/null +++ b/gtsam/discrete/DiscreteKey.h @@ -0,0 +1,133 @@ +/* + * DiscreteKey.h + * @brief specialized key for discrete variables + * @author Frank Dellaert + * @date Feb 28, 2011 + */ + +#pragma once + +#include +#include + +#include +#include +#include + +namespace gtsam { + + typedef std::pair DiscreteKey; + + /** + * Key type for discrete conditionals + * Includes name and cardinality + */ + class OldDiscreteKey : std::pair { + + private: + + std::string name_; + + public: + + /** Default constructor */ + OldDiscreteKey() : + std::pair(0,0), name_("default") { + } + + /** Constructor, defaults to binary */ + OldDiscreteKey(Index j, const std::string& name, size_t cardinality = 2) : + std::pair(j,cardinality), name_(name) { + } + + virtual ~OldDiscreteKey() { + } + + // Testable + bool equals(const OldDiscreteKey& other, double tol = 1e-9) const; + void print(const std::string& s = "") const; + + operator Index() const { return first; } + + const std::string& name() const { + return name_; + } + + size_t cardinality() const { + return second; + } + + /** compare 2 keys by their name */ + bool operator <(const OldDiscreteKey& other) const { + return name_ < other.name_; + } + + /** equality */ + bool operator==(const OldDiscreteKey& other) const { + return (first == other.first) && (second == other.second) && (name_ == other.name_); + } + + bool operator!=(const OldDiscreteKey& other) const { + return !(*this == other); + } + + /** provide streaming */ + friend std::ostream& operator <<(std::ostream &os, const OldDiscreteKey &key); + + }; // OldDiscreteKey + + /// DiscreteKeys is a set of keys that can be assembled using the & operator + struct DiscreteKeys: public std::vector { + + /// Default constructor + DiscreteKeys() { + } + + /// Construct from a key + DiscreteKeys(const DiscreteKey& key) { + push_back(key); + } + + /// Construct from a vector of keys + DiscreteKeys(const std::vector& keys) : + std::vector(keys) { + } + + /// Construct from cardinalities with default names + DiscreteKeys(const std::vector& cs); + + /// Return a vector of indices + std::vector indices() const; + + /// Return a map from index to cardinality + std::map cardinalities() const; + + /// Add a key (non-const!) + DiscreteKeys& operator&(const DiscreteKey& key) { + push_back(key); + return *this; + } + }; // DiscreteKeys + + /// Create a list from two keys + DiscreteKeys operator&(const DiscreteKey& key1, const DiscreteKey& key2); + + /// traits class for DiscreteKey for use with DecisionTree/DecisionDiagram + template<> + struct label_traits { + /** get cardinality from type */ + static size_t cardinality(const OldDiscreteKey& key) { + return key.cardinality(); + } + /** compare 2 keys by their name */ + static bool higher(const OldDiscreteKey& a, const OldDiscreteKey& b) { + return a.name() < b.name(); + } + /** hash function */ + static size_t hash_value(const OldDiscreteKey& a) { + boost::hash hasher; + return hasher(a.name()); + } + }; + +} diff --git a/gtsam/discrete/DiscreteSequentialSolver.cpp b/gtsam/discrete/DiscreteSequentialSolver.cpp new file mode 100644 index 000000000..86cea7ffb --- /dev/null +++ b/gtsam/discrete/DiscreteSequentialSolver.cpp @@ -0,0 +1,47 @@ +/* + * DiscreteSequentialSolver.cpp + * + * @date Feb 16, 2011 + * @author Duy-Nguyen Ta + */ + +//#define ENABLE_TIMING +#include +#include +#include +#include + +namespace gtsam { + + template class GenericSequentialSolver ; + + /* ************************************************************************* */ + DiscreteFactor::sharedValues DiscreteSequentialSolver::optimize() const { + + static const bool debug = false; + + if (debug) this->factors_->print("DiscreteSequentialSolver, eliminating "); + if (debug) this->eliminationTree_->print( + "DiscreteSequentialSolver, elimination tree "); + + // Eliminate using the elimination tree + tic(1, "eliminate"); + DiscreteBayesNet::shared_ptr bayesNet = eliminate(); + toc(1, "eliminate"); + + if (debug) bayesNet->print("DiscreteSequentialSolver, Bayes net "); + + // Allocate the solution vector if it is not already allocated + + // Back-substitute + tic(2, "optimize"); + DiscreteFactor::sharedValues solution = gtsam::optimize(*bayesNet); + toc(2, "optimize"); + + if (debug) solution->print("DiscreteSequentialSolver, solution "); + + return solution; + } +/* ************************************************************************* */ + +} diff --git a/gtsam/discrete/DiscreteSequentialSolver.h b/gtsam/discrete/DiscreteSequentialSolver.h new file mode 100644 index 000000000..0be149646 --- /dev/null +++ b/gtsam/discrete/DiscreteSequentialSolver.h @@ -0,0 +1,97 @@ +/* + * DiscreteSequentialSolver.h + * + * @date Feb 16, 2011 + * @author Duy-Nguyen Ta + */ + +#pragma once + +#include +#include +#include + +namespace gtsam { + // The base class provides all of the needed functionality + + class DiscreteSequentialSolver: public GenericSequentialSolver { + + protected: + typedef GenericSequentialSolver Base; + typedef boost::shared_ptr shared_ptr; + + public: + + /** + * The problem we are trying to solve (SUM or MPE). + */ + typedef enum { + BEL, // Belief updating (or conditional updating) + MPE, // Most-Probable-Explanation + MAP + // Maximum A Posteriori hypothesis + } ProblemType; + + /** + * Construct the solver for a factor graph. This builds the elimination + * tree, which already does some of the work of elimination. + */ + DiscreteSequentialSolver(const FactorGraph& factorGraph) : + Base(factorGraph) { + } + + /** + * Construct the solver with a shared pointer to a factor graph and to a + * VariableIndex. The solver will store these pointers, so this constructor + * is the fastest. + */ + DiscreteSequentialSolver( + const FactorGraph::shared_ptr& factorGraph, + const VariableIndex::shared_ptr& variableIndex) : + Base(factorGraph, variableIndex) { + } + + const EliminationTree& eliminationTree() const { + return *eliminationTree_; + } + + /** + * Eliminate the factor graph sequentially. Uses a column elimination tree + * to recursively eliminate. + */ + BayesNet::shared_ptr eliminate() const { + return Base::eliminate(&EliminateDiscrete); + } + +#ifdef BROKEN + /** + * Compute the marginal joint over a set of variables, by integrating out + * all of the other variables. This function returns the result as a factor + * graph. + */ + DiscreteFactorGraph::shared_ptr jointFactorGraph( + const std::vector& js) const { + DiscreteFactorGraph::shared_ptr results(new DiscreteFactorGraph( + *Base::jointFactorGraph(js, &EliminateDiscrete))); + return results; + } + + /** + * Compute the marginal density over a variable, by integrating out + * all of the other variables. This function returns the result as a factor. + */ + DiscreteFactor::shared_ptr marginalFactor(Index j) const { + return Base::marginalFactor(j, &EliminateDiscrete); + } +#endif + + /** + * Compute the MPE solution of the DiscreteFactorGraph. This + * eliminates to create a BayesNet and then back-substitutes this BayesNet to + * obtain the solution. + */ + DiscreteFactor::sharedValues optimize() const; + + }; + +} // gtsam diff --git a/gtsam/discrete/Domain.cpp b/gtsam/discrete/Domain.cpp new file mode 100644 index 000000000..ff71fa39f --- /dev/null +++ b/gtsam/discrete/Domain.cpp @@ -0,0 +1,95 @@ +/* + * Domain.cpp + * @brief Domain restriction constraint + * @date Feb 13, 2012 + * @author Frank Dellaert + */ + +#include +#include +#include +#include + +namespace gtsam { + + using namespace std; + + /* ************************************************************************* */ + void Domain::print(const string& s) const { +// cout << s << ": Domain on " << keys_[0] << " (j=" << keys_[0] +// << ") with values"; +// BOOST_FOREACH (size_t v,values_) cout << " " << v; +// cout << endl; + BOOST_FOREACH (size_t v,values_) cout << v; + } + + /* ************************************************************************* */ + double Domain::operator()(const Values& values) const { + return contains(values.at(keys_[0])); + } + + /* ************************************************************************* */ + Domain::operator DecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(keys_[0],cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; ++i1) + table.push_back(contains(i1)); + DecisionTreeFactor converted(keys, table); + return converted; + } + + /* ************************************************************************* */ + DecisionTreeFactor Domain::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return DecisionTreeFactor(*this) * f; + } + + /* ************************************************************************* */ + bool Domain::ensureArcConsistency(size_t j, vector& domains) const { + if (j != keys_[0]) throw invalid_argument("Domain check on wrong domain"); + Domain& D = domains[j]; + BOOST_FOREACH(size_t value, values_) + if (!D.contains(value)) throw runtime_error("Unsatisfiable"); + D = *this; + return true; + } + + /* ************************************************************************* */ + bool Domain::checkAllDiff(const vector keys, vector& domains) { + Index j = keys_[0]; + // for all values in this domain + BOOST_FOREACH(size_t value, values_) { + // for all connected domains + BOOST_FOREACH(Index k, keys) + // if any domain contains the value we cannot make this domain singleton + if (k!=j && domains[k].contains(value)) + goto found; + values_.clear(); + values_.insert(value); + return true; // we changed it + found:; + } + return false; // we did not change it + } + + /* ************************************************************************* */ + DiscreteFactor::shared_ptr Domain::partiallyApply( + const Values& values) const { + Values::const_iterator it = values.find(keys_[0]); + if (it != values.end() && !contains(it->second)) throw runtime_error( + "Domain::partiallyApply: unsatisfiable"); + return boost::make_shared < Domain > (*this); + } + + /* ************************************************************************* */ + DiscreteFactor::shared_ptr Domain::partiallyApply( + const vector& domains) const { + const Domain& Dk = domains[keys_[0]]; + if (Dk.isSingleton() && !contains(*Dk.begin())) throw runtime_error( + "Domain::partiallyApply: unsatisfiable"); + return boost::make_shared < Domain > (Dk); + } + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/Domain.h b/gtsam/discrete/Domain.h new file mode 100644 index 000000000..f06e9a1da --- /dev/null +++ b/gtsam/discrete/Domain.h @@ -0,0 +1,107 @@ +/* + * Domain.h + * @brief Domain restriction constraint + * @date Feb 13, 2012 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include + +namespace gtsam { + + /** + * Domain restriction constraint + */ + class Domain: public DiscreteFactor { + + size_t cardinality_; /// Cardinality + std::set values_; /// allowed values + + public: + + typedef boost::shared_ptr shared_ptr; + + // Constructor on Discrete Key initializes an "all-allowed" domain + Domain(const DiscreteKey& dkey) : + DiscreteFactor(dkey.first), cardinality_(dkey.second) { + for (size_t v = 0; v < cardinality_; v++) + values_.insert(v); + } + + // Constructor on Discrete Key with single allowed value + // Consider SingleValue constraint + Domain(const DiscreteKey& dkey, size_t v) : + DiscreteFactor(dkey.first), cardinality_(dkey.second) { + values_.insert(v); + } + + /// Constructor + Domain(const Domain& other) : + DiscreteFactor(other.keys_[0]), values_(other.values_) { + } + + /// insert a value, non const :-( + void insert(size_t value) { + values_.insert(value); + } + + /// erase a value, non const :-( + void erase(size_t value) { + values_.erase(value); + } + + size_t nrValues() const { + return values_.size(); + } + + bool isSingleton() const { + return nrValues() == 1; + } + + size_t firstValue() const { + return *values_.begin(); + } + + // print + virtual void print(const std::string& s = "") const; + + bool contains(size_t value) const { + return values_.count(value)>0; + } + + /// Calculate value + virtual double operator()(const Values& values) const; + + /// Convert into a decisiontree + virtual operator DecisionTreeFactor() const; + + /// Multiply into a decisiontree + virtual DecisionTreeFactor operator*(const DecisionTreeFactor& f) const; + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + bool ensureArcConsistency(size_t j, std::vector& domains) const; + + /** + * Check for a value in domain that does not occur in any other connected domain. + * If found, we make this a singleton... Called in AllDiff::ensureArcConsistency + * @param keys connected domains through alldiff + */ + bool checkAllDiff(const std::vector keys, std::vector& domains); + + /// Partially apply known values + virtual DiscreteFactor::shared_ptr partiallyApply( + const Values& values) const; + + /// Partially apply known values, domain version + virtual DiscreteFactor::shared_ptr partiallyApply( + const std::vector& domains) const; + }; + +} // namespace gtsam diff --git a/gtsam/discrete/PotentialTable.cpp b/gtsam/discrete/PotentialTable.cpp new file mode 100644 index 000000000..d0388e052 --- /dev/null +++ b/gtsam/discrete/PotentialTable.cpp @@ -0,0 +1,162 @@ +/* + * Potentials.cpp + * + * @date Feb 21, 2011 + * @author Duy-Nguyen Ta + */ + +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +namespace gtsam { + + /* ************************************************************************* */ + void PotentialTable::Iterator::operator++() { + // note size_t is unsigned and i>=0 is always true, so strange-looking loop: + for (size_t i = size(); i--; ) { + if (++at(i) < cardinalities_[i]) + return; + else + at(i) = 0; + } + } + + /* ************************************************************************* */ + size_t PotentialTable::computeTableSize( + const std::vector& cardinalities) { + size_t tableSize = 1; + BOOST_FOREACH(const size_t c, cardinalities) + tableSize *= c; + return tableSize; + } + + /* ************************************************************************* */ + PotentialTable::PotentialTable(const std::vector& cs) : + cardinalities_(cs), table_(computeTableSize(cs)) { + generateKeyFactors(); + } + + /* ************************************************************************* */ + PotentialTable::PotentialTable(const std::vector& cardinalities, + const Table& table) : cardinalities_(cardinalities),table_(table) { + generateKeyFactors(); + } + + /* ************************************************************************* */ + PotentialTable::PotentialTable(const std::vector& cardinalities, + const std::string& tableString) : cardinalities_(cardinalities) { + parse(tableString); + generateKeyFactors(); + } + + /* ************************************************************************* */ + bool PotentialTable::equals(const PotentialTable& other, double tol) const { + //TODO: compare potentials in a more general sense with arbitrary order of keys??? + if ((cardinalities_ == other.cardinalities_) && (table_.size() + == other.table_.size()) && (keyFactors_ == other.keyFactors_)) { + for (size_t i = 0; i < table_.size(); i++) { + if (fabs(table_[i] - other.table_[i]) > tol) { + return false; + } + return true; + } + } + return false; + } + + /* ************************************************************************* */ + void PotentialTable::print(const std::string& s) const { + cout << s << endl; + for (size_t i = 0; i < cardinalities_.size(); i++) + cout << boost::format("[%d,%d]") % cardinalities_[i] % keyFactors_[i] << " "; + cout << endl; + Iterator assignment(cardinalities_); + for (size_t idx = 0; idx < table_.size(); ++idx, ++assignment) { + for (size_t k = 0; k < assignment.size(); k++) + cout << assignment[k] << "\t\t"; + cout << table_[idx] << endl; + } + } + + /* ************************************************************************* */ + const double& PotentialTable::operator()(const Assignment& var) const { + return table_[tableIndexFromAssignment(var)]; + } + + /* ************************************************************************* */ + const double& PotentialTable::operator[](const size_t index) const { + return table_.at(index); + } + + + /* ************************************************************************* */ + void PotentialTable::setPotential(const PotentialTable::Assignment& asg, const double potential) { + size_t idx = tableIndexFromAssignment(asg); + assert(idx (iss), istream_iterator (), + back_inserter(table_)); + +#ifndef NDEBUG + size_t expectedSize = computeTableSize(cardinalities_); + if (table_.size() != expectedSize) throw invalid_argument( + boost::str( + boost::format( + "String specification \"%s\" for table only contains %d doubles instead of %d") + % tableString % table_.size() % expectedSize)); +#endif + } + +} // namespace diff --git a/gtsam/discrete/PotentialTable.h b/gtsam/discrete/PotentialTable.h new file mode 100644 index 000000000..b7741ba1e --- /dev/null +++ b/gtsam/discrete/PotentialTable.h @@ -0,0 +1,95 @@ +/* + * Potentials.h + * + * @date Feb 21, 2011 + * @author Duy-Nguyen Ta + */ + +#ifndef POTENTIALS_H_ +#define POTENTIALS_H_ + +#include +#include +#include +#include +#include +#include + +namespace gtsam +{ +/** + * PotentialTable holds the real-valued potentials for Factors or Conditionals + */ +class PotentialTable { +public: + typedef std::vector Table; // container type for potentials f(x1,x2,..) + typedef std::vector Cardinalities; // just a typedef + typedef std::vector Assignment; // just a typedef + + /** + * An assignment that can be incemented + */ + struct Iterator: std::vector { + Cardinalities cardinalities_; + Iterator(const Cardinalities& cs):cardinalities_(cs) { + for(size_t i=0;i cardinalities_; // cardinalities of variables + Table table_; // Potential values of all instantiations of the variables, following the variables' order in vector Keys. + std::vector keyFactors_; // factors to multiply a key's assignment with, to access the potential table + + void generateKeyFactors(); + void parse(const std::string& tableString); + +public: + + /** compute table size from variable cardinalities */ + static size_t computeTableSize(const std::vector& cardinalities); + + /** construct an empty potential */ + PotentialTable() {} + + /** Dangerous empty n-ary potential. */ + PotentialTable(const std::vector& cardinalities); + + /** n-ary potential. */ + PotentialTable(const std::vector& cardinalities, + const Table& table); + + /** n-ary potential. */ + PotentialTable(const std::vector& cardinalities, + const std::string& tableString); + + /** return iterator to first element */ + Iterator begin() const { return Iterator(cardinalities_);} + + /** equality */ + bool equals(const PotentialTable& other, double tol = 1e-9) const; + + /** print */ + void print(const std::string& s = "Potential Table: ") const; + + /** return cardinality of a variable */ + size_t cardinality(size_t var) const { return cardinalities_[var]; } + size_t tableSize() const { return table_.size(); } + + /** accessors to potential values in the table given the assignment */ + const double& operator()(const Assignment& var) const; + const double& operator[](const size_t index) const; + + void setPotential(const Assignment& asg, const double potential); + void setPotential(const size_t tableIndex, const double potential); + + /** convert between assignment and where it is in the table */ + size_t tableIndexFromAssignment(const Assignment& var) const; + Assignment assignmentFromTableIndex(const size_t i) const; +}; + + +} // namespace + +#endif /* POTENTIALS_H_ */ diff --git a/gtsam/discrete/Potentials.cpp b/gtsam/discrete/Potentials.cpp new file mode 100644 index 000000000..f2d6760eb --- /dev/null +++ b/gtsam/discrete/Potentials.cpp @@ -0,0 +1,53 @@ +/* + * Potentials.cpp + * @date March 24, 2011 + * @author Frank Dellaert + */ + +#include +#include +#include + +using namespace std; + +namespace gtsam { + + // explicit instantiation + template class DecisionTree ; + template class AlgebraicDecisionTree ; + + /* ************************************************************************* */ + double Potentials::safe_div(const double& a, const double& b) { + // cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b)); + // The use for safe_div is when we divide the product factor by the sum factor. + // If the product or sum is zero, we accord zero probability to the event. + return (a == 0 || b == 0) ? 0 : (a / b); + } + + /* ******************************************************************************** */ + Potentials::Potentials() : + ADT(1.0) { + } + + /* ******************************************************************************** */ + Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree) : + ADT(decisionTree), cardinalities_(keys.cardinalities()) { + } + + /* ************************************************************************* */ + bool Potentials::equals(const Potentials& other, double tol) const { + return ADT::equals(other, tol); + } + + /* ************************************************************************* */ + void Potentials::print(const string&s) const { + cout << s << "\n Cardinalities: "; + BOOST_FOREACH(const DiscreteKey& key, cardinalities_) + cout << key.first << "=" << key.second << " "; + cout << endl; + ADT::print(" "); + } + + /* ************************************************************************* */ + +} // namespace gtsam diff --git a/gtsam/discrete/Potentials.h b/gtsam/discrete/Potentials.h new file mode 100644 index 000000000..cfcf58400 --- /dev/null +++ b/gtsam/discrete/Potentials.h @@ -0,0 +1,62 @@ +/* + * Potentials.h + * @date March 24, 2011 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace gtsam { + + /** + * A base class for both DiscreteFactor and DiscreteConditional + */ + class Potentials: public AlgebraicDecisionTree { + + public: + + typedef AlgebraicDecisionTree ADT; + + protected: + + /// Cardinality for each key, used in combine + std::map cardinalities_; + + /** Constructor from ColumnIndex, and ADT */ + Potentials(const ADT& potentials) : + ADT(potentials) { + } + + // Safe division for probabilities + static double safe_div(const double& a, const double& b); + + public: + + /** Default constructor for I/O */ + Potentials(); + + /** Constructor from Indices and ADT */ + Potentials(const DiscreteKeys& keys, const ADT& decisionTree); + + /** Constructor from Indices and (string or doubles) */ + template + Potentials(const DiscreteKeys& keys, SOURCE table) : + ADT(keys, table), cardinalities_(keys.cardinalities()) { + } + + // Testable + bool equals(const Potentials& other, double tol = 1e-9) const; + void print(const std::string& s = "Potentials: ") const; + + size_t cardinality(Index j) const { return cardinalities_.at(j);} + + }; // Potentials + +} // namespace gtsam diff --git a/gtsam/discrete/RefCounted.cpp b/gtsam/discrete/RefCounted.cpp new file mode 100644 index 000000000..6440b5028 --- /dev/null +++ b/gtsam/discrete/RefCounted.cpp @@ -0,0 +1,9 @@ +/* + * @file RefCounted.cpp + * @brief Simple reference-counted base class + * @author Frank Dellaert + * @date Mar 29, 2011 + */ + +#include + diff --git a/gtsam/discrete/RefCounted.h b/gtsam/discrete/RefCounted.h new file mode 100644 index 000000000..03d086ab6 --- /dev/null +++ b/gtsam/discrete/RefCounted.h @@ -0,0 +1,86 @@ +/* + * @file RefCounted.h + * @brief Simple reference-counted base class + * @author Frank Dellaert + * @date Mar 29, 2011 + */ + +#include + +// Forward Declarations +namespace gtsam { + struct RefCounted; +} + +namespace boost { + void intrusive_ptr_add_ref(const gtsam::RefCounted * p); + void intrusive_ptr_release(const gtsam::RefCounted * p); +} + +namespace gtsam { + + /** + * Simple reference counted class inspired by + * http://www.codeproject.com/KB/stl/boostsmartptr.aspx + */ + struct RefCounted { + private: + mutable long references_; + friend void ::boost::intrusive_ptr_add_ref(const RefCounted * p); + friend void ::boost::intrusive_ptr_release(const RefCounted * p); + public: + RefCounted() : + references_(0) { + } + virtual ~RefCounted() { + } + }; + +} // namespace gtsam + +// Intrusive Pointer free functions +#ifndef DEBUG_REFCOUNT + +namespace boost { + + // increment reference count of object *p + inline void intrusive_ptr_add_ref(const gtsam::RefCounted * p) { + ++(p->references_); + } + + // decrement reference count, and delete object when reference count reaches 0 + inline void intrusive_ptr_release(const gtsam::RefCounted * p) { + if (--(p->references_) == 0) + delete p; + } + +} // namespace boost + +#else + +#include + + namespace gtsam { + static long GlobalRefCount = 0; + } + + namespace boost { + inline void intrusive_ptr_add_ref(const gtsam::RefCounted * p) { + ++(p->references_); + gtsam::GlobalRefCount++; + std::cout << "add_ref " << p << " " << p->references_ << // + " " << gtsam::GlobalRefCount << std::endl; + } + + inline void intrusive_ptr_release(const gtsam::RefCounted * p) { + gtsam::GlobalRefCount--; + std::cout << "release " << p << " " << (p->references_ - 1) << // + " " << gtsam::GlobalRefCount << std::endl; + if (--(p->references_) == 0) + delete p; + } + + } // namespace boost + +#endif + diff --git a/gtsam/discrete/Scheduler.cpp b/gtsam/discrete/Scheduler.cpp new file mode 100644 index 000000000..574e276a9 --- /dev/null +++ b/gtsam/discrete/Scheduler.cpp @@ -0,0 +1,297 @@ +/* + * Scheduler.h + * @brief an example how inference can be used for scheduling qualifiers + * @date Mar 26, 2011 + * @author Frank Dellaert + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace gtsam { + + using namespace std; + + Scheduler::Scheduler(size_t maxNrStudents, const string& filename): + maxNrStudents_(maxNrStudents) + { + typedef boost::tokenizer > Tokenizer; + + // open file + ifstream is(filename.c_str()); + + string line; // buffer + + // process first line with faculty + if (getline(is, line, '\r')) { + Tokenizer tok(line); + Tokenizer::iterator it = tok.begin(); + for (++it; it != tok.end(); ++it) + addFaculty(*it); + } + + // for all remaining lines + size_t count = 0; + while (getline(is, line, '\r')) { + if (count++ > 100) throw runtime_error("reached 100 lines, exiting"); + Tokenizer tok(line); + Tokenizer::iterator it = tok.begin(); + addSlot(*it++); // add slot + // add availability + for (; it != tok.end(); ++it) + available_ += (it->empty()) ? "0 " : "1 "; + available_ += '\n'; + } + } // constructor + + /** addStudent has to be called after adding slots and faculty */ + void Scheduler::addStudent(const string& studentName, + const string& area1, const string& area2, + const string& area3, const string& advisor) { + assert(nrStudents() area) const { + return area ? students_[s].keys_[*area] : students_[s].key_; + } + + const string& Scheduler::studentName(size_t i) const { + assert(i slot) { + bool debug = ISDEBUG("Scheduler::buildGraph"); + + assert(iat(j); + cout << studentName(s) << " slot: " << slotName_[slot] << endl; + Index base = 3*s; + for (size_t area = 0; area < 3; area++) { + size_t faculty = assignment->at(base+area); + cout << setw(12) << studentArea(s,area) << ": " << facultyName_[faculty] + << endl; + } + cout << endl; + } + } + + /** Special print for single-student case */ + void Scheduler::printSpecial(sharedValues assignment) const { + Values::const_iterator it = assignment->begin(); + for (size_t area = 0; area < 3; area++, it++) { + size_t f = it->second; + cout << setw(12) << it->first << ": " << facultyName_[f] << endl; + } + cout << endl; + } + + /** Accumulate faculty stats */ + void Scheduler::accumulateStats(sharedValues assignment, vector< + size_t>& stats) const { + for (size_t s = 0; s < nrStudents(); s++) { + Index base = 3*s; + for (size_t area = 0; area < 3; area++) { + size_t f = assignment->at(base+area); + assert(frbegin(); + const Student & student = students_.front(); + cout << endl; + (*it)->print(student.name_); + } + + tic(3, "my_optimize"); + sharedValues mpe = optimize(*chordal); + toc(3, "my_optimize"); + return mpe; + } + + /** find the assignment of students to slots with most possible committees */ + Scheduler::sharedValues Scheduler::bestSchedule() const { + sharedValues best; + throw runtime_error("bestSchedule not implemented"); + return best; + } + + /** find the corresponding most desirable committee assignment */ + Scheduler::sharedValues Scheduler::bestAssignment( + sharedValues bestSchedule) const { + sharedValues best; + throw runtime_error("bestAssignment not implemented"); + return best; + } + +} // gtsam + + diff --git a/gtsam/discrete/Scheduler.h b/gtsam/discrete/Scheduler.h new file mode 100644 index 000000000..8b91bce61 --- /dev/null +++ b/gtsam/discrete/Scheduler.h @@ -0,0 +1,171 @@ +/* + * Scheduler.h + * @brief an example how inference can be used for scheduling qualifiers + * @date Mar 26, 2011 + * @author Frank Dellaert + */ + +#pragma once + +#include + +namespace gtsam { + + /** + * Scheduler class + * Creates one variable for each student, and three variables for each + * of the student's areas, for a total of 4*nrStudents variables. + * The "student" variable will determine when the student takes the qual. + * The "area" variables determine which faculty are on his/her committee. + */ + class Scheduler : public CSP { + + private: + + /** Internal data structure for students */ + struct Student { + std::string name_; + DiscreteKey key_; // key for student + std::vector keys_; // key for areas + std::vector areaName_; + std::vector advisor_; + Student(size_t nrFaculty, size_t advisorIndex) : + keys_(3), areaName_(3), advisor_(nrFaculty, 1.0) { + advisor_[advisorIndex] = 0.0; + } + void print() const { + using std::cout; + cout << name_ << ": "; + for (size_t area = 0; area < 3; area++) + cout << areaName_[area] << " "; + cout << std::endl; + } + }; + + /** Maximum number of students */ + size_t maxNrStudents_; + + /** discrete keys, indexed by student and area index */ + std::vector students_; + + /** faculty identifiers */ + std::map facultyIndex_; + std::vector facultyName_, slotName_, areaName_; + + /** area constraints */ + typedef std::map > FacultyInArea; + FacultyInArea facultyInArea_; + + /** nrTimeSlots * nrFaculty availability constraints */ + std::string available_; + + /** which slots are good */ + std::vector slotsAvailable_; + + public: + + /** + * Constructor + * WE need to know the number of students in advance for ordering keys. + * then add faculty, slots, areas, availability, students, in that order + */ + Scheduler(size_t maxNrStudents):maxNrStudents_(maxNrStudents) { + } + + void addFaculty(const std::string& facultyName) { + facultyIndex_[facultyName] = nrFaculty(); + facultyName_.push_back(facultyName); + } + + size_t nrFaculty() const { + return facultyName_.size(); + } + + /** boolean std::string of nrTimeSlots * nrFaculty */ + void setAvailability(const std::string& available) { + available_ = available; + } + + void addSlot(const std::string& slotName) { + slotName_.push_back(slotName); + } + + size_t nrTimeSlots() const { + return slotName_.size(); + } + + const std::string& slotName(size_t s) const { + return slotName_[s]; + } + + /** slots available, boolean */ + void setSlotsAvailable(const std::vector& slotsAvailable) { + slotsAvailable_ = slotsAvailable; + } + + void addArea(const std::string& facultyName, const std::string& areaName) { + areaName_.push_back(areaName); + std::vector& table = facultyInArea_[areaName]; // will create if needed + if (table.empty()) table.resize(nrFaculty(), 0); + table[facultyIndex_[facultyName]] = 1; + } + + /** + * Constructor that reads in faculty, slots, availibility. + * Still need to add areas and students after this + */ + Scheduler(size_t maxNrStudents, const std::string& filename); + + /** get key for student and area, 0 is time slot itself */ + const DiscreteKey& key(size_t s, boost::optional area = boost::none) const; + + /** addStudent has to be called after adding slots and faculty */ + void addStudent(const std::string& studentName, const std::string& area1, + const std::string& area2, const std::string& area3, + const std::string& advisor); + + /// current number of students + size_t nrStudents() const { + return students_.size(); + } + + const std::string& studentName(size_t i) const; + const DiscreteKey& studentKey(size_t i) const; + const std::string& studentArea(size_t i, size_t area) const; + + /** Add student-specific constraints to the graph */ + void addStudentSpecificConstraints(size_t i, boost::optional slot = boost::none); + + /** Main routine that builds factor graph */ + void buildGraph(size_t mutexBound = 7); + + /** print */ + void print(const std::string& s = "Scheduler") const; + + /** Print readable form of assignment */ + void printAssignment(sharedValues assignment) const; + + /** Special print for single-student case */ + void printSpecial(sharedValues assignment) const; + + /** Accumulate faculty stats */ + void accumulateStats(sharedValues assignment, + std::vector& stats) const; + + /** Eliminate, return a Bayes net */ + DiscreteBayesNet::shared_ptr eliminate() const; + + /** Find the best total assignment - can be expensive */ + sharedValues optimalAssignment() const; + + /** find the assignment of students to slots with most possible committees */ + sharedValues bestSchedule() const; + + /** find the corresponding most desirable committee assignment */ + sharedValues bestAssignment(sharedValues bestSchedule) const; + + }; // Scheduler + +} // gtsam + + diff --git a/gtsam/discrete/Signature.cpp b/gtsam/discrete/Signature.cpp new file mode 100644 index 000000000..4d808543a --- /dev/null +++ b/gtsam/discrete/Signature.cpp @@ -0,0 +1,217 @@ +/* + * Signature.cpp + * @brief: signatures for conditional densities + * @author: Frank dellaert + * @date Feb 27, 2011 + */ + +#include +#include + +#include "Signature.h" + +#ifdef BOOST_HAVE_PARSER +#include // for parsing +#include // for qi::_val +#endif + +namespace gtsam { + + using namespace std; + + +#ifdef BOOST_HAVE_PARSER + namespace qi = boost::spirit::qi; + + // parser for strings of form "99/1 80/20" etc... + namespace parser { + typedef string::const_iterator It; + using boost::phoenix::val; + using boost::phoenix::ref; + using boost::phoenix::push_back; + + // Special rows, true and false + Signature::Row createF() { + Signature::Row r(2); + r[0] = 1; + r[1] = 0; + return r; + } + Signature::Row createT() { + Signature::Row r(2); + r[0] = 0; + r[1] = 1; + return r; + } + Signature::Row T = createT(), F = createF(); + + // Special tables (inefficient, but do we care for user input?) + Signature::Table logic(bool ff, bool ft, bool tf, bool tt) { + Signature::Table t(4); + t[0] = ff ? T : F; + t[1] = ft ? T : F; + t[2] = tf ? T : F; + t[3] = tt ? T : F; + return t; + } + + struct Grammar { + qi::rule table, or_, and_, rows; + qi::rule true_, false_, row; + Grammar() { + table = or_ | and_ | rows; + or_ = qi::lit("OR")[qi::_val = logic(false, true, true, true)]; + and_ = qi::lit("AND")[qi::_val = logic(false, false, false, true)]; + rows = +(row | true_ | false_); // only loads first of the rows under boost 1.42 + row = qi::double_ >> +("/" >> qi::double_); + true_ = qi::lit("T")[qi::_val = T]; + false_ = qi::lit("F")[qi::_val = F]; + } + } grammar; + + // Create simpler parsing function to avoid the issue of only parsing a single row + bool parse_table(const string& spec, Signature::Table& table) { + // check for OR, AND on whole phrase + It f = spec.begin(), l = spec.end(); + if (qi::parse(f, l, + qi::lit("OR")[ref(table) = logic(false, true, true, true)]) || + qi::parse(f, l, + qi::lit("AND")[ref(table) = logic(false, false, false, true)])) + return true; + + // tokenize into separate rows + istringstream iss(spec); + string token; + while (iss >> token) { + Signature::Row values; + It tf = token.begin(), tl = token.end(); + bool r = qi::parse(tf, tl, + qi::double_[push_back(ref(values), qi::_1)] >> +("/" >> qi::double_[push_back(ref(values), qi::_1)]) | + qi::lit("T")[ref(values) = T] | + qi::lit("F")[ref(values) = F] ); + if (!r) + return false; + table.push_back(values); + } + + return true; + } + } // \namespace parser +#endif + + ostream& operator <<(ostream &os, const Signature::Row &row) { + os << row[0]; + for (size_t i = 1; i < row.size(); i++) + os << " " << row[i]; + return os; + } + + ostream& operator <<(ostream &os, const Signature::Table &table) { + for (size_t i = 0; i < table.size(); i++) + os << table[i] << endl; + return os; + } + + Signature::Signature(const DiscreteKey& key) : + key_(key) { + } + + DiscreteKeys Signature::discreteKeysParentsFirst() const { + DiscreteKeys keys; + BOOST_FOREACH(const DiscreteKey& key, parents_) + keys.push_back(key); + keys.push_back(key_); + return keys; + } + + vector Signature::indices() const { + vector js; + js.push_back(key_.first); + BOOST_FOREACH(const DiscreteKey& key, parents_) + js.push_back(key.first); + return js; + } + + vector Signature::cpt() const { + vector cpt; + if (table_) { + BOOST_FOREACH(const Row& row, *table_) + BOOST_FOREACH(const double& x, row) + cpt.push_back(x); + } + return cpt; + } + + Signature& Signature::operator,(const DiscreteKey& parent) { + parents_.push_back(parent); + return *this; + } + + static void normalize(Signature::Row& row) { + double sum = 0; + for (size_t i = 0; i < row.size(); i++) + sum += row[i]; + for (size_t i = 0; i < row.size(); i++) + row[i] /= sum; + } + + Signature& Signature::operator=(const string& spec) { + spec_.reset(spec); +#ifdef BOOST_HAVE_PARSER + Table table; + // NOTE: using simpler parse function to ensure boost back compatibility +// parser::It f = spec.begin(), l = spec.end(); + bool success = // +// qi::phrase_parse(f, l, parser::grammar.table, qi::space, table); // using full grammar + parser::parse_table(spec, table); + if (success) { + BOOST_FOREACH(Row& row, table) + normalize(row); + table_.reset(table); + } +#endif + return *this; + } + + Signature& Signature::operator=(const Table& t) { + Table table = t; + BOOST_FOREACH(Row& row, table) + normalize(row); + table_.reset(table); + return *this; + } + + ostream& operator <<(ostream &os, const Signature &s) { + os << s.key_.first; + if (s.parents_.empty()) { + os << " % "; + } else { + os << " | " << s.parents_[0].first; + for (size_t i = 1; i < s.parents_.size(); i++) + os << " && " << s.parents_[i].first; + os << " = "; + } + os << (s.spec_ ? *s.spec_ : "no spec") << endl; + if (s.table_) + os << (*s.table_); + else + os << "spec could not be parsed" << endl; + return os; + } + + Signature operator|(const DiscreteKey& key, const DiscreteKey& parent) { + Signature s(key); + return s, parent; + } + + Signature operator%(const DiscreteKey& key, const string& parent) { + Signature s(key); + return s = parent; + } + + Signature operator%(const DiscreteKey& key, const Signature::Table& parent) { + Signature s(key); + return s = parent; + } + +} // namespace gtsam diff --git a/gtsam/discrete/Signature.h b/gtsam/discrete/Signature.h new file mode 100644 index 000000000..25e66860e --- /dev/null +++ b/gtsam/discrete/Signature.h @@ -0,0 +1,129 @@ +/* + * Signature.h + * @brief: signatures for conditional densities + * @author: Frank dellaert + * @date Feb 27, 2011 + */ + +#pragma once +#include +#include +#include +#include + +#include // for checking whether we are using boost 1.40 +#if BOOST_VERSION >= 104200 +#define BOOST_HAVE_PARSER +#endif + +namespace gtsam { + + /** + * Signature for a discrete conditional density, used to construct conditionals. + * + * The format is (Key % string) for nodes with no parents, + * and (Key | Key, Key = string) for nodes with parents. + * + * The string specifies a conditional probability spec in the 00 01 10 11 order. + * For three-valued, it would be 00 01 02 10 11 12 20 21 22, etc... + * + * For example, given the following keys + * + * Key A("Asia"), S("Smoking"), T("Tuberculosis"), L("LungCancer"), + * B("Bronchitis"), E("Either"), X("XRay"), D("Dyspnoea"); + * + * These are all valid signatures (Asia network example): + * + * A % "99/1" + * S % "50/50" + * T|A = "99/1 95/5" + * L|S = "99/1 90/10" + * B|S = "70/30 40/60" + * E|T,L = "F F F 1" + * X|E = "95/5 2/98" + * D|E,B = "9/1 2/8 3/7 1/9" + */ + class Signature { + + public: + + /** Data type for the CPT */ + typedef std::vector Row; + typedef std::vector Table; + + private: + + /** the variable key */ + DiscreteKey key_; + + /** the parent keys */ + DiscreteKeys parents_; + + // the given CPT specification string + boost::optional spec_; + + // the CPT as parsed, if successful + boost::optional table_; + + public: + + /** Constructor from DiscreteKey */ + Signature(const DiscreteKey& key); + + /** the variable key */ + const DiscreteKey& key() const { + return key_; + } + + /** the parent keys */ + const DiscreteKeys& parents() const { + return parents_; + } + + /** All keys, with variable key last */ + DiscreteKeys discreteKeysParentsFirst() const; + + /** All key indices, with variable key first */ + std::vector indices() const; + + // the CPT as parsed, if successful + const boost::optional
& table() const { + return table_; + } + + // the CPT as a vector of doubles, with key's values most rapidly changing + std::vector cpt() const; + + /** Add a parent */ + Signature& operator,(const DiscreteKey& parent); + + /** Add the CPT spec - Fails in boost 1.40 */ + Signature& operator=(const std::string& spec); + + /** Add the CPT spec directly as a table */ + Signature& operator=(const Table& table); + + /** provide streaming */ + friend std::ostream& operator <<(std::ostream &os, const Signature &s); + }; + + /** + * Helper function to create Signature objects + * example: Signature s = D | E; + */ + Signature operator|(const DiscreteKey& key, const DiscreteKey& parent); + + /** + * Helper function to create Signature objects + * example: Signature s(D % "99/1"); + * Uses string parser, which requires BOOST 1.42 or higher + */ + Signature operator%(const DiscreteKey& key, const std::string& parent); + + /** + * Helper function to create Signature objects, using table construction directly + * example: Signature s(D % table); + */ + Signature operator%(const DiscreteKey& key, const Signature::Table& parent); + +} diff --git a/gtsam/discrete/SingleValue.cpp b/gtsam/discrete/SingleValue.cpp new file mode 100644 index 000000000..abd21a1a1 --- /dev/null +++ b/gtsam/discrete/SingleValue.cpp @@ -0,0 +1,78 @@ +/* + * SingleValue.cpp + * @brief domain constraint + * @date Feb 13, 2012 + * @author Frank Dellaert + */ + +#include +#include +#include +#include +#include + +namespace gtsam { + + using namespace std; + + /* ************************************************************************* */ + void SingleValue::print(const string& s) const { + cout << s << ": SingleValue on " << keys_[0] << " (j=" << keys_[0] + << ") with value " << value_ << endl; + } + + /* ************************************************************************* */ + double SingleValue::operator()(const Values& values) const { + return (double) (values.at(keys_[0]) == value_); + } + + /* ************************************************************************* */ + SingleValue::operator DecisionTreeFactor() const { + DiscreteKeys keys; + keys += DiscreteKey(keys_[0],cardinality_); + vector table; + for (size_t i1 = 0; i1 < cardinality_; i1++) + table.push_back(i1 == value_); + DecisionTreeFactor converted(keys, table); + return converted; + } + + /* ************************************************************************* */ + DecisionTreeFactor SingleValue::operator*(const DecisionTreeFactor& f) const { + // TODO: can we do this more efficiently? + return DecisionTreeFactor(*this) * f; + } + + /* ************************************************************************* */ + bool SingleValue::ensureArcConsistency(size_t j, + vector& domains) const { + if (j != keys_[0]) throw invalid_argument( + "SingleValue check on wrong domain"); + Domain& D = domains[j]; + if (D.isSingleton()) { + if (D.firstValue() != value_) throw runtime_error("Unsatisfiable"); + return false; + } + D = Domain(discreteKey(),value_); + return true; + } + + /* ************************************************************************* */ + DiscreteFactor::shared_ptr SingleValue::partiallyApply(const Values& values) const { + Values::const_iterator it = values.find(keys_[0]); + if (it != values.end() && it->second != value_) throw runtime_error( + "SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared < SingleValue > (keys_[0], cardinality_, value_); + } + + /* ************************************************************************* */ + DiscreteFactor::shared_ptr SingleValue::partiallyApply( + const vector& domains) const { + const Domain& Dk = domains[keys_[0]]; + if (Dk.isSingleton() && !Dk.contains(value_)) throw runtime_error( + "SingleValue::partiallyApply: unsatisfiable"); + return boost::make_shared < SingleValue > (discreteKey(), value_); + } + +/* ************************************************************************* */ +} // namespace gtsam diff --git a/gtsam/discrete/SingleValue.h b/gtsam/discrete/SingleValue.h new file mode 100644 index 000000000..c1757f82d --- /dev/null +++ b/gtsam/discrete/SingleValue.h @@ -0,0 +1,72 @@ +/* + * SingleValue.h + * @brief domain constraint + * @date Feb 6, 2012 + * @author Frank Dellaert + */ + +#pragma once + +#include +#include + +namespace gtsam { + + /** + * SingleValue constraint + */ + class SingleValue: public DiscreteFactor { + + /// Number of values + size_t cardinality_; + + /// allowed value + size_t value_; + + DiscreteKey discreteKey() const { + return DiscreteKey(keys_[0],cardinality_); + } + + public: + + typedef boost::shared_ptr shared_ptr; + + /// Constructor + SingleValue(Index key, size_t n, size_t value) : + DiscreteFactor(key), cardinality_(n), value_(value) { + } + + /// Constructor + SingleValue(const DiscreteKey& dkey, size_t value) : + DiscreteFactor(dkey.first), cardinality_(dkey.second), value_(value) { + } + + // print + virtual void print(const std::string& s = "") const; + + /// Calculate value + virtual double operator()(const Values& values) const; + + /// Convert into a decisiontree + virtual operator DecisionTreeFactor() const; + + /// Multiply into a decisiontree + virtual DecisionTreeFactor operator*(const DecisionTreeFactor& f) const; + + /* + * Ensure Arc-consistency + * @param j domain to be checked + * @param domains all other domains + */ + bool ensureArcConsistency(size_t j, std::vector& domains) const; + + /// Partially apply known values + virtual DiscreteFactor::shared_ptr partiallyApply( + const Values& values) const; + + /// Partially apply known values, domain version + virtual DiscreteFactor::shared_ptr partiallyApply( + const std::vector& domains) const; + }; + +} // namespace gtsam diff --git a/gtsam/discrete/TypedDiscreteFactor.cpp b/gtsam/discrete/TypedDiscreteFactor.cpp new file mode 100644 index 000000000..daa745037 --- /dev/null +++ b/gtsam/discrete/TypedDiscreteFactor.cpp @@ -0,0 +1,117 @@ +/* + * @file TypedDiscreteFactor.cpp + * @brief + * @author Duy-Nguyen Ta + * @date Mar 5, 2011 + */ + +#include +#include +#include +#include + +using namespace std; + +namespace gtsam { + + /* ******************************************************************************** */ + TypedDiscreteFactor::TypedDiscreteFactor(const Indices& keys, + const string& table) : + Factor (keys.begin(), keys.end()), potentials_(keys, table) { + } + + /* ******************************************************************************** */ + TypedDiscreteFactor::TypedDiscreteFactor(const Indices& keys, + const vector& table) : + Factor (keys.begin(), keys.end()), potentials_(keys, table) { + //#define DEBUG_FACTORS +#ifdef DEBUG_FACTORS + static size_t count = 0; + string dotfile = (boost::format("Factor-%03d") % ++count).str(); + potentials_.dot(dotfile); + if (count == 57) potentials_.print("57"); +#endif + } + + /* ************************************************************************* */ + double TypedDiscreteFactor::operator()(const Values& values) const { + return potentials_(values); + } + + /* ************************************************************************* */ + void TypedDiscreteFactor::print(const string&s) const { + Factor::print(s); + potentials_.print(); + } + + /* ************************************************************************* */ + bool TypedDiscreteFactor::equals(const TypedDiscreteFactor& other, double tol) const { + return potentials_.equals(other.potentials_, tol); + } + + /* ******************************************************************************** */ + DiscreteFactor::shared_ptr TypedDiscreteFactor::toDiscreteFactor( + const KeyOrdering& ordering) const { + throw std::runtime_error("broken"); + //return boost::make_shared(keys(), ordering, potentials_); + } + +#ifdef OLD +DiscreteFactor TypedDiscreteFactor::toDiscreteFactor( + const KeyOrdering& ordering, const ProblemType problemType) const { + { + static bool debug = false; + + // instantiate vector keys and column index in order + DiscreteFactor::ColumnIndex orderColumnIndex; + vector keys; + BOOST_FOREACH(const KeyOrdering::value_type& ord, ordering) + { + if (debug) cout << "Key: " << ord.first; + + // find the key with ord.first in this factor + vector::const_iterator it = std::find(keys_.begin(), + keys_.end(), ord.first); + + // if found + if (it != keys_.end()) { + if (debug) cout << "it found: " << (*it) << ", index: " + << ord.second << endl; + + keys.push_back(ord.second); // push back the ordering index + orderColumnIndex[ord.second] = columnIndex_.at(ord.first.name()); + + if (debug) cout << "map " << ord.second << " with name: " + << ord.first.name() << " to " << columnIndex_.at( + ord.first.name()) << endl; + } + } + + DiscreteFactor f(keys, potentials_, orderColumnIndex, problemType); + return f; + } + + /* ******************************************************************************** */ + std::vector TypedDiscreteFactor::init(const Indices& keys) { + vector cardinalities; + for (size_t j = 0; j < keys.size(); j++) { + Index key = keys[j]; + keys_.push_back(key); + columnIndex_[key.name()] = j; + cardinalities.push_back(key.cardinality()); + } + return cardinalities; + } + + /* ******************************************************************************** */ + double TypedDiscreteFactor::potential(const TypedValues& values) const { + vector assignment(values.size()); + BOOST_FOREACH(const TypedValues::value_type& val, values) + if (columnIndex_.find(val.first) != columnIndex_.end()) assignment[columnIndex_.at( + val.first)] = val.second; + return potentials_(assignment); + } + +#endif + +} // namespace diff --git a/gtsam/discrete/TypedDiscreteFactor.h b/gtsam/discrete/TypedDiscreteFactor.h new file mode 100644 index 000000000..95f70898b --- /dev/null +++ b/gtsam/discrete/TypedDiscreteFactor.h @@ -0,0 +1,68 @@ +/* + * @file TypedDiscreteFactor.h + * @brief + * @author Duy-Nguyen Ta + * @date Mar 5, 2011 + */ + +#pragma once + +#include +#include +#include +#include + +namespace gtsam { + + /** + * A factor on discrete variables with string keys + */ + class TypedDiscreteFactor: public Factor { + + typedef AlgebraicDecisionDiagram ADD; + + /** potentials of the factor */ + ADD potentials_; + + public: + + /** A map from keys to values */ + typedef ADD::Assignment Values; + + /** Constructor from keys and string table */ + TypedDiscreteFactor(const Indices& keys, const std::string& table); + + /** Constructor from keys and doubles */ + TypedDiscreteFactor(const Indices& keys, + const std::vector& table); + + /** Evaluate */ + double operator()(const Values& values) const; + + // Testable + bool equals(const TypedDiscreteFactor& other, double tol = 1e-9) const; + void print(const std::string& s = "DiscreteFactor: ") const; + + DiscreteFactor::shared_ptr toDiscreteFactor(const KeyOrdering& ordering) const; + +#ifdef OLD + /** map each variable name to its column index in the potential table */ + typedef std::map Index2IndexMap; + Index2IndexMap columnIndex_; + + /** Initialize keys, column index, and return cardinalities */ + std::vector init(const Indices& keys); + + public: + + /** Default constructor */ + TypedDiscreteFactor() {} + + /** Evaluate potential of a given assignment of values */ + double potential(const TypedValues& values) const; + +#endif + + }; // TypedDiscreteFactor + +} // namespace diff --git a/gtsam/discrete/TypedDiscreteFactorGraph.cpp b/gtsam/discrete/TypedDiscreteFactorGraph.cpp new file mode 100644 index 000000000..e0e18a885 --- /dev/null +++ b/gtsam/discrete/TypedDiscreteFactorGraph.cpp @@ -0,0 +1,68 @@ +/* + * @file TypedDiscreteFactorGraph.cpp + * @brief + * @author Duy-Nguyen Ta + * @date Mar 1, 2011 + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace std; + +namespace gtsam { + + /* ************************************************************************* */ + TypedDiscreteFactorGraph::TypedDiscreteFactorGraph() { + } + + /* ************************************************************************* */ + TypedDiscreteFactorGraph::TypedDiscreteFactorGraph(const string& filename) { + bool success = parseUAI(filename, *this); + if (!success) throw runtime_error( + "TypedDiscreteFactorGraph constructor from filename failed"); + } + + /* ************************************************************************* */ + void TypedDiscreteFactorGraph::add// + (const Indices& keys, const string& table) { + push_back(boost::shared_ptr// + (new TypedDiscreteFactor(keys, table))); + } + + /* ************************************************************************* */ + void TypedDiscreteFactorGraph::add// + (const Indices& keys, const vector& table) { + push_back(boost::shared_ptr// + (new TypedDiscreteFactor(keys, table))); + } + + /* ************************************************************************* */ + void TypedDiscreteFactorGraph::print(const string s) { + cout << s << endl; + cout << "Factors: " << endl; + BOOST_FOREACH(const sharedFactor factor, factors_) + factor->print(); + } + + /* ************************************************************************* */ + double TypedDiscreteFactorGraph::operator()( + const TypedDiscreteFactor::Values& values) const { + // Loop over all factors and multiply their probabilities + double p = 1.0; + BOOST_FOREACH(const sharedFactor& factor, *this) + p *= (*factor)(values); + return p; + } + +/* ************************************************************************* */ + +} diff --git a/gtsam/discrete/TypedDiscreteFactorGraph.h b/gtsam/discrete/TypedDiscreteFactorGraph.h new file mode 100644 index 000000000..93e19e9d2 --- /dev/null +++ b/gtsam/discrete/TypedDiscreteFactorGraph.h @@ -0,0 +1,50 @@ +/* + * @file TypedDiscreteFactorGraph.h + * @brief Factor graph with typed factors (with Index keys) + * @author Duy-Nguyen Ta + * @author Frank Dellaert + * @date Mar 1, 2011 + */ + +#pragma once + +#include +#include +#include +#include + +namespace gtsam { + + /** + * Typed discrete factor graph, where keys are strings + */ + class TypedDiscreteFactorGraph: public FactorGraph { + + public: + + /** + * Default constructor + */ + TypedDiscreteFactorGraph(); + + /** + * Constructor from file + * For now assumes in .uai format from UAI'08 Probablistic Inference Evaluation + * See http://graphmod.ics.uci.edu/uai08/FileFormat + */ + TypedDiscreteFactorGraph(const std::string& filename); + + // Add factors without shared pointer ugliness + void add(const Indices& keys, const std::string& table); + void add(const Indices& keys, const std::vector& table); + + /** print */ + void print(const std::string s); + + /** Evaluate potential of a given assignment of values */ + double operator()(const TypedDiscreteFactor::Values& values) const; + + }; // TypedDiscreteFactorGraph + + +} // namespace diff --git a/gtsam/discrete/examples/Doodle.csv b/gtsam/discrete/examples/Doodle.csv new file mode 100644 index 000000000..1ce4ecebb --- /dev/null +++ b/gtsam/discrete/examples/Doodle.csv @@ -0,0 +1 @@ +,Ron Arkin,Andrea Thomaz,Ayanna Howard,Wayne Book,Mike Stilman,Charlie Kemp,Jun Ueda,Patricio Vela,Magnus Egerstedt,Harvey Lipkin,Frank Dellaert,Irfan Essa,Aaron Bobick,Jim Rehg,Henrik Christensen,Tucker Balch,Karen Feigh,N/A 1,N/A 2 Mon 9:00-10.30,,1,1,1,1,,1,1,,,1,,,,1,,,1,1 Mon 10:30-12:00,,1,1,1,1,,,1,1,,,,,,,,,1,1 Mon 1:30-3:00,,,1,,,1,1,1,1,1,1,,,,1,,,1,1 Mon 3:00-4:30,,,,1,,1,1,1,,1,1,1,,1,1,,,1,1 Tue 9:00-10.30,,,1,,,,,1,,1,1,,,1,1,,,1,1 Tue 10:30-12:00,,,1,1,1,,1,1,,1,1,,,1,,1,,1,1 Tue 1:30-3:00,,1,,1,1,,1,1,1,1,1,,,,,1,,1,1 Tue 3:00-4:30,,1,1,,,,,1,1,1,1,,,,1,1,,1,1 Wed 9:00-10.30,,,1,1,,,,,1,,1,,1,,1,,1,1,1 Wed 10:30-12:00,,,,1,1,,1,1,1,,,1,1,,1,1,1,1,1 Wed 1:30-3:00,,,,,1,1,,1,,1,1,1,1,,,1,,1,1 Wed 3:00-4:30,,,,,1,1,1,1,1,,1,1,,1,,1,,1,1 Thu 9:00-10.30,,,1,,,,,1,,1,,,1,1,,,,1,1 Thu 10:30-12:00,,,1,1,1,,1,1,,1,,,,1,,1,,1,1 Thu 1:30-3:00,,,1,1,1,,1,1,1,1,,,,1,,1,1,1,1 Thu 3:00-4:30,,,1,,,,,,1,1,,,,,1,1,1,1,1 Fri 9:00-10.30,,,1,1,1,1,1,1,,,1,,,,1,,,1,1 Fri 10:30-12:00,,,1,1,1,,1,1,,,,1,,,1,,,1,1 Fri 1:30-3:00,,,1,1,1,,,1,,1,1,1,1,,,,,1,1 Fri 3:00-4:30,,,,,,,,,,1,1,,1,,,,,1,1 \ No newline at end of file diff --git a/gtsam/discrete/examples/Doodle.xls b/gtsam/discrete/examples/Doodle.xls new file mode 100644 index 000000000..c607581e9 Binary files /dev/null and b/gtsam/discrete/examples/Doodle.xls differ diff --git a/gtsam/discrete/examples/Doodle2012.csv b/gtsam/discrete/examples/Doodle2012.csv new file mode 100644 index 000000000..54520b614 --- /dev/null +++ b/gtsam/discrete/examples/Doodle2012.csv @@ -0,0 +1 @@ +,Karen Feigh,Henrik Christensen,Panos Tsiotras,Ron Arkin,Andrea Thomaz,Magnus Egerstedt,Charles Isbell,Fumin Zhang,Mike Stilman,Jun Ueda,Aaron Bobick,Ayanna Howard,Patricio Vela,Charlie Kemp,Tucker Balch Mon 9:00 AM - 10:30 AM,,,1,1,1,1,1,,,,1,,,, Mon 10:30 AM - 12:00 PM,1,,,1,1,,1,1,1,,1,,1,1,1 Mon 1:30 PM - 3:00 PM,1,1,1,,,1,1,1,1,1,1,1,1,,1 Mon 3:00 PM - 4:30 PM,,,1,1,,,1,,1,,1,1,1,,1 Mon 4:30 PM - 6:00 PM,,1,1,,,,,1,,1,1,,1,, Tue 9:00 AM - 10:30 AM,,1,1,,1,1,1,,,,1,1,,, Tue 10:30 AM - 12:00 PM,1,1,1,1,1,,1,1,,1,1,,1,,1 Tue 1:30 PM - 3:00 PM,1,1,1,,1,1,,1,1,1,1,,,1, Tue 3:00 PM - 4:30 PM,,1,,,1,1,,1,,,,,,, Tue 4:30 PM - 6:00 PM,,,,,1,,,1,1,,1,,,, Wed 9:00 AM - 10:30 AM,1,1,1,,1,,1,,,,,1,,, Wed 10:30 AM - 12:00 PM,1,,,,1,1,1,1,1,1,,,1,1, Wed 1:30 PM - 3:00 PM,1,,1,,,1,,1,1,1,,,1,, Wed 3:00 PM - 4:30 PM,,,1,,,,,,1,,,,1,,1 Wed 4:30 PM - 6:00 PM,,,1,,,,,1,,,,,1,, Thu 9:00 AM - 10:30 AM,,1,1,,,1,,,,,,,,, Thu 10:30 AM - 12:00 PM,1,1,,,,1,,1,,1,,,,,1 Thu 1:30 PM - 3:00 PM,1,,,,,1,,1,,,,,,, Thu 3:00 PM - 4:30 PM,,1,1,,,1,1,1,,,,,,, Thu 4:30 PM - 6:00 PM,,1,1,,,,,1,,,,,,, Fri 9:00 AM - 10:30 AM,1,1,,,,,1,,,,,1,,, Fri 10:30 AM - 12:00 PM,1,1,,,,,,1,1,1,,,,,1 Fri 1:30 PM - 3:00 PM,1,,,,,1,,1,1,1,,,1,1,1 Fri 3:00 PM - 4:30 PM,1,,,,,,1,1,1,1,,,,,1 Fri 4:30 PM - 6:00 PM,,1,,,,,,1,,,,,1,, \ No newline at end of file diff --git a/gtsam/discrete/examples/Doodle2012.xls b/gtsam/discrete/examples/Doodle2012.xls new file mode 100644 index 000000000..981e2dc25 Binary files /dev/null and b/gtsam/discrete/examples/Doodle2012.xls differ diff --git a/gtsam/discrete/examples/intrusive.xlsx b/gtsam/discrete/examples/intrusive.xlsx new file mode 100644 index 000000000..53fd048e2 Binary files /dev/null and b/gtsam/discrete/examples/intrusive.xlsx differ diff --git a/gtsam/discrete/examples/schedulingExample.cpp b/gtsam/discrete/examples/schedulingExample.cpp new file mode 100644 index 000000000..7bb401d6a --- /dev/null +++ b/gtsam/discrete/examples/schedulingExample.cpp @@ -0,0 +1,344 @@ +/* + * schedulingExample.cpp + * @brief hard scheduling example + * @date March 25, 2011 + * @author Frank Dellaert + */ + +//#define ENABLE_TIMING +#define ADD_NO_CACHING +#define ADD_NO_PRUNING +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +using namespace boost::assign; +using namespace std; +using namespace gtsam; + +/* ************************************************************************* */ +void addStudent(Scheduler& s, size_t i) { + switch (i) { + case 0: + s.addStudent("Michael N", "AI", "Autonomy", "Perception", "Tucker Balch"); + break; + case 1: + s.addStudent("Tucker H", "Controls", "AI", "Perception", "Jim Rehg"); + break; + case 2: + s.addStudent("Jake H", "Controls", "AI", "Perception", "Henrik Christensen"); + break; + case 3: + s.addStudent("Tobias K", "Controls", "AI", "Autonomy", "Mike Stilman"); + break; + case 4: + s.addStudent("Shu J", "Controls", "AI", "HRI", "N/A 1"); + break; + case 5: + s.addStudent("Akansel C", "AI", "Autonomy", "Mechanics", + "Henrik Christensen"); + break; + case 6: + s.addStudent("Tiffany C", "Controls", "N/A 1", "N/A 2", "Charlie Kemp"); + break; + } +} +/* ************************************************************************* */ +Scheduler largeExample(size_t nrStudents = 7) { + string path("/Users/dellaert/borg/gtsam2/gtsam2/discrete/examples/"); + Scheduler s(nrStudents, path + "Doodle.csv"); + + s.addArea("Harvey Lipkin", "Mechanics"); + s.addArea("Wayne Book", "Mechanics"); + s.addArea("Jun Ueda", "Mechanics"); + + // s.addArea("Wayne Book", "Controls"); + s.addArea("Patricio Vela", "Controls"); + s.addArea("Magnus Egerstedt", "Controls"); + s.addArea("Jun Ueda", "Controls"); + + // s.addArea("Frank Dellaert", "Perception"); + s.addArea("Jim Rehg", "Perception"); + s.addArea("Irfan Essa", "Perception"); + s.addArea("Aaron Bobick", "Perception"); + s.addArea("Henrik Christensen", "Perception"); + + s.addArea("Mike Stilman", "AI"); + s.addArea("Henrik Christensen", "AI"); + s.addArea("Frank Dellaert", "AI"); + s.addArea("Ayanna Howard", "AI"); + // s.addArea("Tucker Balch", "AI"); + + s.addArea("Ayanna Howard", "Autonomy"); + // s.addArea("Andrea Thomaz", "Autonomy"); + s.addArea("Charlie Kemp", "Autonomy"); + s.addArea("Tucker Balch", "Autonomy"); + s.addArea("Ron Arkin", "Autonomy"); + + s.addArea("Andrea Thomaz", "HRI"); + s.addArea("Karen Feigh", "HRI"); + s.addArea("Charlie Kemp", "HRI"); + + // Allow students not to take three areas + s.addArea("N/A 1", "N/A 1"); + s.addArea("N/A 2", "N/A 2"); + + // add students + for (size_t i = 0; i < nrStudents; i++) + addStudent(s, i); + + return s; +} + +/* ************************************************************************* */ +void runLargeExample() { + + Scheduler scheduler = largeExample(); + scheduler.print(); + + // BUILD THE GRAPH ! + size_t addMutex = 2; + scheduler.buildGraph(addMutex); + + // Do brute force product and output that to file + if (scheduler.nrStudents() == 1) { // otherwise too slow + DecisionTreeFactor product = scheduler.product(); + product.dot("scheduling-large", false); + } + + // Do exact inference + // SETDEBUG("timing-verbose", true); + SETDEBUG("DiscreteConditional::DiscreteConditional", true); + tic(2, "large"); + DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + toc(2, "large"); + tictoc_finishedIteration(); + tictoc_print(); + scheduler.printAssignment(MPE); +} + +/* ************************************************************************* */ +// Solve a series of relaxed problems for maximum flexibility solution +void solveStaged(size_t addMutex = 2) { + + // super-hack! just count... + bool debug = false; + SETDEBUG("DiscreteConditional::COUNT", true); + SETDEBUG("DiscreteConditional::DiscreteConditional", debug); // progress + + // make a vector with slot availability, initially all 1 + // Reads file to get count :-) + vector slotsAvailable(largeExample(0).nrTimeSlots(), 1.0); + + // now, find optimal value for each student, using relaxed mutex constraints + for (size_t s = 0; s < 7; s++) { + // add all students first time, then drop last one second time, etc... + Scheduler scheduler = largeExample(7 - s); + //scheduler.print(str(boost::format("Scheduler %d") % (7-s))); + + // only allow slots not yet taken + scheduler.setSlotsAvailable(slotsAvailable); + + // BUILD THE GRAPH ! + scheduler.buildGraph(addMutex); + + // Do EXACT INFERENCE + tic_("eliminate"); + DiscreteBayesNet::shared_ptr chordal = scheduler.eliminate(); + toc_("eliminate"); + + // find root node + DiscreteConditional::shared_ptr root = *(chordal->rbegin()); + if (debug) + root->print(""/*scheduler.studentName(s)*/); + + // solve root node only + Scheduler::Values values; + size_t bestSlot = root->solve(values); + + // get corresponding count + DiscreteKey dkey = scheduler.studentKey(6 - s); + values[dkey.first] = bestSlot; + size_t count = (*root)(values); + + // remove this slot from consideration + slotsAvailable[bestSlot] = 0.0; + cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(6-s) + % scheduler.slotName(bestSlot) % bestSlot % count << endl; + } + tictoc_print_(); + + // Solution with addMutex = 2: (20 secs) + // TC = Wed 2 (9), count = 96375041778 + // AC = Tue 2 (5), count = 4076088090 + // SJ = Mon 1 (0), count = 29596704 + // TK = Mon 3 (2), count = 755370 + // JH = Wed 4 (11), count = 12000 + // TH = Fri 2 (17), count = 220 + // MN = Fri 1 (16), count = 5 + // + // Mutex does make a difference !! + +} + +/* ************************************************************************* */ +// Sample from solution found above and evaluate cost function +bool NonZero(size_t i) { + return i > 0; +} + +DiscreteBayesNet::shared_ptr createSampler(size_t i, + size_t slot, vector& schedulers) { + Scheduler scheduler = largeExample(0); // todo: wrong nr students + addStudent(scheduler, i); + SETDEBUG("Scheduler::buildGraph", false); + scheduler.addStudentSpecificConstraints(0, slot); + DiscreteBayesNet::shared_ptr chordal = scheduler.eliminate(); + // chordal->print(scheduler[i].studentKey(0).name()); // large ! + schedulers.push_back(scheduler); + return chordal; +} + +void sampleSolutions() { + + vector schedulers; + vector samplers(7); + + // Given the time-slots, we can create 7 independent samplers + vector slots; + slots += 16, 17, 11, 2, 0, 5, 9; // given slots + for (size_t i = 0; i < 7; i++) + samplers[i] = createSampler(i, slots[i], schedulers); + + // now, sample schedules + for (size_t n = 0; n < 500; n++) { + vector stats(19, 0); + vector samples; + for (size_t i = 0; i < 7; i++) { + samples.push_back(sample(*samplers[i])); + schedulers[i].accumulateStats(samples[i], stats); + } + size_t max = *max_element(stats.begin(), stats.end()); + size_t min = *min_element(stats.begin(), stats.end()); + size_t nz = count_if(stats.begin(), stats.end(), NonZero); + if (nz >= 15 && max <= 2) { + cout << boost::format( + "Sampled schedule %d, min = %d, nz = %d, max = %d\n") % (n + 1) % min + % nz % max; + for (size_t i = 0; i < 7; i++) { + cout << schedulers[i].studentName(0) << " : " << schedulers[i].slotName( + slots[i]) << endl; + schedulers[i].printSpecial(samples[i]); + } + } + } + // Output was + // Sampled schedule 359, min = 0, nz = 15, max = 2 + // Michael N : Fri 9:00-10.30 + // Michael N AI: Frank Dellaert + // Michael N Autonomy: Charlie Kemp + // Michael N Perception: Henrik Christensen + // + // Tucker H : Fri 10:30-12:00 + // Tucker H AI: Ayanna Howard + // Tucker H Controls: Patricio Vela + // Tucker H Perception: Irfan Essa + // + // Jake H : Wed 3:00-4:30 + // Jake H AI: Mike Stilman + // Jake H Controls: Magnus Egerstedt + // Jake H Perception: Jim Rehg + // + // Tobias K : Mon 1:30-3:00 + // Tobias K AI: Ayanna Howard + // Tobias K Autonomy: Charlie Kemp + // Tobias K Controls: Magnus Egerstedt + // + // Shu J : Mon 9:00-10.30 + // Shu J AI: Mike Stilman + // Shu J Controls: Jun Ueda + // Shu J HRI: Andrea Thomaz + // + // Akansel C : Tue 10:30-12:00 + // Akansel C AI: Frank Dellaert + // Akansel C Autonomy: Tucker Balch + // Akansel C Mechanics: Harvey Lipkin + // + // Tiffany C : Wed 10:30-12:00 + // Tiffany C Controls: Patricio Vela + // Tiffany C N/A 1: N/A 1 + // Tiffany C N/A 2: N/A 2 + +} + +/* ************************************************************************* */ +void accomodateStudent() { + + // super-hack! just count... + bool debug = false; + // SETDEBUG("DiscreteConditional::COUNT",true); + SETDEBUG("DiscreteConditional::DiscreteConditional", debug); // progress + + Scheduler scheduler = largeExample(0); + // scheduler.addStudent("Victor E", "Autonomy", "HRI", "AI", + // "Henrik Christensen"); + scheduler.addStudent("Carlos N", "Perception", "AI", "Autonomy", + "Henrik Christensen"); + scheduler.print("scheduler"); + + // rule out all occupied slots + vector slots; + slots += 16, 17, 11, 2, 0, 5, 9, 14; + vector slotsAvailable(scheduler.nrTimeSlots(), 1.0); + BOOST_FOREACH(size_t s, slots) + slotsAvailable[s] = 0; + scheduler.setSlotsAvailable(slotsAvailable); + + // BUILD THE GRAPH ! + scheduler.buildGraph(1); + + // Do EXACT INFERENCE + DiscreteBayesNet::shared_ptr chordal = scheduler.eliminate(); + + // find root node + DiscreteConditional::shared_ptr root = *(chordal->rbegin()); + if (debug) + root->print(""/*scheduler.studentName(s)*/); + // GTSAM_PRINT(*chordal); + + // solve root node only + Scheduler::Values values; + size_t bestSlot = root->solve(values); + + // get corresponding count + DiscreteKey dkey = scheduler.studentKey(0); + values[dkey.first] = bestSlot; + size_t count = (*root)(values); + cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(0) + % scheduler.slotName(bestSlot) % bestSlot % count << endl; + + // sample schedules + for (size_t n = 0; n < 10; n++) { + Scheduler::sharedValues sample0 = sample(*chordal); + scheduler.printAssignment(sample0); + } +} + +/* ************************************************************************* */ +int main() { + runLargeExample(); + solveStaged(3); +// sampleSolutions(); + // accomodateStudent(); + return 0; +} +/* ************************************************************************* */ + diff --git a/gtsam/discrete/examples/schedulingQuals12.cpp b/gtsam/discrete/examples/schedulingQuals12.cpp new file mode 100644 index 000000000..230218997 --- /dev/null +++ b/gtsam/discrete/examples/schedulingQuals12.cpp @@ -0,0 +1,264 @@ +/* + * schedulingExample.cpp + * @brief hard scheduling example + * @date March 25, 2011 + * @author Frank Dellaert + */ + +#define ENABLE_TIMING +#define ADD_NO_CACHING +#define ADD_NO_PRUNING +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +using namespace boost::assign; +using namespace std; +using namespace gtsam; + +size_t NRSTUDENTS = 9; + +bool NonZero(size_t i) { + return i > 0; +} + +/* ************************************************************************* */ +void addStudent(Scheduler& s, size_t i) { + switch (i) { + case 0: + s.addStudent("Pan, Yunpeng", "Controls", "Perception", "Mechanics", "Eric Johnson"); + break; + case 1: + s.addStudent("Sawhney, Rahul", "Controls", "AI", "Perception", "Henrik Christensen"); + break; + case 2: + s.addStudent("Akgun, Baris", "Controls", "AI", "HRI", "Andrea Thomaz"); + break; + case 3: + s.addStudent("Jiang, Shu", "Controls", "AI", "Perception", "Ron Arkin"); + break; + case 4: + s.addStudent("Grice, Phillip", "Controls", "Perception", "HRI", "Charlie Kemp"); + break; + case 5: + s.addStudent("Huaman, Ana", "Controls", "AI", "Perception", "Mike Stilman"); + break; + case 6: + s.addStudent("Levihn, Martin", "AI", "Autonomy", "Perception", "Mike Stilman"); + break; + case 7: + s.addStudent("Nieto, Carlos", "AI", "Autonomy", "Perception", "Henrik Christensen"); + break; + case 8: + s.addStudent("Robinette, Paul", "Controls", "AI", "HRI", "Ayanna Howard"); + break; + } +} + +/* ************************************************************************* */ +Scheduler largeExample(size_t nrStudents = NRSTUDENTS) { + string path("/Users/dellaert/borg/gtsam2/gtsam2/discrete/examples/"); + Scheduler s(nrStudents, path + "Doodle2012.csv"); + + s.addArea("Harvey Lipkin", "Mechanics"); + s.addArea("Jun Ueda", "Mechanics"); + + s.addArea("Patricio Vela", "Controls"); + s.addArea("Magnus Egerstedt", "Controls"); + s.addArea("Jun Ueda", "Controls"); + s.addArea("Panos Tsiotras", "Controls"); + s.addArea("Fumin Zhang", "Controls"); + + s.addArea("Henrik Christensen", "Perception"); + s.addArea("Aaron Bobick", "Perception"); + + s.addArea("Mike Stilman", "AI"); +// s.addArea("Henrik Christensen", "AI"); + s.addArea("Ayanna Howard", "AI"); + s.addArea("Charles Isbell", "AI"); + s.addArea("Tucker Balch", "AI"); + + s.addArea("Ayanna Howard", "Autonomy"); + s.addArea("Charlie Kemp", "Autonomy"); + s.addArea("Tucker Balch", "Autonomy"); + s.addArea("Ron Arkin", "Autonomy"); + + s.addArea("Andrea Thomaz", "HRI"); + s.addArea("Karen Feigh", "HRI"); + s.addArea("Charlie Kemp", "HRI"); + + // add students + for (size_t i = 0; i < nrStudents; i++) + addStudent(s, i); + + return s; +} + +/* ************************************************************************* */ +void runLargeExample() { + + Scheduler scheduler = largeExample(); + scheduler.print(); + + // BUILD THE GRAPH ! + size_t addMutex = 3; + // SETDEBUG("Scheduler::buildGraph", true); + scheduler.buildGraph(addMutex); + + // Do brute force product and output that to file + if (scheduler.nrStudents() == 1) { // otherwise too slow + DecisionTreeFactor product = scheduler.product(); + product.dot("scheduling-large", false); + } + + // Do exact inference + // SETDEBUG("timing-verbose", true); + SETDEBUG("DiscreteConditional::DiscreteConditional", true); +#define SAMPLE +#ifdef SAMPLE + tic(2, "large"); + DiscreteBayesNet::shared_ptr chordal = scheduler.eliminate(); + toc(2, "large"); + tictoc_finishedIteration(); + tictoc_print(); + for (size_t i=0;i<100;i++) { + DiscreteFactor::sharedValues assignment = sample(*chordal); + vector stats(scheduler.nrFaculty()); + scheduler.accumulateStats(assignment, stats); + size_t max = *max_element(stats.begin(), stats.end()); + size_t min = *min_element(stats.begin(), stats.end()); + size_t nz = count_if(stats.begin(), stats.end(), NonZero); +// cout << min << ", " << max << ", " << nz << endl; + if (nz >= 13 && min >=1 && max <= 4) { + cout << "======================================================\n"; + scheduler.printAssignment(assignment); + } + } +#else + tic(2, "large"); + DiscreteFactor::sharedValues MPE = scheduler.optimalAssignment(); + toc(2, "large"); + tictoc_finishedIteration(); + tictoc_print(); + scheduler.printAssignment(MPE); +#endif +} + +/* ************************************************************************* */ +// Solve a series of relaxed problems for maximum flexibility solution +void solveStaged(size_t addMutex = 2) { + + // super-hack! just count... + bool debug = false; + SETDEBUG("DiscreteConditional::COUNT", true); + SETDEBUG("DiscreteConditional::DiscreteConditional", debug); // progress + + // make a vector with slot availability, initially all 1 + // Reads file to get count :-) + vector slotsAvailable(largeExample(0).nrTimeSlots(), 1.0); + + // now, find optimal value for each student, using relaxed mutex constraints + for (size_t s = 0; s < NRSTUDENTS; s++) { + // add all students first time, then drop last one second time, etc... + Scheduler scheduler = largeExample(NRSTUDENTS - s); + //scheduler.print(str(boost::format("Scheduler %d") % (NRSTUDENTS-s))); + + // only allow slots not yet taken + scheduler.setSlotsAvailable(slotsAvailable); + + // BUILD THE GRAPH ! + scheduler.buildGraph(addMutex); + + // Do EXACT INFERENCE + tic_("eliminate"); + DiscreteBayesNet::shared_ptr chordal = scheduler.eliminate(); + toc_("eliminate"); + + // find root node + DiscreteConditional::shared_ptr root = *(chordal->rbegin()); + if (debug) + root->print(""/*scheduler.studentName(s)*/); + + // solve root node only + Scheduler::Values values; + size_t bestSlot = root->solve(values); + + // get corresponding count + DiscreteKey dkey = scheduler.studentKey(NRSTUDENTS - 1 - s); + values[dkey.first] = bestSlot; + size_t count = (*root)(values); + + // remove this slot from consideration + slotsAvailable[bestSlot] = 0.0; + cout << boost::format("%s = %d (%d), count = %d") % scheduler.studentName(NRSTUDENTS-1-s) + % scheduler.slotName(bestSlot) % bestSlot % count << endl; + } + tictoc_print_(); +} + +/* ************************************************************************* */ +// Sample from solution found above and evaluate cost function +DiscreteBayesNet::shared_ptr createSampler(size_t i, + size_t slot, vector& schedulers) { + Scheduler scheduler = largeExample(0); // todo: wrong nr students + addStudent(scheduler, i); + SETDEBUG("Scheduler::buildGraph", false); + scheduler.addStudentSpecificConstraints(0, slot); + DiscreteBayesNet::shared_ptr chordal = scheduler.eliminate(); + // chordal->print(scheduler[i].studentKey(0).name()); // large ! + schedulers.push_back(scheduler); + return chordal; +} + +void sampleSolutions() { + + vector schedulers; + vector samplers(NRSTUDENTS); + + // Given the time-slots, we can create NRSTUDENTS independent samplers + vector slots; + slots += 3, 20, 2, 6, 5, 11, 1, 4; // given slots + for (size_t i = 0; i < NRSTUDENTS; i++) + samplers[i] = createSampler(i, slots[i], schedulers); + + // now, sample schedules + for (size_t n = 0; n < 500; n++) { + vector stats(19, 0); + vector samples; + for (size_t i = 0; i < NRSTUDENTS; i++) { + samples.push_back(sample(*samplers[i])); + schedulers[i].accumulateStats(samples[i], stats); + } + size_t max = *max_element(stats.begin(), stats.end()); + size_t min = *min_element(stats.begin(), stats.end()); + size_t nz = count_if(stats.begin(), stats.end(), NonZero); + if (nz >= 15 && max <= 2) { + cout << boost::format( + "Sampled schedule %d, min = %d, nz = %d, max = %d\n") % (n + 1) % min + % nz % max; + for (size_t i = 0; i < NRSTUDENTS; i++) { + cout << schedulers[i].studentName(0) << " : " << schedulers[i].slotName( + slots[i]) << endl; + schedulers[i].printSpecial(samples[i]); + } + } + } +} + +/* ************************************************************************* */ +int main() { + runLargeExample(); +// solveStaged(3); +// sampleSolutions(); + return 0; +} +/* ************************************************************************* */ + diff --git a/gtsam/discrete/examples/small.csv b/gtsam/discrete/examples/small.csv new file mode 100644 index 000000000..144ead08c --- /dev/null +++ b/gtsam/discrete/examples/small.csv @@ -0,0 +1 @@ +,Frank,Harvey,Magnus,Andrea Mon,1,1,1, Wed,1,1,1,1 Fri,,1,1,1 \ No newline at end of file diff --git a/gtsam/discrete/label_traits.h b/gtsam/discrete/label_traits.h new file mode 100644 index 000000000..0a0c39094 --- /dev/null +++ b/gtsam/discrete/label_traits.h @@ -0,0 +1,41 @@ +/* + * label_traits.h + * @brief traits class for labels used in Decision Diagram + * @author Frank Dellaert + * @date Mar 22, 2011 + */ + +#pragma once + +#include +#include + +namespace gtsam { + + /** + * Default traits class for label type, http://www.cantrip.org/traits.html + * Override to provide non-default behavior, see example in Index + */ + template + struct label_traits { + /** default = binary label */ + static size_t cardinality(const T&) { + return 2; + } + /** default higher(a,b) = a hasher; + return hasher(a); + } + }; + + /* custom hash function for labels, no need to specialize this */ + template + std::size_t hash_value(const T& a) { + return label_traits::hash_value(a); + } +} diff --git a/gtsam/discrete/parseUAI.cpp b/gtsam/discrete/parseUAI.cpp new file mode 100644 index 000000000..417063da2 --- /dev/null +++ b/gtsam/discrete/parseUAI.cpp @@ -0,0 +1,157 @@ +/* + * parseUAI.cpp + * @brief: parse UAI 2008 format + * @date March 5, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +//#define PARSE +#ifdef PARSE +#include +#include // for parsing +#include // for ref +#include +#include +#include +#include + +#include + +using namespace std; +namespace qi = boost::spirit::qi; + +namespace gtsam { + + /* ************************************************************************* */ + // Keys are the vars of variables connected to a factor + // subclass of Indices with special constructor + struct Keys: public Indices { + Keys() { + } + // Pick correct vars based on indices + Keys(const Indices& vars, const vector& indices) { + BOOST_FOREACH(int i, indices) + push_back(vars[i]); + } + }; + + /* ************************************************************************* */ + // The UAI grammar is defined in a class + // Spirit local variables are used, see + // http://boost-spirit.com/home/2010/01/21/what-are-rule-bound-semantic-actions + /* ************************************************************************* */ + struct Grammar { + + // declare all parsers as instance variables + typedef vector Table; + typedef boost::spirit::istream_iterator It; + qi::rule uai, preamble, type, vars, factors, tables; + qi::rule > keys; + qi::rule > table; + + // Variables filled by preamble parser + size_t nrVars_, nrFactors_; + Indices vars_; + vector factors_; + + // Variables filled by tables parser + vector
tables_; + + // The constructor defines the parser rules (declared below) + // To debug, just say debug(rule) after defining the rule + Grammar() { + using boost::phoenix::val; + using boost::phoenix::ref; + using boost::phoenix::construct; + using namespace boost::spirit::qi; + + //--------------- high level parsers with side-effects :-( ----------------- + + // A uai file consists of preamble followed by tables + uai = preamble >> tables; + + // The preamble defines the variables and factors + // The parser fills in the first set of variables above, + // including the vector of factor "Neighborhoods" + preamble = type >> vars >> int_[ref(nrFactors_) = _1] >> factors; + + // type string, does not seem to matter + type = lit("BAYES") | lit("MARKOV"); + + // vars parses "3 2 2 3" and synthesizes a Keys class, in this case + // containing Indices {v0,2}, {v1,2}, and {v2,3} + vars = int_[ref(nrVars_) = _1] >> (repeat(ref(nrVars_))[int_]) // + [ref(vars_) = construct (_1)]; + + // Parse a list of Neighborhoods and fill factors_ + factors = (repeat(ref(nrFactors_))[keys])// + [ref(factors_) = _1]; + + // The tables parser fills in the tables_ + tables = (repeat(ref(nrFactors_))[table])// + [ref(tables_) = _1]; + + //----------- basic parsers with synthesized attributes :-) ----------------- + + // keys parses strings like "2 1 2", indicating + // a binary factor (2) on variables v1 and v2. + // It returns a Keys class as attribute + keys = int_[_a = _1] >> repeat(_a)[int_] // + [_val = construct (ref(vars_), _1)]; + + // The tables are a list of doubles preceded by a count, e.g. "4 1.0 2.0 3.0 4.0" + // The table parser returns a PotentialTable::Table attribute + table = int_[_a = _1] >> repeat(_a)[double_] // + [_val = construct
(_1)]; + } + + // Add the factors to the graph + void addFactorsToGraph(TypedDiscreteFactorGraph& graph) { + assert(factors_.size()==nrFactors_); + assert(tables_.size()==nrFactors_); + for (size_t i = 0; i < nrFactors_; i++) + graph.add(factors_[i], tables_[i]); + } + + }; + + /* ************************************************************************* */ + bool parseUAI(const std::string& filename, TypedDiscreteFactorGraph& graph) { + + // open file, disable skipping of whitespace + std::ifstream in(filename.c_str()); + if (!in) { + cerr << "Could not open " << filename << endl; + return false; + } + + in.unsetf(std::ios::skipws); + + // wrap istream into iterator + boost::spirit::istream_iterator first(in); + boost::spirit::istream_iterator last; + + // Parse and add factors into the graph + Grammar grammar; + bool success = qi::phrase_parse(first, last, grammar.uai, qi::space); + if (success) grammar.addFactorsToGraph(graph); + + return success; + } +/* ************************************************************************* */ + +}// gtsam +#else + +#include + +namespace gtsam { + +/** Dummy version of function - otherwise, missing symbol */ +bool parseUAI(const std::string& filename, TypedDiscreteFactorGraph& graph) { + return false; +} + +} // \namespace gtsam +#endif diff --git a/gtsam/discrete/parseUAI.h b/gtsam/discrete/parseUAI.h new file mode 100644 index 000000000..070c62e08 --- /dev/null +++ b/gtsam/discrete/parseUAI.h @@ -0,0 +1,22 @@ +/* + * parseUAI.h + * @brief: parse UAI 2008 format + * @date March 5, 2011 + * @author Duy-Nguyen Ta + * @author Frank Dellaert + */ + +#include +#include + +namespace gtsam { + + /** + * Constructor from file + * For now assumes in .uai format from UAI'08 Probablistic Inference Evaluation + * See http://graphmod.ics.uci.edu/uai08/FileFormat + */ + bool parseUAI(const std::string& filename, + gtsam::TypedDiscreteFactorGraph& graph); + +} // gtsam diff --git a/gtsam/discrete/tests/data/FG/alarm.fg b/gtsam/discrete/tests/data/FG/alarm.fg new file mode 100644 index 000000000..40fbb6f9d --- /dev/null +++ b/gtsam/discrete/tests/data/FG/alarm.fg @@ -0,0 +1,935 @@ +# ALARM network +# from http://compbio.cs.huji.ac.il/Repository/Datasets/alarm/alarm.dsc +37 + +2 +0 5 +2 2 +4 +0 0.9 +1 0.1 +2 0.01 +3 0.99 + +2 +1 4 +3 3 +9 +0 0.95 +1 0.04 +2 0.01 +3 0.04 +4 0.95 +5 0.01 +6 0.01 +7 0.29 +8 0.7 + +2 +2 4 +3 3 +9 +0 0.95 +1 0.04 +2 0.01 +3 0.04 +4 0.95 +5 0.01 +6 0.01 +7 0.04 +8 0.95 + +1 +3 +2 +2 +0 0.2 +1 0.8 + +3 +3 4 5 +2 3 2 +12 +0 0.95 +1 0.01 +2 0.04 +3 0.09 +4 0.01 +5 0.9 +6 0.98 +7 0.05 +8 0.01 +9 0.9 +10 0.01 +11 0.05 + +1 +5 +2 +2 +0 0.05 +1 0.95 + +3 +3 5 6 +2 2 3 +12 +0 0.98 +1 0.5 +2 0.95 +3 0.05 +4 0.01 +5 0.49 +6 0.04 +7 0.9 +8 0.01 +9 0.01 +10 0.01 +11 0.05 + +1 +7 +2 +2 +0 0.05 +1 0.95 + +3 +7 8 34 +2 3 3 +18 +0 0.98 +1 0.98 +2 0.01 +3 0.01 +4 0.01 +5 0.01 +6 0.4 +7 0.01 +8 0.59 +9 0.98 +10 0.01 +11 0.01 +12 0.3 +13 0.01 +14 0.4 +15 0.01 +16 0.3 +17 0.98 + +3 +9 10 34 +3 2 3 +18 +0 0.333 +1 0.333 +2 0.333 +3 0.98 +4 0.01 +5 0.01 +6 0.333 +7 0.333 +8 0.333 +9 0.01 +10 0.98 +11 0.01 +12 0.333 +13 0.333 +14 0.333 +15 0.01 +16 0.01 +17 0.98 + +1 +10 +2 +2 +0 0.1 +1 0.9 + +3 +10 11 34 +2 3 3 +18 +0 0.333 +1 0.98 +2 0.333 +3 0.01 +4 0.333 +5 0.01 +6 0.333 +7 0.01 +8 0.333 +9 0.98 +10 0.333 +11 0.01 +12 0.333 +13 0.01 +14 0.333 +15 0.01 +16 0.333 +17 0.98 + +1 +12 +2 +2 +0 0.1 +1 0.9 + +1 +13 +2 +2 +0 0.01 +1 0.99 + +2 +13 14 +2 3 +6 +0 0.98 +1 0.3 +2 0.01 +3 0.4 +4 0.01 +5 0.3 + +3 +15 30 32 +4 4 3 +48 +0 0.97 +1 0.01 +2 0.01 +3 0.01 +4 0.01 +5 0.97 +6 0.01 +7 0.01 +8 0.01 +9 0.01 +10 0.97 +11 0.01 +12 0.01 +13 0.01 +14 0.01 +15 0.97 +16 0.01 +17 0.97 +18 0.01 +19 0.01 +20 0.97 +21 0.01 +22 0.01 +23 0.01 +24 0.01 +25 0.01 +26 0.97 +27 0.01 +28 0.01 +29 0.01 +30 0.01 +31 0.97 +32 0.01 +33 0.97 +34 0.01 +35 0.01 +36 0.01 +37 0.01 +38 0.97 +39 0.01 +40 0.97 +41 0.01 +42 0.01 +43 0.01 +44 0.01 +45 0.01 +46 0.01 +47 0.97 + +1 +16 +2 +2 +0 0.04 +1 0.96 + +3 +17 24 30 +4 3 4 +48 +0 0.97 +1 0.01 +2 0.01 +3 0.01 +4 0.97 +5 0.01 +6 0.01 +7 0.01 +8 0.97 +9 0.01 +10 0.01 +11 0.01 +12 0.01 +13 0.97 +14 0.01 +15 0.01 +16 0.6 +17 0.38 +18 0.01 +19 0.01 +20 0.01 +21 0.97 +22 0.01 +23 0.01 +24 0.01 +25 0.01 +26 0.97 +27 0.01 +28 0.5 +29 0.48 +30 0.01 +31 0.01 +32 0.01 +33 0.01 +34 0.97 +35 0.01 +36 0.01 +37 0.01 +38 0.01 +39 0.97 +40 0.5 +41 0.48 +42 0.01 +43 0.01 +44 0.01 +45 0.01 +46 0.01 +47 0.97 + +1 +18 +2 +2 +0 0.05 +1 0.95 + +3 +18 19 31 +2 3 4 +19 +0 1 +1 1 +6 0.99 +7 0.95 +8 0.01 +9 0.04 +11 0.01 +12 0.95 +13 0.01 +14 0.04 +15 0.95 +16 0.01 +17 0.04 +18 0.95 +19 0.01 +20 0.04 +21 0.01 +22 0.01 +23 0.98 + +3 +19 20 23 +3 3 2 +18 +0 0.98 +1 0.01 +2 0.98 +3 0.01 +4 0.01 +5 0.01 +6 0.01 +7 0.98 +8 0.01 +9 0.01 +10 0.98 +11 0.69 +12 0.98 +13 0.01 +14 0.3 +15 0.01 +16 0.01 +17 0.01 + +2 +21 22 +3 2 +6 +0 0.01 +1 0.19 +2 0.8 +3 0.05 +4 0.9 +5 0.05 + +1 +22 +2 +2 +0 0.01 +1 0.99 + +3 +22 23 24 +2 2 3 +12 +0 0.1 +1 0.95 +2 0.9 +3 0.05 +4 0.1 +5 0.95 +6 0.9 +7 0.05 +8 0.01 +9 0.05 +10 0.99 +11 0.95 + +1 +24 +3 +3 +0 0.92 +1 0.03 +2 0.05 + +4 +16 24 25 29 +2 3 4 4 +96 +0 0.97 +1 0.97 +2 0.97 +3 0.97 +4 0.97 +5 0.97 +6 0.01 +7 0.01 +8 0.01 +9 0.01 +10 0.01 +11 0.01 +12 0.01 +13 0.01 +14 0.01 +15 0.01 +16 0.01 +17 0.01 +18 0.01 +19 0.01 +20 0.01 +21 0.01 +22 0.01 +23 0.01 +24 0.01 +25 0.01 +26 0.1 +27 0.4 +28 0.01 +29 0.01 +30 0.3 +31 0.97 +32 0.84 +33 0.58 +34 0.29 +35 0.9 +36 0.49 +37 0.01 +38 0.05 +39 0.01 +40 0.3 +41 0.08 +42 0.2 +43 0.01 +44 0.01 +45 0.01 +46 0.4 +47 0.01 +48 0.01 +49 0.01 +50 0.05 +51 0.2 +52 0.01 +53 0.01 +54 0.01 +55 0.01 +56 0.25 +57 0.75 +58 0.01 +59 0.01 +60 0.08 +61 0.97 +62 0.25 +63 0.04 +64 0.08 +65 0.38 +66 0.9 +67 0.01 +68 0.45 +69 0.01 +70 0.9 +71 0.6 +72 0.01 +73 0.01 +74 0.01 +75 0.2 +76 0.01 +77 0.01 +78 0.01 +79 0.01 +80 0.15 +81 0.7 +82 0.01 +83 0.01 +84 0.01 +85 0.01 +86 0.25 +87 0.09 +88 0.01 +89 0.01 +90 0.97 +91 0.97 +92 0.59 +93 0.01 +94 0.97 +95 0.97 + +1 +26 +2 +2 +0 0.1 +1 0.9 + +1 +27 +3 +3 +0 0.05 +1 0.9 +2 0.05 + +2 +27 28 +3 4 +12 +0 0.05 +1 0.05 +2 0.05 +3 0.93 +4 0.01 +5 0.01 +6 0.01 +7 0.93 +8 0.01 +9 0.01 +10 0.01 +11 0.93 + +3 +26 28 29 +2 4 4 +32 +0 0.97 +1 0.97 +2 0.97 +3 0.01 +4 0.97 +5 0.01 +6 0.97 +7 0.01 +8 0.01 +9 0.01 +10 0.01 +11 0.97 +12 0.01 +13 0.01 +14 0.01 +15 0.01 +16 0.01 +17 0.01 +18 0.01 +19 0.01 +20 0.01 +21 0.97 +22 0.01 +23 0.01 +24 0.01 +25 0.01 +26 0.01 +27 0.01 +28 0.01 +29 0.01 +30 0.01 +31 0.97 + +4 +16 24 29 30 +2 3 4 4 +96 +0 0.97 +1 0.97 +2 0.97 +3 0.97 +4 0.97 +5 0.97 +6 0.95 +7 0.01 +8 0.97 +9 0.97 +10 0.95 +11 0.01 +12 0.4 +13 0.01 +14 0.97 +15 0.97 +16 0.5 +17 0.01 +18 0.3 +19 0.01 +20 0.97 +21 0.97 +22 0.3 +23 0.01 +24 0.01 +25 0.01 +26 0.01 +27 0.01 +28 0.01 +29 0.01 +30 0.03 +31 0.97 +32 0.01 +33 0.01 +34 0.03 +35 0.97 +36 0.58 +37 0.01 +38 0.01 +39 0.01 +40 0.48 +41 0.01 +42 0.68 +43 0.01 +44 0.01 +45 0.01 +46 0.68 +47 0.01 +48 0.01 +49 0.01 +50 0.01 +51 0.01 +52 0.01 +53 0.01 +54 0.01 +55 0.01 +56 0.01 +57 0.01 +58 0.01 +59 0.01 +60 0.01 +61 0.97 +62 0.01 +63 0.01 +64 0.01 +65 0.97 +66 0.01 +67 0.01 +68 0.01 +69 0.01 +70 0.01 +71 0.01 +72 0.01 +73 0.01 +74 0.01 +75 0.01 +76 0.01 +77 0.01 +78 0.01 +79 0.01 +80 0.01 +81 0.01 +82 0.01 +83 0.01 +84 0.01 +85 0.01 +86 0.01 +87 0.01 +88 0.01 +89 0.01 +90 0.01 +91 0.97 +92 0.01 +93 0.01 +94 0.01 +95 0.97 + +3 +24 30 31 +3 4 4 +48 +0 0.97 +1 0.97 +2 0.97 +3 0.01 +4 0.01 +5 0.03 +6 0.01 +7 0.01 +8 0.01 +9 0.01 +10 0.01 +11 0.01 +12 0.01 +13 0.01 +14 0.01 +15 0.97 +16 0.97 +17 0.95 +18 0.01 +19 0.01 +20 0.94 +21 0.01 +22 0.01 +23 0.88 +24 0.01 +25 0.01 +26 0.01 +27 0.01 +28 0.01 +29 0.01 +30 0.97 +31 0.97 +32 0.04 +33 0.01 +34 0.01 +35 0.1 +36 0.01 +37 0.01 +38 0.01 +39 0.01 +40 0.01 +41 0.01 +42 0.01 +43 0.01 +44 0.01 +45 0.97 +46 0.97 +47 0.01 + +2 +31 32 +4 3 +12 +0 0.01 +1 0.01 +2 0.04 +3 0.9 +4 0.01 +5 0.01 +6 0.92 +7 0.09 +8 0.98 +9 0.98 +10 0.04 +11 0.01 + +5 +12 14 20 32 33 +2 3 3 3 2 +108 +0 0.01 +1 0.05 +2 0.01 +3 0.7 +4 0.01 +5 0.95 +6 0.01 +7 0.05 +8 0.01 +9 0.7 +10 0.05 +11 0.95 +12 0.01 +13 0.05 +14 0.05 +15 0.7 +16 0.05 +17 0.95 +18 0.01 +19 0.05 +20 0.01 +21 0.7 +22 0.01 +23 0.99 +24 0.01 +25 0.05 +26 0.01 +27 0.7 +28 0.05 +29 0.99 +30 0.01 +31 0.05 +32 0.05 +33 0.7 +34 0.05 +35 0.99 +36 0.01 +37 0.01 +38 0.01 +39 0.1 +40 0.01 +41 0.3 +42 0.01 +43 0.01 +44 0.01 +45 0.1 +46 0.01 +47 0.3 +48 0.01 +49 0.01 +50 0.01 +51 0.1 +52 0.01 +53 0.3 +54 0.99 +55 0.95 +56 0.99 +57 0.3 +58 0.99 +59 0.05 +60 0.99 +61 0.95 +62 0.99 +63 0.3 +64 0.95 +65 0.05 +66 0.99 +67 0.95 +68 0.95 +69 0.3 +70 0.95 +71 0.05 +72 0.99 +73 0.95 +74 0.99 +75 0.3 +76 0.99 +77 0.01 +78 0.99 +79 0.95 +80 0.99 +81 0.3 +82 0.95 +83 0.01 +84 0.99 +85 0.95 +86 0.95 +87 0.3 +88 0.95 +89 0.01 +90 0.99 +91 0.99 +92 0.99 +93 0.9 +94 0.99 +95 0.7 +96 0.99 +97 0.99 +98 0.99 +99 0.9 +100 0.99 +101 0.7 +102 0.99 +103 0.99 +104 0.99 +105 0.9 +106 0.99 +107 0.7 + +2 +33 34 +2 3 +6 +0 0.05 +1 0.01 +2 0.9 +3 0.09 +4 0.05 +5 0.9 + +3 +6 34 35 +3 3 3 +27 +0 0.98 +1 0.95 +2 0.3 +3 0.95 +4 0.04 +5 0.01 +6 0.8 +7 0.01 +8 0.01 +9 0.01 +10 0.04 +11 0.69 +12 0.04 +13 0.95 +14 0.3 +15 0.19 +16 0.04 +17 0.01 +18 0.01 +19 0.01 +20 0.01 +21 0.01 +22 0.01 +23 0.69 +24 0.01 +25 0.95 +26 0.98 + +3 +14 35 36 +3 3 3 +27 +0 0.98 +1 0.98 +2 0.3 +3 0.98 +4 0.1 +5 0.05 +6 0.9 +7 0.05 +8 0.01 +9 0.01 +10 0.01 +11 0.6 +12 0.01 +13 0.85 +14 0.4 +15 0.09 +16 0.2 +17 0.09 +18 0.01 +19 0.01 +20 0.1 +21 0.01 +22 0.05 +23 0.55 +24 0.01 +25 0.75 +26 0.9 diff --git a/gtsam/discrete/tests/data/UAI/sampleMARKOV.uai b/gtsam/discrete/tests/data/UAI/sampleMARKOV.uai new file mode 100644 index 000000000..aacf458ed --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/sampleMARKOV.uai @@ -0,0 +1,18 @@ +MARKOV +3 +2 2 3 +3 +1 0 +2 0 1 +2 1 2 + +2 + 0.436 0.564 + +4 + 0.128 0.872 + 0.920 0.080 + +6 + 0.210 0.333 0.457 + 0.811 0.000 0.189 \ No newline at end of file diff --git a/gtsam/discrete/tests/data/UAI/sampleMARKOV.uai.evid b/gtsam/discrete/tests/data/UAI/sampleMARKOV.uai.evid new file mode 100644 index 000000000..59f3e67a5 --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/sampleMARKOV.uai.evid @@ -0,0 +1,3 @@ +2 + 1 0 + 2 1 \ No newline at end of file diff --git a/gtsam/discrete/tests/data/UAI/uai08_test1.uai b/gtsam/discrete/tests/data/UAI/uai08_test1.uai new file mode 100644 index 000000000..d205773fc --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test1.uai @@ -0,0 +1,996 @@ +BAYES +54 +2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 +54 +4 1 17 27 0 +4 49 32 38 1 +2 22 2 +2 19 3 +1 4 +2 49 5 +3 17 19 6 +4 36 26 28 7 +1 8 +5 22 51 4 28 9 +4 8 23 26 10 +2 5 11 +1 12 +3 27 53 13 +3 38 12 14 +3 32 29 15 +3 23 45 16 +4 49 38 23 17 +6 32 36 33 43 39 18 +4 23 26 44 19 +5 31 22 38 23 20 +6 22 29 34 37 40 21 +1 22 +5 49 31 2 29 23 +2 4 24 +3 49 5 25 +4 22 5 29 26 +2 2 27 +3 25 47 28 +1 29 +2 2 30 +1 31 +1 32 +3 42 41 33 +3 36 25 34 +5 32 38 5 41 35 +1 36 +10 36 27 19 53 18 46 50 35 11 37 +2 22 38 +3 38 26 39 +1 40 +1 41 +3 2 25 42 +4 1 23 4 43 +1 44 +4 34 7 0 45 +4 2 30 33 46 +5 49 23 26 27 47 +3 31 44 48 +1 49 +3 26 41 50 +1 51 +4 36 8 23 52 +3 32 26 53 + +16 + 0.285714 0.714286 + 0.461538 0.538462 + 0.307692 0.692308 + 0.300000 0.700000 + 0.333333 0.666667 + 0.714286 0.285714 + 0.588235 0.411765 + 0.588235 0.411765 + +16 + 0.625000 0.375000 + 0.750000 0.250000 + 0.625000 0.375000 + 0.166667 0.833333 + 0.555556 0.444444 + 0.545455 0.454545 + 0.500000 0.500000 + 0.428571 0.571429 + +4 + 0.666667 0.333333 + 0.571429 0.428571 + +4 + 0.461538 0.538462 + 0.272727 0.727273 + +2 + 0.230769 0.769231 + +4 + 0.625000 0.375000 + 0.583333 0.416667 + +8 + 0.800000 0.200000 + 0.500000 0.500000 + 0.333333 0.666667 + 0.250000 0.750000 + +16 + 0.411765 0.588235 + 0.500000 0.500000 + 0.769231 0.230769 + 0.692308 0.307692 + 0.625000 0.375000 + 0.600000 0.400000 + 0.833333 0.166667 + 0.571429 0.428571 + +2 + 0.700000 0.300000 + +32 + 0.833333 0.166667 + 0.250000 0.750000 + 0.500000 0.500000 + 0.500000 0.500000 + 0.526316 0.473684 + 0.500000 0.500000 + 0.428571 0.571429 + 0.500000 0.500000 + 0.444444 0.555556 + 0.666667 0.333333 + 0.636364 0.363636 + 0.384615 0.615385 + 0.222222 0.777778 + 0.411765 0.588235 + 0.526316 0.473684 + 0.583333 0.416667 + +16 + 0.357143 0.642857 + 0.363636 0.636364 + 0.166667 0.833333 + 0.777778 0.222222 + 0.473684 0.526316 + 0.538462 0.461538 + 0.500000 0.500000 + 0.470588 0.529412 + +4 + 0.470588 0.529412 + 0.400000 0.600000 + +2 + 0.500000 0.500000 + +8 + 0.473684 0.526316 + 0.555556 0.444444 + 0.411765 0.588235 + 0.714286 0.285714 + +8 + 0.800000 0.200000 + 0.444444 0.555556 + 0.600000 0.400000 + 0.666667 0.333333 + +8 + 0.500000 0.500000 + 0.333333 0.666667 + 0.625000 0.375000 + 0.692308 0.307692 + +8 + 0.285714 0.714286 + 0.714286 0.285714 + 0.500000 0.500000 + 0.500000 0.500000 + +16 + 0.833333 0.166667 + 0.454545 0.545455 + 0.625000 0.375000 + 0.250000 0.750000 + 0.727273 0.272727 + 0.588235 0.411765 + 0.400000 0.600000 + 0.500000 0.500000 + +64 + 0.166667 0.833333 + 0.666667 0.333333 + 0.692308 0.307692 + 0.538462 0.461538 + 0.500000 0.500000 + 0.250000 0.750000 + 0.437500 0.562500 + 0.473684 0.526316 + 0.769231 0.230769 + 0.400000 0.600000 + 0.555556 0.444444 + 0.272727 0.727273 + 0.473684 0.526316 + 0.818182 0.181818 + 0.750000 0.250000 + 0.416667 0.583333 + 0.588235 0.411765 + 0.769231 0.230769 + 0.500000 0.500000 + 0.473684 0.526316 + 0.833333 0.166667 + 0.444444 0.555556 + 0.600000 0.400000 + 0.529412 0.470588 + 0.727273 0.272727 + 0.615385 0.384615 + 0.444444 0.555556 + 0.400000 0.600000 + 0.642857 0.357143 + 0.200000 0.800000 + 0.333333 0.666667 + 0.437500 0.562500 + +16 + 0.666667 0.333333 + 0.250000 0.750000 + 0.625000 0.375000 + 0.357143 0.642857 + 0.500000 0.500000 + 0.300000 0.700000 + 0.526316 0.473684 + 0.600000 0.400000 + +32 + 0.444444 0.555556 + 0.583333 0.416667 + 0.500000 0.500000 + 0.571429 0.428571 + 0.400000 0.600000 + 0.500000 0.500000 + 0.333333 0.666667 + 0.666667 0.333333 + 0.473684 0.526316 + 0.500000 0.500000 + 0.545455 0.454545 + 0.454545 0.545455 + 0.500000 0.500000 + 0.466667 0.533333 + 0.777778 0.222222 + 0.222222 0.777778 + +64 + 0.333333 0.666667 + 0.818182 0.181818 + 0.526316 0.473684 + 0.375000 0.625000 + 0.625000 0.375000 + 0.444444 0.555556 + 0.473684 0.526316 + 0.533333 0.466667 + 0.500000 0.500000 + 0.500000 0.500000 + 0.363636 0.636364 + 0.300000 0.700000 + 0.250000 0.750000 + 0.562500 0.437500 + 0.571429 0.428571 + 0.642857 0.357143 + 0.666667 0.333333 + 0.363636 0.636364 + 0.384615 0.615385 + 0.600000 0.400000 + 0.818182 0.181818 + 0.428571 0.571429 + 0.625000 0.375000 + 0.562500 0.437500 + 0.583333 0.416667 + 0.529412 0.470588 + 0.529412 0.470588 + 0.545455 0.454545 + 0.333333 0.666667 + 0.230769 0.769231 + 0.500000 0.500000 + 0.230769 0.769231 + +2 + 0.588235 0.411765 + +32 + 0.333333 0.666667 + 0.333333 0.666667 + 0.428571 0.571429 + 0.600000 0.400000 + 0.750000 0.250000 + 0.666667 0.333333 + 0.411765 0.588235 + 0.583333 0.416667 + 0.800000 0.200000 + 0.545455 0.454545 + 0.333333 0.666667 + 0.375000 0.625000 + 0.571429 0.428571 + 0.285714 0.714286 + 0.555556 0.444444 + 0.461538 0.538462 + +4 + 0.500000 0.500000 + 0.625000 0.375000 + +8 + 0.727273 0.272727 + 0.461538 0.538462 + 0.777778 0.222222 + 0.400000 0.600000 + +16 + 0.555556 0.444444 + 0.600000 0.400000 + 0.571429 0.428571 + 0.833333 0.166667 + 0.777778 0.222222 + 0.357143 0.642857 + 0.285714 0.714286 + 0.642857 0.357143 + +4 + 0.461538 0.538462 + 0.250000 0.750000 + +8 + 0.692308 0.307692 + 0.529412 0.470588 + 0.437500 0.562500 + 0.666667 0.333333 + +2 + 0.727273 0.272727 + +4 + 0.500000 0.500000 + 0.571429 0.428571 + +2 + 0.375000 0.625000 + +2 + 0.428571 0.571429 + +8 + 0.666667 0.333333 + 0.444444 0.555556 + 0.500000 0.500000 + 0.416667 0.583333 + +8 + 0.357143 0.642857 + 0.461538 0.538462 + 0.272727 0.727273 + 0.411765 0.588235 + +32 + 0.470588 0.529412 + 0.466667 0.533333 + 0.700000 0.300000 + 0.555556 0.444444 + 0.444444 0.555556 + 0.666667 0.333333 + 0.466667 0.533333 + 0.466667 0.533333 + 0.200000 0.800000 + 0.588235 0.411765 + 0.166667 0.833333 + 0.333333 0.666667 + 0.526316 0.473684 + 0.562500 0.437500 + 0.333333 0.666667 + 0.700000 0.300000 + +2 + 0.166667 0.833333 + +1024 + 0.250000 0.750000 + 0.307692 0.692308 + 0.500000 0.500000 + 0.666667 0.333333 + 0.818182 0.181818 + 0.500000 0.500000 + 0.625000 0.375000 + 0.615385 0.384615 + 0.500000 0.500000 + 0.285714 0.714286 + 0.230769 0.769231 + 0.692308 0.307692 + 0.333333 0.666667 + 0.625000 0.375000 + 0.437500 0.562500 + 0.625000 0.375000 + 0.272727 0.727273 + 0.636364 0.363636 + 0.181818 0.818182 + 0.500000 0.500000 + 0.500000 0.500000 + 0.500000 0.500000 + 0.818182 0.181818 + 0.437500 0.562500 + 0.500000 0.500000 + 0.750000 0.250000 + 0.375000 0.625000 + 0.625000 0.375000 + 0.700000 0.300000 + 0.466667 0.533333 + 0.411765 0.588235 + 0.666667 0.333333 + 0.750000 0.250000 + 0.285714 0.714286 + 0.250000 0.750000 + 0.571429 0.428571 + 0.555556 0.444444 + 0.428571 0.571429 + 0.500000 0.500000 + 0.666667 0.333333 + 0.571429 0.428571 + 0.222222 0.777778 + 0.615385 0.384615 + 0.461538 0.538462 + 0.250000 0.750000 + 0.666667 0.333333 + 0.200000 0.800000 + 0.384615 0.615385 + 0.300000 0.700000 + 0.466667 0.533333 + 0.625000 0.375000 + 0.562500 0.437500 + 0.583333 0.416667 + 0.500000 0.500000 + 0.727273 0.272727 + 0.571429 0.428571 + 0.250000 0.750000 + 0.333333 0.666667 + 0.500000 0.500000 + 0.545455 0.454545 + 0.333333 0.666667 + 0.666667 0.333333 + 0.461538 0.538462 + 0.181818 0.818182 + 0.714286 0.285714 + 0.666667 0.333333 + 0.470588 0.529412 + 0.500000 0.500000 + 0.470588 0.529412 + 0.500000 0.500000 + 0.416667 0.583333 + 0.625000 0.375000 + 0.625000 0.375000 + 0.692308 0.307692 + 0.500000 0.500000 + 0.666667 0.333333 + 0.714286 0.285714 + 0.600000 0.400000 + 0.461538 0.538462 + 0.500000 0.500000 + 0.500000 0.500000 + 0.181818 0.818182 + 0.750000 0.250000 + 0.357143 0.642857 + 0.400000 0.600000 + 0.625000 0.375000 + 0.250000 0.750000 + 0.461538 0.538462 + 0.250000 0.750000 + 0.333333 0.666667 + 0.272727 0.727273 + 0.428571 0.571429 + 0.166667 0.833333 + 0.600000 0.400000 + 0.750000 0.250000 + 0.583333 0.416667 + 0.769231 0.230769 + 0.769231 0.230769 + 0.545455 0.454545 + 0.470588 0.529412 + 0.454545 0.545455 + 0.555556 0.444444 + 0.714286 0.285714 + 0.384615 0.615385 + 0.428571 0.571429 + 0.636364 0.363636 + 0.583333 0.416667 + 0.384615 0.615385 + 0.357143 0.642857 + 0.571429 0.428571 + 0.642857 0.357143 + 0.636364 0.363636 + 0.714286 0.285714 + 0.230769 0.769231 + 0.333333 0.666667 + 0.428571 0.571429 + 0.533333 0.466667 + 0.625000 0.375000 + 0.444444 0.555556 + 0.357143 0.642857 + 0.555556 0.444444 + 0.500000 0.500000 + 0.333333 0.666667 + 0.384615 0.615385 + 0.600000 0.400000 + 0.333333 0.666667 + 0.700000 0.300000 + 0.500000 0.500000 + 0.545455 0.454545 + 0.800000 0.200000 + 0.625000 0.375000 + 0.250000 0.750000 + 0.500000 0.500000 + 0.500000 0.500000 + 0.500000 0.500000 + 0.666667 0.333333 + 0.666667 0.333333 + 0.692308 0.307692 + 0.400000 0.600000 + 0.692308 0.307692 + 0.666667 0.333333 + 0.555556 0.444444 + 0.666667 0.333333 + 0.222222 0.777778 + 0.562500 0.437500 + 0.500000 0.500000 + 0.666667 0.333333 + 0.230769 0.769231 + 0.555556 0.444444 + 0.307692 0.692308 + 0.800000 0.200000 + 0.400000 0.600000 + 0.666667 0.333333 + 0.285714 0.714286 + 0.500000 0.500000 + 0.444444 0.555556 + 0.555556 0.444444 + 0.272727 0.727273 + 0.600000 0.400000 + 0.428571 0.571429 + 0.400000 0.600000 + 0.526316 0.473684 + 0.333333 0.666667 + 0.750000 0.250000 + 0.636364 0.363636 + 0.333333 0.666667 + 0.750000 0.250000 + 0.500000 0.500000 + 0.818182 0.181818 + 0.375000 0.625000 + 0.333333 0.666667 + 0.625000 0.375000 + 0.583333 0.416667 + 0.230769 0.769231 + 0.769231 0.230769 + 0.800000 0.200000 + 0.636364 0.363636 + 0.384615 0.615385 + 0.562500 0.437500 + 0.727273 0.272727 + 0.250000 0.750000 + 0.600000 0.400000 + 0.538462 0.461538 + 0.750000 0.250000 + 0.428571 0.571429 + 0.300000 0.700000 + 0.555556 0.444444 + 0.692308 0.307692 + 0.230769 0.769231 + 0.333333 0.666667 + 0.454545 0.545455 + 0.666667 0.333333 + 0.583333 0.416667 + 0.454545 0.545455 + 0.562500 0.437500 + 0.666667 0.333333 + 0.500000 0.500000 + 0.250000 0.750000 + 0.625000 0.375000 + 0.588235 0.411765 + 0.818182 0.181818 + 0.500000 0.500000 + 0.250000 0.750000 + 0.636364 0.363636 + 0.181818 0.818182 + 0.333333 0.666667 + 0.411765 0.588235 + 0.500000 0.500000 + 0.428571 0.571429 + 0.230769 0.769231 + 0.333333 0.666667 + 0.562500 0.437500 + 0.666667 0.333333 + 0.600000 0.400000 + 0.333333 0.666667 + 0.500000 0.500000 + 0.333333 0.666667 + 0.714286 0.285714 + 0.333333 0.666667 + 0.714286 0.285714 + 0.454545 0.545455 + 0.181818 0.818182 + 0.400000 0.600000 + 0.750000 0.250000 + 0.636364 0.363636 + 0.300000 0.700000 + 0.222222 0.777778 + 0.200000 0.800000 + 0.777778 0.222222 + 0.500000 0.500000 + 0.384615 0.615385 + 0.411765 0.588235 + 0.818182 0.181818 + 0.357143 0.642857 + 0.588235 0.411765 + 0.285714 0.714286 + 0.562500 0.437500 + 0.529412 0.470588 + 0.466667 0.533333 + 0.454545 0.545455 + 0.800000 0.200000 + 0.571429 0.428571 + 0.250000 0.750000 + 0.500000 0.500000 + 0.400000 0.600000 + 0.444444 0.555556 + 0.600000 0.400000 + 0.500000 0.500000 + 0.200000 0.800000 + 0.642857 0.357143 + 0.666667 0.333333 + 0.600000 0.400000 + 0.250000 0.750000 + 0.500000 0.500000 + 0.600000 0.400000 + 0.300000 0.700000 + 0.363636 0.636364 + 0.727273 0.272727 + 0.250000 0.750000 + 0.500000 0.500000 + 0.666667 0.333333 + 0.615385 0.384615 + 0.642857 0.357143 + 0.473684 0.526316 + 0.437500 0.562500 + 0.545455 0.454545 + 0.411765 0.588235 + 0.466667 0.533333 + 0.666667 0.333333 + 0.333333 0.666667 + 0.562500 0.437500 + 0.700000 0.300000 + 0.500000 0.500000 + 0.473684 0.526316 + 0.357143 0.642857 + 0.571429 0.428571 + 0.416667 0.583333 + 0.555556 0.444444 + 0.833333 0.166667 + 0.727273 0.272727 + 0.181818 0.818182 + 0.750000 0.250000 + 0.200000 0.800000 + 0.470588 0.529412 + 0.583333 0.416667 + 0.625000 0.375000 + 0.800000 0.200000 + 0.400000 0.600000 + 0.437500 0.562500 + 0.400000 0.600000 + 0.444444 0.555556 + 0.454545 0.545455 + 0.181818 0.818182 + 0.615385 0.384615 + 0.533333 0.466667 + 0.428571 0.571429 + 0.625000 0.375000 + 0.777778 0.222222 + 0.333333 0.666667 + 0.588235 0.411765 + 0.285714 0.714286 + 0.500000 0.500000 + 0.636364 0.363636 + 0.428571 0.571429 + 0.727273 0.272727 + 0.500000 0.500000 + 0.285714 0.714286 + 0.818182 0.181818 + 0.250000 0.750000 + 0.555556 0.444444 + 0.181818 0.818182 + 0.727273 0.272727 + 0.529412 0.470588 + 0.625000 0.375000 + 0.555556 0.444444 + 0.777778 0.222222 + 0.714286 0.285714 + 0.727273 0.272727 + 0.300000 0.700000 + 0.411765 0.588235 + 0.222222 0.777778 + 0.800000 0.200000 + 0.642857 0.357143 + 0.769231 0.230769 + 0.562500 0.437500 + 0.600000 0.400000 + 0.400000 0.600000 + 0.600000 0.400000 + 0.461538 0.538462 + 0.500000 0.500000 + 0.461538 0.538462 + 0.750000 0.250000 + 0.307692 0.692308 + 0.444444 0.555556 + 0.400000 0.600000 + 0.666667 0.333333 + 0.727273 0.272727 + 0.250000 0.750000 + 0.666667 0.333333 + 0.500000 0.500000 + 0.473684 0.526316 + 0.727273 0.272727 + 0.444444 0.555556 + 0.428571 0.571429 + 0.285714 0.714286 + 0.500000 0.500000 + 0.470588 0.529412 + 0.500000 0.500000 + 0.363636 0.636364 + 0.428571 0.571429 + 0.615385 0.384615 + 0.500000 0.500000 + 0.555556 0.444444 + 0.500000 0.500000 + 0.250000 0.750000 + 0.642857 0.357143 + 0.400000 0.600000 + 0.411765 0.588235 + 0.250000 0.750000 + 0.700000 0.300000 + 0.500000 0.500000 + 0.416667 0.583333 + 0.692308 0.307692 + 0.500000 0.500000 + 0.357143 0.642857 + 0.750000 0.250000 + 0.181818 0.818182 + 0.166667 0.833333 + 0.250000 0.750000 + 0.714286 0.285714 + 0.769231 0.230769 + 0.666667 0.333333 + 0.714286 0.285714 + 0.333333 0.666667 + 0.285714 0.714286 + 0.750000 0.250000 + 0.166667 0.833333 + 0.500000 0.500000 + 0.466667 0.533333 + 0.714286 0.285714 + 0.545455 0.454545 + 0.166667 0.833333 + 0.428571 0.571429 + 0.750000 0.250000 + 0.307692 0.692308 + 0.428571 0.571429 + 0.818182 0.181818 + 0.375000 0.625000 + 0.625000 0.375000 + 0.250000 0.750000 + 0.700000 0.300000 + 0.300000 0.700000 + 0.625000 0.375000 + 0.642857 0.357143 + 0.428571 0.571429 + 0.500000 0.500000 + 0.777778 0.222222 + 0.444444 0.555556 + 0.333333 0.666667 + 0.428571 0.571429 + 0.307692 0.692308 + 0.333333 0.666667 + 0.166667 0.833333 + 0.571429 0.428571 + 0.333333 0.666667 + 0.500000 0.500000 + 0.538462 0.461538 + 0.250000 0.750000 + 0.416667 0.583333 + 0.500000 0.500000 + 0.500000 0.500000 + 0.625000 0.375000 + 0.473684 0.526316 + 0.375000 0.625000 + 0.470588 0.529412 + 0.454545 0.545455 + 0.500000 0.500000 + 0.333333 0.666667 + 0.500000 0.500000 + 0.363636 0.636364 + 0.600000 0.400000 + 0.166667 0.833333 + 0.769231 0.230769 + 0.588235 0.411765 + 0.642857 0.357143 + 0.636364 0.363636 + 0.833333 0.166667 + 0.166667 0.833333 + 0.470588 0.529412 + 0.700000 0.300000 + 0.700000 0.300000 + 0.666667 0.333333 + 0.714286 0.285714 + 0.384615 0.615385 + 0.500000 0.500000 + 0.777778 0.222222 + 0.454545 0.545455 + 0.500000 0.500000 + 0.181818 0.818182 + 0.526316 0.473684 + 0.700000 0.300000 + 0.777778 0.222222 + 0.529412 0.470588 + 0.714286 0.285714 + 0.428571 0.571429 + 0.500000 0.500000 + 0.588235 0.411765 + 0.571429 0.428571 + 0.750000 0.250000 + 0.500000 0.500000 + 0.666667 0.333333 + 0.363636 0.636364 + 0.571429 0.428571 + 0.454545 0.545455 + 0.444444 0.555556 + 0.250000 0.750000 + 0.363636 0.636364 + 0.272727 0.727273 + 0.333333 0.666667 + 0.615385 0.384615 + 0.615385 0.384615 + 0.333333 0.666667 + 0.583333 0.416667 + 0.166667 0.833333 + 0.428571 0.571429 + 0.400000 0.600000 + 0.454545 0.545455 + 0.500000 0.500000 + 0.714286 0.285714 + 0.500000 0.500000 + 0.800000 0.200000 + 0.500000 0.500000 + 0.500000 0.500000 + 0.300000 0.700000 + 0.454545 0.545455 + 0.416667 0.583333 + 0.615385 0.384615 + 0.600000 0.400000 + 0.357143 0.642857 + 0.454545 0.545455 + 0.230769 0.769231 + 0.428571 0.571429 + 0.500000 0.500000 + 0.562500 0.437500 + 0.555556 0.444444 + 0.571429 0.428571 + 0.750000 0.250000 + 0.166667 0.833333 + 0.285714 0.714286 + 0.400000 0.600000 + 0.461538 0.538462 + 0.333333 0.666667 + 0.555556 0.444444 + 0.416667 0.583333 + 0.466667 0.533333 + 0.333333 0.666667 + 0.444444 0.555556 + 0.375000 0.625000 + 0.642857 0.357143 + 0.727273 0.272727 + 0.470588 0.529412 + 0.363636 0.636364 + 0.714286 0.285714 + 0.666667 0.333333 + 0.411765 0.588235 + 0.250000 0.750000 + 0.437500 0.562500 + 0.500000 0.500000 + 0.400000 0.600000 + 0.400000 0.600000 + 0.428571 0.571429 + 0.222222 0.777778 + +4 + 0.454545 0.545455 + 0.363636 0.636364 + +8 + 0.285714 0.714286 + 0.625000 0.375000 + 0.400000 0.600000 + 0.727273 0.272727 + +2 + 0.526316 0.473684 + +2 + 0.454545 0.545455 + +8 + 0.529412 0.470588 + 0.500000 0.500000 + 0.538462 0.461538 + 0.500000 0.500000 + +16 + 0.526316 0.473684 + 0.571429 0.428571 + 0.562500 0.437500 + 0.230769 0.769231 + 0.333333 0.666667 + 0.750000 0.250000 + 0.333333 0.666667 + 0.600000 0.400000 + +2 + 0.714286 0.285714 + +16 + 0.230769 0.769231 + 0.454545 0.545455 + 0.571429 0.428571 + 0.777778 0.222222 + 0.466667 0.533333 + 0.250000 0.750000 + 0.384615 0.615385 + 0.571429 0.428571 + +16 + 0.666667 0.333333 + 0.555556 0.444444 + 0.363636 0.636364 + 0.833333 0.166667 + 0.400000 0.600000 + 0.818182 0.181818 + 0.692308 0.307692 + 0.692308 0.307692 + +32 + 0.533333 0.466667 + 0.400000 0.600000 + 0.666667 0.333333 + 0.333333 0.666667 + 0.588235 0.411765 + 0.363636 0.636364 + 0.470588 0.529412 + 0.500000 0.500000 + 0.636364 0.363636 + 0.400000 0.600000 + 0.636364 0.363636 + 0.428571 0.571429 + 0.500000 0.500000 + 0.714286 0.285714 + 0.272727 0.727273 + 0.357143 0.642857 + +8 + 0.250000 0.750000 + 0.285714 0.714286 + 0.583333 0.416667 + 0.571429 0.428571 + +2 + 0.375000 0.625000 + +8 + 0.666667 0.333333 + 0.300000 0.700000 + 0.529412 0.470588 + 0.473684 0.526316 + +2 + 0.500000 0.500000 + +16 + 0.666667 0.333333 + 0.200000 0.800000 + 0.500000 0.500000 + 0.500000 0.500000 + 0.666667 0.333333 + 0.714286 0.285714 + 0.470588 0.529412 + 0.533333 0.466667 + +8 + 0.307692 0.692308 + 0.470588 0.529412 + 0.333333 0.666667 + 0.333333 0.666667 + diff --git a/gtsam/discrete/tests/data/UAI/uai08_test1.uai.evid b/gtsam/discrete/tests/data/UAI/uai08_test1.uai.evid new file mode 100644 index 000000000..5ca206d95 --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test1.uai.evid @@ -0,0 +1,11 @@ +10 + 0 1 + 2 0 + 9 0 + 16 1 + 20 1 + 21 1 + 22 0 + 26 1 + 39 0 + 41 1 diff --git a/gtsam/discrete/tests/data/UAI/uai08_test1.uai.output b/gtsam/discrete/tests/data/UAI/uai08_test1.uai.output new file mode 100644 index 000000000..c376783fe --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test1.uai.output @@ -0,0 +1,3 @@ +z -2.7351873 +m 54 2 0.0 1.0 2 0.5995855 0.40041456 2 1.0 0.0 2 0.3761365 0.62386346 2 0.25656807 0.7434319 2 0.6449692 0.35503078 2 0.4957979 0.5042021 2 0.69854456 0.30145544 2 0.7 0.3 2 1.0 0.0 2 0.5303537 0.46964625 2 0.44570237 0.5542976 2 0.5 0.5 2 0.55686617 0.4431338 2 0.6284742 0.3715258 2 0.5607879 0.43921205 2 0.0 1.0 2 0.54289234 0.4571077 2 0.5770133 0.42298666 2 0.547688 0.452312 2 0.0 1.0 2 0.0 1.0 2 1.0 0.0 2 0.5760513 0.4239487 2 0.592929 0.40707102 2 0.63438964 0.36561036 2 0.0 1.0 2 0.52899235 0.47100765 2 0.5998554 0.40014458 2 0.7750039 0.22499608 2 0.50000435 0.49999565 2 0.36475798 0.63524204 2 0.44666538 0.55333465 2 0.43111995 0.56888 2 0.37207335 0.62792665 2 0.5581817 0.4418183 2 0.16809757 0.83190244 2 0.4813641 0.5186359 2 0.43732184 0.56267816 2 1.0 0.0 2 0.54721755 0.45278242 2 0.0 1.0 2 0.51865995 0.48134002 2 0.51229435 0.48770565 2 0.7142385 0.2857615 2 0.53666514 0.46333483 2 0.6171147 0.38288528 2 0.46532288 0.5346771 2 0.46330887 0.5366911 2 0.36718464 0.63281536 2 0.4735739 0.5264261 2 0.5244508 0.4755492 2 0.604569 0.395431 2 0.3945428 0.6054572 +s -11.533098 54 1 0 0 1 1 0 1 0 0 0 0 1 1 0 0 0 1 1 1 1 1 1 0 1 0 0 1 1 0 0 1 1 1 1 1 0 1 1 1 0 0 1 0 1 0 1 0 1 0 1 1 0 0 1 diff --git a/gtsam/discrete/tests/data/UAI/uai08_test2.uai b/gtsam/discrete/tests/data/UAI/uai08_test2.uai new file mode 100644 index 000000000..a75b376ed --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test2.uai @@ -0,0 +1,269 @@ +BAYES +21 +4 4 4 4 4 4 4 4 4 2 2 2 2 2 2 2 2 2 2 2 2 +21 +1 0 +1 1 +1 2 +1 3 +1 4 +1 5 +1 6 +1 7 +1 8 +3 1 0 9 +3 1 3 10 +3 5 1 11 +3 2 6 12 +3 6 4 13 +3 3 6 14 +3 5 7 15 +3 7 2 16 +3 7 3 17 +3 0 8 18 +3 3 8 19 +3 8 4 20 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.25 0.25 0.25 0.25 + +4 + 0.1 0.2 0.3 0.4 + +4 + 0.25 0.25 0.25 0.25 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + +32 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.1 0.9 + 0.9 0.1 + diff --git a/gtsam/discrete/tests/data/UAI/uai08_test2.uai.evid b/gtsam/discrete/tests/data/UAI/uai08_test2.uai.evid new file mode 100644 index 000000000..1214f3c1b --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test2.uai.evid @@ -0,0 +1,13 @@ +12 + 17 0 + 10 0 + 19 0 + 18 0 + 11 0 + 13 0 + 15 0 + 20 0 + 9 0 + 12 0 + 16 0 + 14 0 diff --git a/gtsam/discrete/tests/data/UAI/uai08_test2.uai.output b/gtsam/discrete/tests/data/UAI/uai08_test2.uai.output new file mode 100644 index 000000000..a124d2b4c --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test2.uai.output @@ -0,0 +1,3 @@ +z -5.264346 +m 21 4 0.116109975 0.20537 0.29463002 0.38389003 4 0.10865768 0.2028859 0.2971141 0.3913423 4 0.11159538 0.20386513 0.29613486 0.3884046 4 0.105094366 0.20169812 0.29830188 0.39490563 4 0.116109975 0.20537 0.29463002 0.38389003 4 0.11159538 0.20386513 0.29613486 0.3884046 4 0.10865768 0.2028859 0.2971141 0.3913423 4 0.1 0.2 0.3 0.4 4 0.10956474 0.20318824 0.29681176 0.39043528 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 2 1.0 0.0 +s -5.7635098 21 3 3 3 3 3 3 3 3 3 0 0 0 0 0 0 0 0 0 0 0 0 diff --git a/gtsam/discrete/tests/data/UAI/uai08_test3.uai b/gtsam/discrete/tests/data/UAI/uai08_test3.uai new file mode 100644 index 000000000..2abb99bc2 --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test3.uai @@ -0,0 +1,94 @@ +MARKOV +9 +4 4 4 4 4 4 4 4 4 +13 +1 7 +2 1 0 +2 1 3 +2 5 1 +2 2 6 +2 6 4 +2 3 6 +2 5 7 +2 7 2 +2 7 3 +2 0 8 +2 3 8 +2 8 4 + +4 + 0.1 0.2 0.3 0.4 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + +16 + 0.9 0.1 0.1 0.1 + 0.1 0.9 0.1 0.1 + 0.1 0.1 0.9 0.1 + 0.1 0.1 0.1 0.9 + + diff --git a/gtsam/discrete/tests/data/UAI/uai08_test3.uai.evid b/gtsam/discrete/tests/data/UAI/uai08_test3.uai.evid new file mode 100644 index 000000000..18748286e --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test3.uai.evid @@ -0,0 +1 @@ +0 diff --git a/gtsam/discrete/tests/data/UAI/uai08_test3.uai.output b/gtsam/discrete/tests/data/UAI/uai08_test3.uai.output new file mode 100644 index 000000000..1ddb8297a --- /dev/null +++ b/gtsam/discrete/tests/data/UAI/uai08_test3.uai.output @@ -0,0 +1,3 @@ +z -0.44786617 +m 9 4 0.116109975 0.20537 0.29463002 0.38389003 4 0.10865768 0.2028859 0.2971141 0.3913423 4 0.11159538 0.20386513 0.29613486 0.3884046 4 0.105094366 0.20169812 0.29830188 0.39490563 4 0.116109975 0.20537 0.29463002 0.38389003 4 0.11159538 0.20386513 0.29613486 0.3884046 4 0.10865768 0.2028859 0.2971141 0.3913423 4 0.1 0.2 0.3 0.4 4 0.10956474 0.20318824 0.29681176 0.39043528 +s -0.9470299 9 3 3 3 3 3 3 3 3 3 diff --git a/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp new file mode 100644 index 000000000..2c3d61fe8 --- /dev/null +++ b/gtsam/discrete/tests/testAlgebraicDecisionTree.cpp @@ -0,0 +1,524 @@ +/* + * @file testDecisionTree.cpp + * @brief Develop DecisionTree + * @author Frank Dellaert + * @date Mar 6, 2011 + */ + +#include +#include // make sure we have traits +// headers first to make sure no missing headers +//#define DT_NO_PRUNING +#include +#include // for convert only +#define DISABLE_TIMING +#include // for checking whether we are using boost 1.40 +#if BOOST_VERSION >= 104200 +#define BOOST_HAVE_PARSER +#endif + +#include +#include +#include +#include +#include +using namespace boost::assign; + +#include +#include + +using namespace std; +using namespace gtsam; + +/* ******************************************************************************** */ +typedef AlgebraicDecisionTree ADT; + +template class DecisionTree; +template class AlgebraicDecisionTree; + +#define DISABLE_DOT + +template +void dot(const T&f, const string& filename) { +#ifndef DISABLE_DOT + f.dot(filename); +#endif +} + +/** I can't get this to work ! + class Mul: boost::function { + inline double operator()(const double& a, const double& b) { + return a * b; + } + }; + + // If second argument of binary op is Leaf + template + typename DecisionTree::Node::Ptr DecisionTree::Choice::apply_fC_op_gL( + Cache& cache, const Leaf& gL, Mul op) const { + Ptr h(new Choice(label(), cardinality())); + BOOST_FOREACH(const NodePtr& branch, branches_) + h->push_back(branch->apply_f_op_g(cache, gL, op)); + return Unique(cache, h); + } + */ + +/* ******************************************************************************** */ +// instrumented operators +/* ******************************************************************************** */ +size_t muls = 0, adds = 0; +boost::timer timer; +void resetCounts() { + muls = 0; + adds = 0; + timer.restart(); +} +void printCounts(const string& s) { +#ifndef DISABLE_TIMING + cout << boost::format("%s: %3d muls, %3d adds, %g ms.") % s % muls % adds + % (1000 * timer.elapsed()) << endl; +#endif + resetCounts(); +} +double mul(const double& a, const double& b) { + muls++; + return a * b; +} +double add_(const double& a, const double& b) { + adds++; + return a + b; +} + +/* ******************************************************************************** */ +// test ADT +TEST(ADT, example3) +{ + // Create labels + DiscreteKey A(0,2), B(1,2), C(2,2), D(3,2), E(4,2); + + // Literals + ADT a(A, 0.5, 0.5); + ADT notb(B, 1, 0); + ADT c(C, 0.1, 0.9); + ADT d(D, 0.1, 0.9); + ADT note(E, 0.9, 0.1); + + ADT cnotb = c * notb; + dot(cnotb, "ADT-cnotb"); + +// a.print("a: "); +// cnotb.print("cnotb: "); + ADT acnotb = a * cnotb; +// acnotb.print("acnotb: "); +// acnotb.printCache("acnotb Cache:"); + + dot(acnotb, "ADT-acnotb"); + + + ADT big = apply(apply(d, note, &mul), acnotb, &add_); + dot(big, "ADT-big"); +} + +/* ******************************************************************************** */ +// Asia Bayes Network +/* ******************************************************************************** */ + +/** Convert Signature into CPT */ +ADT create(const Signature& signature) { + ADT p(signature.discreteKeysParentsFirst(), signature.cpt()); + static size_t count = 0; + const DiscreteKey& key = signature.key(); + string dotfile = (boost::format("CPT-%03d-%d") % ++count % key.first).str(); + dot(p, dotfile); + return p; +} + +/* ************************************************************************* */ +// test Asia Joint +TEST(ADT, joint) +{ + DiscreteKey A(0, 2), S(1, 2), T(2, 2), L(3, 2), B(4, 2), E(5, 2), X(6, 2), D(7, 2); + +#ifdef BOOST_HAVE_PARSER + resetCounts(); + ADT pA = create(A % "99/1"); + ADT pS = create(S % "50/50"); + ADT pT = create(T | A = "99/1 95/5"); + ADT pL = create(L | S = "99/1 90/10"); + ADT pB = create(B | S = "70/30 40/60"); + ADT pE = create((E | T, L) = "F T T T"); + ADT pX = create(X | E = "95/5 2/98"); + ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); + printCounts("Asia CPTs"); + + // Create joint + resetCounts(); + ADT joint = pA; + dot(joint, "Asia-A"); + joint = apply(joint, pS, &mul); + dot(joint, "Asia-AS"); + joint = apply(joint, pT, &mul); + dot(joint, "Asia-AST"); + joint = apply(joint, pL, &mul); + dot(joint, "Asia-ASTL"); + joint = apply(joint, pB, &mul); + dot(joint, "Asia-ASTLB"); + joint = apply(joint, pE, &mul); + dot(joint, "Asia-ASTLBE"); + joint = apply(joint, pX, &mul); + dot(joint, "Asia-ASTLBEX"); + joint = apply(joint, pD, &mul); + dot(joint, "Asia-ASTLBEXD"); + EXPECT_LONGS_EQUAL(346, muls); + printCounts("Asia joint"); + + ADT pASTL = pA; + pASTL = apply(pASTL, pS, &mul); + pASTL = apply(pASTL, pT, &mul); + pASTL = apply(pASTL, pL, &mul); + + // test combine + ADT fAa = pASTL.combine(L, &add_).combine(T, &add_).combine(S, &add_); + EXPECT(assert_equal(pA, fAa)); + ADT fAb = pASTL.combine(S, &add_).combine(T, &add_).combine(L, &add_); + EXPECT(assert_equal(pA, fAb)); +#endif +} + +/* ************************************************************************* */ +// test Inference with joint +TEST(ADT, inference) +{ + DiscreteKey A(0,2), D(1,2),// + B(2,2), L(3,2), E(4,2), S(5,2), T(6,2), X(7,2); + +#ifdef BOOST_HAVE_PARSER + resetCounts(); + ADT pA = create(A % "99/1"); + ADT pS = create(S % "50/50"); + ADT pT = create(T | A = "99/1 95/5"); + ADT pL = create(L | S = "99/1 90/10"); + ADT pB = create(B | S = "70/30 40/60"); + ADT pE = create((E | T, L) = "F T T T"); + ADT pX = create(X | E = "95/5 2/98"); + ADT pD = create((D | E, B) = "9/1 2/8 3/7 1/9"); + // printCounts("Inference CPTs"); + + // Create joint + resetCounts(); + ADT joint = pA; + dot(joint, "Joint-Product-A"); + joint = apply(joint, pS, &mul); + dot(joint, "Joint-Product-AS"); + joint = apply(joint, pT, &mul); + dot(joint, "Joint-Product-AST"); + joint = apply(joint, pL, &mul); + dot(joint, "Joint-Product-ASTL"); + joint = apply(joint, pB, &mul); + dot(joint, "Joint-Product-ASTLB"); + joint = apply(joint, pE, &mul); + dot(joint, "Joint-Product-ASTLBE"); + joint = apply(joint, pX, &mul); + dot(joint, "Joint-Product-ASTLBEX"); + joint = apply(joint, pD, &mul); + dot(joint, "Joint-Product-ASTLBEXD"); + EXPECT_LONGS_EQUAL(370, muls); // different ordering + printCounts("Asia product"); + + ADT marginal = joint; + marginal = marginal.combine(X, &add_); + dot(marginal, "Joint-Sum-ADBLEST"); + marginal = marginal.combine(T, &add_); + dot(marginal, "Joint-Sum-ADBLES"); + marginal = marginal.combine(S, &add_); + dot(marginal, "Joint-Sum-ADBLE"); + marginal = marginal.combine(E, &add_); + dot(marginal, "Joint-Sum-ADBL"); + EXPECT_LONGS_EQUAL(161, adds); + printCounts("Asia sum"); +#endif +} + +/* ************************************************************************* */ +TEST(ADT, factor_graph) +{ + DiscreteKey B(0,2), L(1,2), E(2,2), S(3,2), T(4,2), X(5,2); + +#ifdef BOOST_HAVE_PARSER + resetCounts(); + ADT pS = create(S % "50/50"); + ADT pT = create(T % "95/5"); + ADT pL = create(L | S = "99/1 90/10"); + ADT pE = create((E | T, L) = "F T T T"); + ADT pX = create(X | E = "95/5 2/98"); + ADT pD = create(B | E = "1/8 7/9"); + ADT pB = create(B | S = "70/30 40/60"); + // printCounts("Create CPTs"); + + // Create joint + resetCounts(); + ADT fg = pS; + fg = apply(fg, pT, &mul); + fg = apply(fg, pL, &mul); + fg = apply(fg, pB, &mul); + fg = apply(fg, pE, &mul); + fg = apply(fg, pX, &mul); + fg = apply(fg, pD, &mul); + dot(fg, "FactorGraph"); + EXPECT_LONGS_EQUAL(158, muls); + printCounts("Asia FG"); + + fg = fg.combine(X, &add_); + dot(fg, "Marginalized-6X"); + fg = fg.combine(T, &add_); + dot(fg, "Marginalized-5T"); + fg = fg.combine(S, &add_); + dot(fg, "Marginalized-4S"); + fg = fg.combine(E, &add_); + dot(fg, "Marginalized-3E"); + fg = fg.combine(L, &add_); + dot(fg, "Marginalized-2L"); + EXPECT(adds = 54); + printCounts("marginalize"); + + // BLESTX + + // Eliminate X + ADT fE = pX; + dot(fE, "Eliminate-01-fEX"); + fE = fE.combine(X, &add_); + dot(fE, "Eliminate-02-fE"); + printCounts("Eliminate X"); + + // Eliminate T + ADT fLE = pT; + fLE = apply(fLE, pE, &mul); + dot(fLE, "Eliminate-03-fLET"); + fLE = fLE.combine(T, &add_); + dot(fLE, "Eliminate-04-fLE"); + printCounts("Eliminate T"); + + // Eliminate S + ADT fBL = pS; + fBL = apply(fBL, pL, &mul); + fBL = apply(fBL, pB, &mul); + dot(fBL, "Eliminate-05-fBLS"); + fBL = fBL.combine(S, &add_); + dot(fBL, "Eliminate-06-fBL"); + printCounts("Eliminate S"); + + // Eliminate E + ADT fBL2 = fE; + fBL2 = apply(fBL2, fLE, &mul); + fBL2 = apply(fBL2, pD, &mul); + dot(fBL2, "Eliminate-07-fBLE"); + fBL2 = fBL2.combine(E, &add_); + dot(fBL2, "Eliminate-08-fBL2"); + printCounts("Eliminate E"); + + // Eliminate L + ADT fB = fBL; + fB = apply(fB, fBL2, &mul); + dot(fB, "Eliminate-09-fBL"); + fB = fB.combine(L, &add_); + dot(fB, "Eliminate-10-fB"); + printCounts("Eliminate L"); +#endif +} + +/* ************************************************************************* */ +// test equality +TEST(ADT, equality_noparser) +{ + DiscreteKey A(0,2), B(1,2); + Signature::Table tableA, tableB; + Signature::Row rA, rB; + rA += 80, 20; rB += 60, 40; + tableA += rA; tableB += rB; + + // Check straight equality + ADT pA1 = create(A % tableA); + ADT pA2 = create(A % tableA); + EXPECT(pA1 == pA2); // should be equal + + // Check equality after apply + ADT pB = create(B % tableB); + ADT pAB1 = apply(pA1, pB, &mul); + ADT pAB2 = apply(pB, pA1, &mul); + EXPECT(pAB2 == pAB1); +} + +/* ************************************************************************* */ +#ifdef BOOST_HAVE_PARSER +// test equality +TEST(ADT, equality_parser) +{ + DiscreteKey A(0,2), B(1,2); + // Check straight equality + ADT pA1 = create(A % "80/20"); + ADT pA2 = create(A % "80/20"); + EXPECT(pA1 == pA2); // should be equal + + // Check equality after apply + ADT pB = create(B % "60/40"); + ADT pAB1 = apply(pA1, pB, &mul); + ADT pAB2 = apply(pB, pA1, &mul); + EXPECT(pAB2 == pAB1); +} +#endif + +/* ******************************************************************************** */ +// Factor graph construction +// test constructor from strings +TEST(ADT, constructor) +{ + DiscreteKey v0(0,2), v1(1,3); + Assignment x00, x01, x02, x10, x11, x12; + x00[0] = 0, x00[1] = 0; + x01[0] = 0, x01[1] = 1; + x02[0] = 0, x02[1] = 2; + x10[0] = 1, x10[1] = 0; + x11[0] = 1, x11[1] = 1; + x12[0] = 1, x12[1] = 2; + + ADT f1(v0 & v1, "0 1 2 3 4 5"); + EXPECT_DOUBLES_EQUAL(0, f1(x00), 1e-9); + EXPECT_DOUBLES_EQUAL(1, f1(x01), 1e-9); + EXPECT_DOUBLES_EQUAL(2, f1(x02), 1e-9); + EXPECT_DOUBLES_EQUAL(3, f1(x10), 1e-9); + EXPECT_DOUBLES_EQUAL(4, f1(x11), 1e-9); + EXPECT_DOUBLES_EQUAL(5, f1(x12), 1e-9); + + ADT f2(v1 & v0, "0 1 2 3 4 5"); + EXPECT_DOUBLES_EQUAL(0, f2(x00), 1e-9); + EXPECT_DOUBLES_EQUAL(2, f2(x01), 1e-9); + EXPECT_DOUBLES_EQUAL(4, f2(x02), 1e-9); + EXPECT_DOUBLES_EQUAL(1, f2(x10), 1e-9); + EXPECT_DOUBLES_EQUAL(3, f2(x11), 1e-9); + EXPECT_DOUBLES_EQUAL(5, f2(x12), 1e-9); + + DiscreteKey z0(0,5), z1(1,4), z2(2,3), z3(3,2); + vector table(5 * 4 * 3 * 2); + double x = 0; + BOOST_FOREACH(double& t, table) + t = x++; + ADT f3(z0 & z1 & z2 & z3, table); + Assignment assignment; + assignment[0] = 0; + assignment[1] = 0; + assignment[2] = 0; + assignment[3] = 1; + EXPECT_DOUBLES_EQUAL(1, f3(assignment), 1e-9); +} + +/* ************************************************************************* */ +// test conversion to integer indices +// Only works if DiscreteKeys are binary, as size_t has binary cardinality! +TEST(ADT, conversion) +{ + DiscreteKey X(0,2), Y(1,2); + ADT fDiscreteKey(X & Y, "0.2 0.5 0.3 0.6"); + dot(fDiscreteKey, "conversion-f1"); + + std::map ordering; + ordering[0] = 5; + ordering[1] = 2; + + AlgebraicDecisionTree fIndexKey(fDiscreteKey, ordering); + // f1.print("f1"); + // f2.print("f2"); + dot(fIndexKey, "conversion-f2"); + + Assignment x00, x01, x02, x10, x11, x12; + x00[5] = 0, x00[2] = 0; + x01[5] = 0, x01[2] = 1; + x10[5] = 1, x10[2] = 0; + x11[5] = 1, x11[2] = 1; + EXPECT_DOUBLES_EQUAL(0.2, fIndexKey(x00), 1e-9); + EXPECT_DOUBLES_EQUAL(0.5, fIndexKey(x01), 1e-9); + EXPECT_DOUBLES_EQUAL(0.3, fIndexKey(x10), 1e-9); + EXPECT_DOUBLES_EQUAL(0.6, fIndexKey(x11), 1e-9); +} + +/* ******************************************************************************** */ +// test operations in elimination +TEST(ADT, elimination) +{ + DiscreteKey A(0,2), B(1,3), C(2,2); + ADT f1(A & B & C, "1 2 3 4 5 6 1 8 3 3 5 5"); + dot(f1, "elimination-f1"); + + { + // sum out lower key + ADT actualSum = f1.sum(C); + ADT expectedSum(A & B, "3 7 11 9 6 10"); + CHECK(assert_equal(expectedSum,actualSum)); + + // normalize + ADT actual = f1 / actualSum; + vector cpt; + cpt += 1.0 / 3, 2.0 / 3, 3.0 / 7, 4.0 / 7, 5.0 / 11, 6.0 / 11, // + 1.0 / 9, 8.0 / 9, 3.0 / 6, 3.0 / 6, 5.0 / 10, 5.0 / 10; + ADT expected(A & B & C, cpt); + CHECK(assert_equal(expected,actual)); + } + + { + // sum out lower 2 keys + ADT actualSum = f1.sum(C).sum(B); + ADT expectedSum(A, 21, 25); + CHECK(assert_equal(expectedSum,actualSum)); + + // normalize + ADT actual = f1 / actualSum; + vector cpt; + cpt += 1.0 / 21, 2.0 / 21, 3.0 / 21, 4.0 / 21, 5.0 / 21, 6.0 / 21, // + 1.0 / 25, 8.0 / 25, 3.0 / 25, 3.0 / 25, 5.0 / 25, 5.0 / 25; + ADT expected(A & B & C, cpt); + CHECK(assert_equal(expected,actual)); + } +} + +/* ******************************************************************************** */ +// Test non-commutative op +TEST(ADT, div) +{ + DiscreteKey A(0,2), B(1,2); + + // Literals + ADT a(A, 8, 16); + ADT b(B, 2, 4); + ADT expected_a_div_b(A & B, "4 2 8 4"); // 8/2 8/4 16/2 16/4 + ADT expected_b_div_a(A & B, "0.25 0.5 0.125 0.25"); // 2/8 4/8 2/16 4/16 + EXPECT(assert_equal(expected_a_div_b, a / b)); + EXPECT(assert_equal(expected_b_div_a, b / a)); +} + +/* ******************************************************************************** */ +// test zero shortcut +TEST(ADT, zero) +{ + DiscreteKey A(0,2), B(1,2); + + // Literals + ADT a(A, 0, 1); + ADT notb(B, 1, 0); + ADT anotb = a * notb; + // GTSAM_PRINT(anotb); + Assignment x00, x01, x10, x11; + x00[0] = 0, x00[1] = 0; + x01[0] = 0, x01[1] = 1; + x10[0] = 1, x10[1] = 0; + x11[0] = 1, x11[1] = 1; + EXPECT_DOUBLES_EQUAL(0, anotb(x00), 1e-9); + EXPECT_DOUBLES_EQUAL(0, anotb(x01), 1e-9); + EXPECT_DOUBLES_EQUAL(1, anotb(x10), 1e-9); + EXPECT_DOUBLES_EQUAL(0, anotb(x11), 1e-9); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/discrete/tests/testCSP.cpp b/gtsam/discrete/tests/testCSP.cpp new file mode 100644 index 000000000..51ddca098 --- /dev/null +++ b/gtsam/discrete/tests/testCSP.cpp @@ -0,0 +1,224 @@ +/* + * testCSP.cpp + * @brief develop code for CSP solver + * @date Feb 5, 2012 + * @author Frank Dellaert + */ + +#include +#include +#include +#include +#include + +using namespace std; +using namespace gtsam; + +/* ************************************************************************* */ +TEST_UNSAFE( BinaryAllDif, allInOne) +{ + // Create keys and ordering + size_t nrColors = 2; +// DiscreteKey ID("Idaho", nrColors), UT("Utah", nrColors), AZ("Arizona", nrColors); + DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + + // Check construction and conversion + BinaryAllDiff c1(ID, UT); + DecisionTreeFactor f1(ID & UT, "0 1 1 0"); + EXPECT(assert_equal(f1,(DecisionTreeFactor)c1)); + + // Check construction and conversion + BinaryAllDiff c2(UT, AZ); + DecisionTreeFactor f2(UT & AZ, "0 1 1 0"); + EXPECT(assert_equal(f2,(DecisionTreeFactor)c2)); + + DecisionTreeFactor f3 = f1*f2; + EXPECT(assert_equal(f3,c1*f2)); + EXPECT(assert_equal(f3,c2*f1)); +} + +/* ************************************************************************* */ +TEST_UNSAFE( CSP, allInOne) +{ + // Create keys and ordering + size_t nrColors = 2; + DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + + // Create the CSP + CSP csp; + csp.addAllDiff(ID,UT); + csp.addAllDiff(UT,AZ); + + // Check an invalid combination, with ID==UT==AZ all same color + DiscreteFactor::Values invalid; + invalid[ID.first] = 0; + invalid[UT.first] = 0; + invalid[AZ.first] = 0; + EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); + + // Check a valid combination + DiscreteFactor::Values valid; + valid[ID.first] = 0; + valid[UT.first] = 1; + valid[AZ.first] = 0; + EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); + + // Just for fun, create the product and check it + DecisionTreeFactor product = csp.product(); + // product.dot("product"); + DecisionTreeFactor expectedProduct(ID & AZ & UT, "0 1 0 0 0 0 1 0"); + EXPECT(assert_equal(expectedProduct,product)); + + // Solve + CSP::sharedValues mpe = csp.optimalAssignment(); + CSP::Values expected; + insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 1); + EXPECT(assert_equal(expected,*mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); +} + +/* ************************************************************************* */ +TEST_UNSAFE( CSP, WesternUS) +{ + // Create keys + size_t nrColors = 4; + DiscreteKey + // Create ordering according to example in ND-CSP.lyx + WA(0, nrColors), OR(3, nrColors), CA(1, nrColors),NV(2, nrColors), + ID(8, nrColors), UT(9, nrColors), AZ(10, nrColors), + MT(4, nrColors), WY(5, nrColors), CO(7, nrColors), NM(6, nrColors); + + // Create the CSP + CSP csp; + csp.addAllDiff(WA,ID); + csp.addAllDiff(WA,OR); + csp.addAllDiff(OR,ID); + csp.addAllDiff(OR,CA); + csp.addAllDiff(OR,NV); + csp.addAllDiff(CA,NV); + csp.addAllDiff(CA,AZ); + csp.addAllDiff(ID,MT); + csp.addAllDiff(ID,WY); + csp.addAllDiff(ID,UT); + csp.addAllDiff(ID,NV); + csp.addAllDiff(NV,UT); + csp.addAllDiff(NV,AZ); + csp.addAllDiff(UT,WY); + csp.addAllDiff(UT,CO); + csp.addAllDiff(UT,NM); + csp.addAllDiff(UT,AZ); + csp.addAllDiff(AZ,CO); + csp.addAllDiff(AZ,NM); + csp.addAllDiff(MT,WY); + csp.addAllDiff(WY,CO); + csp.addAllDiff(CO,NM); + + // Solve + CSP::sharedValues mpe = csp.optimalAssignment(); + // GTSAM_PRINT(*mpe); + CSP::Values expected; + insert(expected) + (WA.first,1)(CA.first,1)(NV.first,3)(OR.first,0) + (MT.first,1)(WY.first,0)(NM.first,3)(CO.first,2) + (ID.first,2)(UT.first,1)(AZ.first,0); + EXPECT(assert_equal(expected,*mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + + // Write out the dual graph for hmetis +#ifdef DUAL + VariableIndex index(csp); + index.print("index"); + ofstream os("/Users/dellaert/src/hmetis-1.5-osx-i686/US-West-dual.txt"); + index.outputMetisFormat(os); +#endif +} + +/* ************************************************************************* */ +TEST_UNSAFE( CSP, AllDiff) +{ + // Create keys and ordering + size_t nrColors = 3; + DiscreteKey ID(0, nrColors), UT(2, nrColors), AZ(1, nrColors); + + // Create the CSP + CSP csp; + vector dkeys; + dkeys += ID,UT,AZ; + csp.addAllDiff(dkeys); + csp.addSingleValue(AZ,2); + //GTSAM_PRINT(csp); + + // Check construction and conversion + SingleValue s(AZ,2); + DecisionTreeFactor f1(AZ,"0 0 1"); + EXPECT(assert_equal(f1,(DecisionTreeFactor)s)); + + // Check construction and conversion + AllDiff alldiff(dkeys); + DecisionTreeFactor actual = (DecisionTreeFactor)alldiff; +// GTSAM_PRINT(actual); +// actual.dot("actual"); + DecisionTreeFactor f2(ID & AZ & UT, + "0 0 0 0 0 1 0 1 0 0 0 1 0 0 0 1 0 0 0 1 0 1 0 0 0 0 0"); + EXPECT(assert_equal(f2,actual)); + + // Check an invalid combination, with ID==UT==AZ all same color + DiscreteFactor::Values invalid; + invalid[ID.first] = 0; + invalid[UT.first] = 1; + invalid[AZ.first] = 0; + EXPECT_DOUBLES_EQUAL(0, csp(invalid), 1e-9); + + // Check a valid combination + DiscreteFactor::Values valid; + valid[ID.first] = 0; + valid[UT.first] = 1; + valid[AZ.first] = 2; + EXPECT_DOUBLES_EQUAL(1, csp(valid), 1e-9); + + // Solve + CSP::sharedValues mpe = csp.optimalAssignment(); + CSP::Values expected; + insert(expected)(ID.first, 1)(UT.first, 0)(AZ.first, 2); + EXPECT(assert_equal(expected,*mpe)); + EXPECT_DOUBLES_EQUAL(1, csp(*mpe), 1e-9); + + // Arc-consistency + vector domains; + domains += Domain(ID), Domain(AZ), Domain(UT); + SingleValue singleValue(AZ,2); + EXPECT(singleValue.ensureArcConsistency(1,domains)); + EXPECT(alldiff.ensureArcConsistency(0,domains)); + EXPECT(!alldiff.ensureArcConsistency(1,domains)); + EXPECT(alldiff.ensureArcConsistency(2,domains)); + LONGS_EQUAL(2,domains[0].nrValues()); + LONGS_EQUAL(1,domains[1].nrValues()); + LONGS_EQUAL(2,domains[2].nrValues()); + + // Parial application, version 1 + DiscreteFactor::Values known; + known[AZ.first] = 2; + DiscreteFactor::shared_ptr reduced1 = alldiff.partiallyApply(known); + DecisionTreeFactor f3(ID & UT, "0 1 1 1 0 1 1 1 0"); + EXPECT(assert_equal(f3,reduced1->operator DecisionTreeFactor())); + DiscreteFactor::shared_ptr reduced2 = singleValue.partiallyApply(known); + DecisionTreeFactor f4(AZ, "0 0 1"); + EXPECT(assert_equal(f4,reduced2->operator DecisionTreeFactor())); + + // Parial application, version 2 + DiscreteFactor::shared_ptr reduced3 = alldiff.partiallyApply(domains); + EXPECT(assert_equal(f3,reduced3->operator DecisionTreeFactor())); + DiscreteFactor::shared_ptr reduced4 = singleValue.partiallyApply(domains); + EXPECT(assert_equal(f4,reduced4->operator DecisionTreeFactor())); + + // full arc-consistency test + csp.runArcConsistency(nrColors); +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ + diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp new file mode 100644 index 000000000..216d4b965 --- /dev/null +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -0,0 +1,226 @@ +/* + * @file testDecisionTree.cpp + * @brief Develop DecisionTree + * @author Frank Dellaert + * @author Can Erdogan + * @date Jan 30, 2012 + */ + +#include +#include +using namespace boost::assign; + +#include +#include +#include + +//#define DT_DEBUG_MEMORY +//#define DT_NO_PRUNING +#define DISABLE_DOT +#include +using namespace std; +using namespace gtsam; + +template +void dot(const T&f, const string& filename) { +#ifndef DISABLE_DOT + f.dot(filename); +#endif +} + +#define DOT(x)(dot(x,#x)) + +/* ******************************************************************************** */ +// Test string labels and int range +/* ******************************************************************************** */ + +typedef DecisionTree DT; + +struct Ring { + static inline int zero() { + return 0; + } + static inline int one() { + return 1; + } + static inline int add(const int& a, const int& b) { + return a + b; + } + static inline int mul(const int& a, const int& b) { + return a * b; + } +}; + +/* ******************************************************************************** */ +// test DT +TEST(DT, example) +{ + // Create labels + string A("A"), B("B"), C("C"); + + // create a value + Assignment x00, x01, x10, x11; + x00[A] = 0, x00[B] = 0; + x01[A] = 0, x01[B] = 1; + x10[A] = 1, x10[B] = 0; + x11[A] = 1, x11[B] = 1; + + // A + DT a(A, 0, 5); + LONGS_EQUAL(0,a(x00)) + LONGS_EQUAL(5,a(x10)) + DOT(a); + + // pruned + DT p(A, 2, 2); + LONGS_EQUAL(2,p(x00)) + LONGS_EQUAL(2,p(x10)) + DOT(p); + + // \neg B + DT notb(B, 5, 0); + LONGS_EQUAL(5,notb(x00)) + LONGS_EQUAL(5,notb(x10)) + DOT(notb); + + // apply, two nodes, in natural order + DT anotb = apply(a, notb, &Ring::mul); + LONGS_EQUAL(0,anotb(x00)) + LONGS_EQUAL(0,anotb(x01)) + LONGS_EQUAL(25,anotb(x10)) + LONGS_EQUAL(0,anotb(x11)) + DOT(anotb); + + // check pruning + DT pnotb = apply(p, notb, &Ring::mul); + LONGS_EQUAL(10,pnotb(x00)) + LONGS_EQUAL( 0,pnotb(x01)) + LONGS_EQUAL(10,pnotb(x10)) + LONGS_EQUAL( 0,pnotb(x11)) + DOT(pnotb); + + // check pruning + DT zeros = apply(DT(A, 0, 0), notb, &Ring::mul); + LONGS_EQUAL(0,zeros(x00)) + LONGS_EQUAL(0,zeros(x01)) + LONGS_EQUAL(0,zeros(x10)) + LONGS_EQUAL(0,zeros(x11)) + DOT(zeros); + + // apply, two nodes, in switched order + DT notba = apply(a, notb, &Ring::mul); + LONGS_EQUAL(0,notba(x00)) + LONGS_EQUAL(0,notba(x01)) + LONGS_EQUAL(25,notba(x10)) + LONGS_EQUAL(0,notba(x11)) + DOT(notba); + + // Test choose 0 + DT actual0 = notba.choose(A, 0); + EXPECT(assert_equal(DT(0.0), actual0)); + DOT(actual0); + + // Test choose 1 + DT actual1 = notba.choose(A, 1); + EXPECT(assert_equal(DT(B, 25, 0), actual1)); + DOT(actual1); + + // apply, two nodes at same level + DT a_and_a = apply(a, a, &Ring::mul); + LONGS_EQUAL(0,a_and_a(x00)) + LONGS_EQUAL(0,a_and_a(x01)) + LONGS_EQUAL(25,a_and_a(x10)) + LONGS_EQUAL(25,a_and_a(x11)) + DOT(a_and_a); + + // create a function on C + DT c(C, 0, 5); + + // and a model assigning stuff to C + Assignment x101; + x101[A] = 1, x101[B] = 0, x101[C] = 1; + + // mul notba with C + DT notbac = apply(notba, c, &Ring::mul); + LONGS_EQUAL(125,notbac(x101)) + DOT(notbac); + + // mul now in different order + DT acnotb = apply(apply(a, c, &Ring::mul), notb, &Ring::mul); + LONGS_EQUAL(125,acnotb(x101)) + DOT(acnotb); +} + +/* ******************************************************************************** */ +// test Conversion +enum Label { + U, V, X, Y, Z +}; +typedef DecisionTree BDT; +bool convert(const int& y) { + return y != 0; +} + +TEST(DT, conversion) +{ + // Create labels + string A("A"), B("B"); + + // apply, two nodes, in natural order + DT f1 = apply(DT(A, 0, 5), DT(B, 5, 0), &Ring::mul); + + // convert + map ordering; + ordering[A] = X; + ordering[B] = Y; + boost::function op = convert; + BDT f2(f1, ordering, op); + // f1.print("f1"); + // f2.print("f2"); + + // create a value + Assignment