diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 9ec3b0ac5..a57915e45 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -195,7 +195,7 @@ namespace gtsam { // Construct unordered_map with values std::vector> result; for (const auto& assignment : assignments) { - result.emplace_back(assignment, operator()(assignment)); + result.emplace_back(assignment, evaluate(assignment)); } return result; } diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 9a3cde96d..741715e43 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -130,16 +130,14 @@ namespace gtsam { /// @name Standard Interface /// @{ - /// Calculate probability for given values `x`, + /// Calculate probability for given values, /// is just look up in AlgebraicDecisionTree. - double evaluate(const Assignment& values) const { + virtual double evaluate(const Assignment& values) const override { return ADT::operator()(values); } - /// Evaluate probability distribution, sugar. - double operator()(const DiscreteValues& values) const override { - return ADT::operator()(values); - } + /// Disambiguate to use DiscreteFactor version. Mainly for wrapper + using DiscreteFactor::operator(); /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index f59e29285..858623301 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -169,12 +169,12 @@ class GTSAM_EXPORT DiscreteConditional } /// Evaluate, just look up in AlgebraicDecisionTree - double evaluate(const DiscreteValues& values) const { + virtual double evaluate(const Assignment& values) const override { return ADT::operator()(values); } using DecisionTreeFactor::error; ///< DiscreteValues version - using DecisionTreeFactor::operator(); ///< DiscreteValues version + using DiscreteFactor::operator(); ///< DiscreteValues version /** * @brief restrict to given *parent* values. diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 19af5bd13..2ba670004 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -92,8 +92,21 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { size_t cardinality(Key j) const { return cardinalities_.at(j); } + /** + * @brief Calculate probability for given values. + * Calls specialized evaluation under the hood. + * + * Note: Uses Assignment as it is the base class of DiscreteValues. + * + * @param values Discrete assignment. + * @return double + */ + virtual double evaluate(const Assignment& values) const = 0; + /// Find value for given assignment of values to variables - virtual double operator()(const DiscreteValues&) const = 0; + double operator()(const DiscreteValues& values) const { + return evaluate(values); + } /// Error is just -log(value) virtual double error(const DiscreteValues& values) const; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index f4e023a4d..ea51a996c 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -133,7 +133,7 @@ bool TableFactor::equals(const DiscreteFactor& other, double tol) const { } /* ************************************************************************ */ -double TableFactor::operator()(const DiscreteValues& values) const { +double TableFactor::evaluate(const Assignment& values) const { // a b c d => D * (C * (B * (a) + b) + c) + d uint64_t idx = 0, card = 1; for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { @@ -180,6 +180,7 @@ DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { for (auto i = 0; i < sparse_table_.size(); i++) { table.push_back(sparse_table_.coeff(i)); } + // NOTE(Varun): This constructor is really expensive!! DecisionTreeFactor f(dkeys, table); return f; } diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index 7015353e1..1aecc1669 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -155,14 +155,8 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { // /// @name Standard Interface // /// @{ - /// Calculate probability for given values `x`, - /// is just look up in TableFactor. - double evaluate(const DiscreteValues& values) const { - return operator()(values); - } - - /// Evaluate probability distribution, sugar. - double operator()(const DiscreteValues& values) const override; + /// Evaluate probability distribution, is just look up in TableFactor. + double evaluate(const Assignment& values) const override; /// Calculate error for DiscreteValues `x`, is -log(probability). double error(const DiscreteValues& values) const override; diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 55c8f9e22..b2e2524f8 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -61,14 +61,14 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { DecisionTreeFactor(const std::vector& keys, string table); DecisionTreeFactor(const gtsam::DiscreteConditional& c); - + void print(string s = "DecisionTreeFactor\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; size_t cardinality(gtsam::Key j) const; - + double operator()(const gtsam::DiscreteValues& values) const; gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; size_t cardinality(gtsam::Key j) const; diff --git a/gtsam_unstable/discrete/AllDiff.cpp b/gtsam_unstable/discrete/AllDiff.cpp index 2bd9e6dfd..585ca8103 100644 --- a/gtsam_unstable/discrete/AllDiff.cpp +++ b/gtsam_unstable/discrete/AllDiff.cpp @@ -26,7 +26,7 @@ void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double AllDiff::operator()(const DiscreteValues& values) const { +double AllDiff::evaluate(const Assignment& values) const { std::set taken; // record values taken by keys for (Key dkey : keys_) { size_t value = values.at(dkey); // get the value for that key diff --git a/gtsam_unstable/discrete/AllDiff.h b/gtsam_unstable/discrete/AllDiff.h index d7a63eae0..1180abad4 100644 --- a/gtsam_unstable/discrete/AllDiff.h +++ b/gtsam_unstable/discrete/AllDiff.h @@ -45,7 +45,7 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint { } /// Calculate value = expensive ! - double operator()(const DiscreteValues& values) const override; + double evaluate(const Assignment& values) const override; /// Convert into a decisiontree, can be *very* expensive ! DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/BinaryAllDiff.h b/gtsam_unstable/discrete/BinaryAllDiff.h index 18b335092..e96bfdfde 100644 --- a/gtsam_unstable/discrete/BinaryAllDiff.h +++ b/gtsam_unstable/discrete/BinaryAllDiff.h @@ -47,7 +47,7 @@ class BinaryAllDiff : public Constraint { } /// Calculate value - double operator()(const DiscreteValues& values) const override { + double evaluate(const Assignment& values) const override { return (double)(values.at(keys_[0]) != values.at(keys_[1])); } diff --git a/gtsam_unstable/discrete/Domain.cpp b/gtsam_unstable/discrete/Domain.cpp index bbbc87667..74f621dc7 100644 --- a/gtsam_unstable/discrete/Domain.cpp +++ b/gtsam_unstable/discrete/Domain.cpp @@ -30,7 +30,7 @@ string Domain::base1Str() const { } /* ************************************************************************* */ -double Domain::operator()(const DiscreteValues& values) const { +double Domain::evaluate(const Assignment& values) const { return contains(values.at(key())); } diff --git a/gtsam_unstable/discrete/Domain.h b/gtsam_unstable/discrete/Domain.h index 7f7b717c2..23a566d24 100644 --- a/gtsam_unstable/discrete/Domain.h +++ b/gtsam_unstable/discrete/Domain.h @@ -82,7 +82,7 @@ class GTSAM_UNSTABLE_EXPORT Domain : public Constraint { bool contains(size_t value) const { return values_.count(value) > 0; } /// Calculate value - double operator()(const DiscreteValues& values) const override; + double evaluate(const Assignment& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/gtsam_unstable/discrete/SingleValue.cpp b/gtsam_unstable/discrete/SingleValue.cpp index 6b78f38f5..220bc9c06 100644 --- a/gtsam_unstable/discrete/SingleValue.cpp +++ b/gtsam_unstable/discrete/SingleValue.cpp @@ -22,7 +22,7 @@ void SingleValue::print(const string& s, const KeyFormatter& formatter) const { } /* ************************************************************************* */ -double SingleValue::operator()(const DiscreteValues& values) const { +double SingleValue::evaluate(const Assignment& values) const { return (double)(values.at(keys_[0]) == value_); } diff --git a/gtsam_unstable/discrete/SingleValue.h b/gtsam_unstable/discrete/SingleValue.h index 3f7f22d6a..3df1209b8 100644 --- a/gtsam_unstable/discrete/SingleValue.h +++ b/gtsam_unstable/discrete/SingleValue.h @@ -55,7 +55,7 @@ class GTSAM_UNSTABLE_EXPORT SingleValue : public Constraint { } /// Calculate value - double operator()(const DiscreteValues& values) const override; + double evaluate(const Assignment& values) const override; /// Convert into a decisiontree DecisionTreeFactor toDecisionTreeFactor() const override; diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index 12308bb3c..a78d9c94a 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -13,9 +13,10 @@ Author: Frank Dellaert import unittest +from gtsam.utils.test_case import GtsamTestCase + from gtsam import (DecisionTreeFactor, DiscreteDistribution, DiscreteValues, Ordering) -from gtsam.utils.test_case import GtsamTestCase class TestDecisionTreeFactor(GtsamTestCase): diff --git a/python/gtsam/tests/test_DiscreteBayesTree.py b/python/gtsam/tests/test_DiscreteBayesTree.py index 2a9b6ea09..e08491fab 100644 --- a/python/gtsam/tests/test_DiscreteBayesTree.py +++ b/python/gtsam/tests/test_DiscreteBayesTree.py @@ -19,8 +19,8 @@ from gtsam.utils.test_case import GtsamTestCase import gtsam from gtsam import (DiscreteBayesNet, DiscreteBayesTreeClique, - DiscreteConditional, DiscreteFactorGraph, - DiscreteValues, Ordering) + DiscreteConditional, DiscreteFactorGraph, DiscreteValues, + Ordering) class TestDiscreteBayesNet(GtsamTestCase): diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 241a5f0be..6c9eb9aec 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -13,9 +13,10 @@ Author: Varun Agrawal import unittest -from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys from gtsam.utils.test_case import GtsamTestCase +from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys + # Some DiscreteKeys for binary variables: A = 0, 2 B = 1, 2 diff --git a/python/gtsam/tests/test_DiscreteFactorGraph.py b/python/gtsam/tests/test_DiscreteFactorGraph.py index d725ceac8..3053087b4 100644 --- a/python/gtsam/tests/test_DiscreteFactorGraph.py +++ b/python/gtsam/tests/test_DiscreteFactorGraph.py @@ -14,9 +14,12 @@ Author: Frank Dellaert import unittest import numpy as np -from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, Symbol from gtsam.utils.test_case import GtsamTestCase +from gtsam import (DecisionTreeFactor, DiscreteConditional, + DiscreteFactorGraph, DiscreteKeys, DiscreteValues, Ordering, + Symbol) + OrderingType = Ordering.OrderingType