diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 768a37ab4..7aed00c57 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -141,6 +141,7 @@ namespace gtsam { for (auto& key : keys()) { pairs.emplace_back(key, cardinalities_.at(key)); } + // Reverse to make cartesianProduct output a more natural ordering. std::vector> rpairs(pairs.rbegin(), pairs.rend()); const auto assignments = cartesianProduct(rpairs); diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 328af1ca3..5279b2b8c 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -107,7 +107,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional, try { value = parentsValues.at(j); adt = adt.choose(j, value); // ADT keeps getting smaller. - } catch (exception&) { + } catch (std::out_of_range&) { parentsValues.print("parentsValues: "); throw runtime_error("DiscreteConditional::choose: parent value missing"); }; @@ -251,7 +251,11 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const { ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues) // TODO(Duy): only works for one key now, seems horribly slow this way - assert(nrFrontals() == 1); + if (nrFrontals() != 1) { + throw std::invalid_argument( + "DiscreteConditional::sample can only be called on single variable " + "conditionals"); + } Key key = firstFrontalKey(); size_t nj = cardinality(key); vector p(nj); diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index c72f076d8..ea7f3de32 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -85,16 +85,6 @@ public: DiscreteConditional(const DiscreteKey& key, const std::string& spec) : DiscreteConditional(Signature(key, {}, spec)) {} - /// Single-parent specialization - DiscreteConditional(const DiscreteKey& key, const std::string& spec, - const DiscreteKey& parent1) - : DiscreteConditional(Signature(key, {parent1}, spec)) {} - - /// Two-parent specialization - DiscreteConditional(const DiscreteKey& key, const std::string& spec, - const DiscreteKey& parent1, const DiscreteKey& parent2) - : DiscreteConditional(Signature(key, {parent1, parent2}, spec)) {} - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 3b39374cb..3437a80a0 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -57,13 +57,10 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(); DiscreteConditional(size_t nFrontals, const gtsam::DecisionTreeFactor& f); DiscreteConditional(const gtsam::DiscreteKey& key, string spec); - DiscreteConditional(const gtsam::DiscreteKey& key, string spec, - const gtsam::DiscreteKey& parent1); - DiscreteConditional(const gtsam::DiscreteKey& key, string spec, - const gtsam::DiscreteKey& parent1, - const gtsam::DiscreteKey& parent2); DiscreteConditional(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents, string spec); + DiscreteConditional(const gtsam::DiscreteKey& key, + const std::vector& parents, string spec); DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal); DiscreteConditional(const gtsam::DecisionTreeFactor& joint, @@ -109,13 +106,10 @@ class DiscreteBayesNet { DiscreteBayesNet(); void add(const gtsam::DiscreteConditional& s); void add(const gtsam::DiscreteKey& key, string spec); - void add(const gtsam::DiscreteKey& key, string spec, - const gtsam::DiscreteKey& parent1); - void add(const gtsam::DiscreteKey& key, string spec, - const gtsam::DiscreteKey& parent1, - const gtsam::DiscreteKey& parent2); void add(const gtsam::DiscreteKey& key, const gtsam::DiscreteKeys& parents, string spec); + void add(const gtsam::DiscreteKey& key, + const std::vector& parents, string spec); bool empty() const; size_t size() const; gtsam::KeySet keys() const; diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index 706cdf93d..bdd5a0546 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -57,14 +57,14 @@ class TestDiscreteBayesNet(GtsamTestCase): asia.add(Asia, "99/1") asia.add(Smoking, "50/50") - asia.add(Tuberculosis, "99/1 95/5", Asia) - asia.add(LungCancer, "99/1 90/10", Smoking) - asia.add(Bronchitis, "70/30 40/60", Smoking) + asia.add(Tuberculosis, [Asia], "99/1 95/5") + asia.add(LungCancer, [Smoking], "99/1 90/10") + asia.add(Bronchitis, [Smoking], "70/30 40/60") - asia.add(Either, "F T T T", Tuberculosis, LungCancer) + asia.add(Either, [Tuberculosis, LungCancer], "F T T T") - asia.add(XRay, "95/5 2/98", Either) - asia.add(Dyspnea, "9/1 2/8 3/7 1/9", Either, Bronchitis) + asia.add(XRay, [Either], "95/5 2/98") + asia.add(Dyspnea, [Either, Bronchitis], "9/1 2/8 3/7 1/9") # Convert to factor graph fg = DiscreteFactorGraph(asia) diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index d87734de9..b1ed4fe69 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -14,20 +14,10 @@ Author: Frank Dellaert import unittest from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, - DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, - Ordering) + DiscreteConditional, DiscreteFactorGraph, Ordering) from gtsam.utils.test_case import GtsamTestCase -def P(*args): - """ Create a DiscreteKeys instances from a variable number of DiscreteKey pairs.""" - # TODO: We can make life easier by providing variable argument functions in C++ itself. - dks = DiscreteKeys() - for key in args: - dks.push_back(key) - return dks - - class TestDiscreteBayesNet(GtsamTestCase): """Tests for Discrete Bayes Nets.""" @@ -40,25 +30,25 @@ class TestDiscreteBayesNet(GtsamTestCase): # Create thin-tree Bayesnet. bayesNet = DiscreteBayesNet() - bayesNet.add(keys[0], P(keys[8], keys[12]), "2/3 1/4 3/2 4/1") - bayesNet.add(keys[1], P(keys[8], keys[12]), "4/1 2/3 3/2 1/4") - bayesNet.add(keys[2], P(keys[9], keys[12]), "1/4 8/2 2/3 4/1") - bayesNet.add(keys[3], P(keys[9], keys[12]), "1/4 2/3 3/2 4/1") + bayesNet.add(keys[0], [keys[8], keys[12]], "2/3 1/4 3/2 4/1") + bayesNet.add(keys[1], [keys[8], keys[12]], "4/1 2/3 3/2 1/4") + bayesNet.add(keys[2], [keys[9], keys[12]], "1/4 8/2 2/3 4/1") + bayesNet.add(keys[3], [keys[9], keys[12]], "1/4 2/3 3/2 4/1") - bayesNet.add(keys[4], P(keys[10], keys[13]), "2/3 1/4 3/2 4/1") - bayesNet.add(keys[5], P(keys[10], keys[13]), "4/1 2/3 3/2 1/4") - bayesNet.add(keys[6], P(keys[11], keys[13]), "1/4 3/2 2/3 4/1") - bayesNet.add(keys[7], P(keys[11], keys[13]), "1/4 2/3 3/2 4/1") + bayesNet.add(keys[4], [keys[10], keys[13]], "2/3 1/4 3/2 4/1") + bayesNet.add(keys[5], [keys[10], keys[13]], "4/1 2/3 3/2 1/4") + bayesNet.add(keys[6], [keys[11], keys[13]], "1/4 3/2 2/3 4/1") + bayesNet.add(keys[7], [keys[11], keys[13]], "1/4 2/3 3/2 4/1") - bayesNet.add(keys[8], P(keys[12], keys[14]), "T 1/4 3/2 4/1") - bayesNet.add(keys[9], P(keys[12], keys[14]), "4/1 2/3 F 1/4") - bayesNet.add(keys[10], P(keys[13], keys[14]), "1/4 3/2 2/3 4/1") - bayesNet.add(keys[11], P(keys[13], keys[14]), "1/4 2/3 3/2 4/1") + bayesNet.add(keys[8], [keys[12], keys[14]], "T 1/4 3/2 4/1") + bayesNet.add(keys[9], [keys[12], keys[14]], "4/1 2/3 F 1/4") + bayesNet.add(keys[10], [keys[13], keys[14]], "1/4 3/2 2/3 4/1") + bayesNet.add(keys[11], [keys[13], keys[14]], "1/4 2/3 3/2 4/1") - bayesNet.add(keys[12], P(keys[14]), "3/1 3/1") - bayesNet.add(keys[13], P(keys[14]), "1/3 3/1") + bayesNet.add(keys[12], [keys[14]], "3/1 3/1") + bayesNet.add(keys[13], [keys[14]], "1/3 3/1") - bayesNet.add(keys[14], P(), "1/3") + bayesNet.add(keys[14], "1/3") # Create a factor graph out of the Bayes net. factorGraph = DiscreteFactorGraph(bayesNet) diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 0cd02ce6a..44d25461f 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -23,7 +23,7 @@ class TestDiscreteConditional(GtsamTestCase): def test_likelihood(self): X = (0, 2) Y = (1, 3) - conditional = DiscreteConditional(X, "2/8 4/6 5/5", Y) + conditional = DiscreteConditional(X, [Y], "2/8 4/6 5/5") actual0 = conditional.likelihood(0) expected0 = DecisionTreeFactor(Y, "0.2 0.4 0.5")