diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index ad81d9eb1..f90af56dd 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -70,18 +70,6 @@ namespace gtsam { DecisionTreeFactor(const DiscreteKey& key, const std::vector& row) : DecisionTreeFactor(DiscreteKeys{key}, row) {} - /// Two-key specialization - template - DecisionTreeFactor(const DiscreteKey& key1, const DiscreteKey& key2, - SOURCE table) - : DecisionTreeFactor({key1, key2}, table) {} - - /// Three-key specialization - template - DecisionTreeFactor(const DiscreteKey& key1, const DiscreteKey& key2, - const DiscreteKey& key3, SOURCE table) - : DecisionTreeFactor({key1, key2, key3}, table) {} - /** Construct from a DiscreteConditional type */ DecisionTreeFactor(const DiscreteConditional& c); diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 3437a80a0..da3179a25 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -32,16 +32,16 @@ class DiscreteFactor { #include virtual class DecisionTreeFactor : gtsam::DiscreteFactor { DecisionTreeFactor(); - DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + DecisionTreeFactor(const gtsam::DiscreteKey& key, const std::vector& spec); - DecisionTreeFactor(const gtsam::DiscreteKey& key, const std::string& spec); - DecisionTreeFactor(const gtsam::DiscreteKey& key1, - const gtsam::DiscreteKey& key2, const std::string& spec); - DecisionTreeFactor(const gtsam::DiscreteKey& key1, - const gtsam::DiscreteKey& key2, - const gtsam::DiscreteKey& key3, const std::string& spec); + DecisionTreeFactor(const gtsam::DiscreteKey& key, string table); + + DecisionTreeFactor(const gtsam::DiscreteKeys& keys, string table); + DecisionTreeFactor(const std::vector& keys, string table); + DecisionTreeFactor(const gtsam::DiscreteConditional& c); + void print(string s = "DecisionTreeFactor\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; @@ -174,12 +174,13 @@ class DotWriter { class DiscreteFactorGraph { DiscreteFactorGraph(); DiscreteFactorGraph(const gtsam::DiscreteBayesNet& bayesNet); - - void add(const gtsam::DiscreteKey& j, const std::vector& spec); + void add(const gtsam::DiscreteKey& j, string table); - void add(const gtsam::DiscreteKey& j1, const gtsam::DiscreteKey& j2, string table); + void add(const gtsam::DiscreteKey& j, const std::vector& spec); + void add(const gtsam::DiscreteKeys& keys, string table); - + void add(const std::vector& keys, string table); + bool empty() const; size_t size() const; gtsam::KeySet keys() const; diff --git a/gtsam_unstable/discrete/Scheduler.cpp b/gtsam_unstable/discrete/Scheduler.cpp index 36c1ddda5..e34613c3b 100644 --- a/gtsam_unstable/discrete/Scheduler.cpp +++ b/gtsam_unstable/discrete/Scheduler.cpp @@ -133,10 +133,10 @@ void Scheduler::addStudentSpecificConstraints(size_t i, Potentials::ADT p(dummy & areaKey, available_); // available_ is Doodle string Potentials::ADT q = p.choose(dummyIndex, *slot); - DiscreteFactor::shared_ptr f(new DecisionTreeFactor(areaKey, q)); - CSP::push_back(f); + CSP::add(areaKey, q); } else { - CSP::add(s.key_, areaKey, available_); // available_ is Doodle string + DiscreteKeys keys {s.key_, areaKey}; + CSP::add(keys, available_); // available_ is Doodle string } } diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index 586d1d142..12a60d5cb 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -23,7 +23,7 @@ class TestDecisionTreeFactor(GtsamTestCase): def setUp(self): A = (12, 3) B = (5, 2) - self.factor = DecisionTreeFactor(A, B, "1 2 3 4 5 6") + self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6") def test_enumerate(self): actual = self.factor.enumerate() diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index dc2c7a4f5..1ba145e09 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -36,7 +36,7 @@ class TestDiscreteFactorGraph(GtsamTestCase): graph.add(P2, "0.9 0.6") # Add a binary factor - graph.add(P1, P2, "4 1 10 4") + graph.add([P1, P2], "4 1 10 4") # Instantiate Values assignment = DiscreteValues() @@ -85,8 +85,8 @@ class TestDiscreteFactorGraph(GtsamTestCase): # A simple factor graph (A)-fAC-(C)-fBC-(B) # with smoothness priors graph = DiscreteFactorGraph() - graph.add(A, C, "3 1 1 3") - graph.add(C, B, "3 1 1 3") + graph.add([A, C], "3 1 1 3") + graph.add([C, B], "3 1 1 3") # Test optimization expectedValues = DiscreteValues() @@ -105,8 +105,8 @@ class TestDiscreteFactorGraph(GtsamTestCase): # Create Factor graph graph = DiscreteFactorGraph() - graph.add(C, A, "0.2 0.8 0.3 0.7") - graph.add(C, B, "0.1 0.9 0.4 0.6") + graph.add([C, A], "0.2 0.8 0.3 0.7") + graph.add([C, B], "0.1 0.9 0.4 0.6") actualMPE = graph.optimize()