Merge pull request #1013 from borglab/feature/remove_potentials

Remove Potentials
release/4.3a0
Fan Jiang 2022-01-07 14:09:19 -05:00 committed by GitHub
commit 4dafcc50e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 57 additions and 219 deletions

View File

@ -179,5 +179,6 @@ namespace gtsam {
};
// AlgebraicDecisionTree
template<typename T> struct traits<AlgebraicDecisionTree<T>> : public Testable<AlgebraicDecisionTree<T>> {};
}
// namespace gtsam

View File

@ -34,12 +34,13 @@ namespace gtsam {
/* ******************************************************************************** */
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys& keys,
const ADT& potentials) :
DiscreteFactor(keys.indices()), Potentials(keys, potentials) {
DiscreteFactor(keys.indices()), ADT(potentials),
cardinalities_(keys.cardinalities()) {
}
/* *************************************************************************/
DecisionTreeFactor::DecisionTreeFactor(const DiscreteConditional& c) :
DiscreteFactor(c.keys()), Potentials(c) {
DiscreteFactor(c.keys()), AlgebraicDecisionTree<Key>(c), cardinalities_(c.cardinalities_) {
}
/* ************************************************************************* */
@ -48,16 +49,24 @@ namespace gtsam {
return false;
}
else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return Potentials::equals(f, tol);
const auto& f(static_cast<const DecisionTreeFactor&>(other));
return ADT::equals(f, tol);
}
}
/* ************************************************************************* */
double DecisionTreeFactor::safe_div(const double &a, const double &b) {
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}
/* ************************************************************************* */
void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const {
cout << s;
Potentials::print("Potentials:",formatter);
ADT::print("Potentials:",formatter);
}
/* ************************************************************************* */
@ -162,20 +171,20 @@ namespace gtsam {
void DecisionTreeFactor::dot(std::ostream& os,
const KeyFormatter& keyFormatter,
bool showZero) const {
Potentials::dot(os, keyFormatter, valueFormatter, showZero);
ADT::dot(os, keyFormatter, valueFormatter, showZero);
}
/** output to graphviz format, open a file */
void DecisionTreeFactor::dot(const std::string& name,
const KeyFormatter& keyFormatter,
bool showZero) const {
Potentials::dot(name, keyFormatter, valueFormatter, showZero);
ADT::dot(name, keyFormatter, valueFormatter, showZero);
}
/** output to graphviz format string */
std::string DecisionTreeFactor::dot(const KeyFormatter& keyFormatter,
bool showZero) const {
return Potentials::dot(keyFormatter, valueFormatter, showZero);
return ADT::dot(keyFormatter, valueFormatter, showZero);
}
/* ************************************************************************* */
@ -209,5 +218,15 @@ namespace gtsam {
return ss.str();
}
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const vector<double> &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}
DecisionTreeFactor::DecisionTreeFactor(const DiscreteKeys &keys, const string &table) :
DiscreteFactor(keys.indices()), AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {
}
/* ************************************************************************* */
} // namespace gtsam

View File

@ -19,7 +19,8 @@
#pragma once
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/Potentials.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/inference/Ordering.h>
#include <boost/shared_ptr.hpp>
@ -35,7 +36,7 @@ namespace gtsam {
/**
* A discrete probabilistic factor
*/
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public Potentials {
class GTSAM_EXPORT DecisionTreeFactor: public DiscreteFactor, public AlgebraicDecisionTree<Key> {
public:
@ -43,6 +44,10 @@ namespace gtsam {
typedef DecisionTreeFactor This;
typedef DiscreteFactor Base; ///< Typedef to base class
typedef boost::shared_ptr<DecisionTreeFactor> shared_ptr;
typedef AlgebraicDecisionTree<Key> ADT;
protected:
std::map<Key,size_t> cardinalities_;
public:
@ -55,11 +60,11 @@ namespace gtsam {
/** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */
DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials);
/** Constructor from Indices and (string or doubles) */
template<class SOURCE>
DecisionTreeFactor(const DiscreteKeys& keys, SOURCE table) :
DiscreteFactor(keys.indices()), Potentials(keys, table) {
}
/** Constructor from doubles */
DecisionTreeFactor(const DiscreteKeys& keys, const std::vector<double>& table);
/** Constructor from string */
DecisionTreeFactor(const DiscreteKeys& keys, const std::string& table);
/// Single-key specialization
template <class SOURCE>
@ -71,7 +76,7 @@ namespace gtsam {
: DecisionTreeFactor(DiscreteKeys{key}, row) {}
/** Construct from a DiscreteConditional type */
DecisionTreeFactor(const DiscreteConditional& c);
explicit DecisionTreeFactor(const DiscreteConditional& c);
/// @}
/// @name Testable
@ -90,7 +95,7 @@ namespace gtsam {
/// Value is just look up in AlgebraicDecisonTree
double operator()(const DiscreteValues& values) const override {
return Potentials::operator()(values);
return ADT::operator()(values);
}
/// multiply two factors
@ -98,6 +103,10 @@ namespace gtsam {
return apply(f, ADT::Ring::mul);
}
static double safe_div(const double& a, const double& b);
size_t cardinality(Key j) const { return cardinalities_.at(j);}
/// divide by factor f (safely)
DecisionTreeFactor operator/(const DecisionTreeFactor& f) const {
return apply(f, safe_div);

View File

@ -80,7 +80,7 @@ void DiscreteConditional::print(const string& s,
}
}
cout << ")";
Potentials::print("");
ADT::print("");
cout << endl;
}

View File

@ -128,7 +128,7 @@ public:
/// Evaluate, just look up in AlgebraicDecisonTree
double operator()(const DiscreteValues& values) const override {
return Potentials::operator()(values);
return ADT::operator()(values);
}
/** Convert to a factor */

View File

@ -1,96 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file Potentials.cpp
* @date March 24, 2011
* @author Frank Dellaert
*/
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Potentials.h>
#include <boost/format.hpp>
#include <string>
using namespace std;
namespace gtsam {
/* ************************************************************************* */
double Potentials::safe_div(const double& a, const double& b) {
// cout << boost::format("%g / %g = %g\n") % a % b % ((a == 0) ? 0 : (a / b));
// The use for safe_div is when we divide the product factor by the sum
// factor. If the product or sum is zero, we accord zero probability to the
// event.
return (a == 0 || b == 0) ? 0 : (a / b);
}
/* ********************************************************************************
*/
Potentials::Potentials() : ADT(1.0) {}
/* ********************************************************************************
*/
Potentials::Potentials(const DiscreteKeys& keys, const ADT& decisionTree)
: ADT(decisionTree), cardinalities_(keys.cardinalities()) {}
/* ************************************************************************* */
bool Potentials::equals(const Potentials& other, double tol) const {
return ADT::equals(other, tol);
}
/* ************************************************************************* */
void Potentials::print(const string& s, const KeyFormatter& formatter) const {
cout << s << "\n Cardinalities: { ";
for (const std::pair<const Key,size_t>& key : cardinalities_)
cout << formatter(key.first) << ":" << key.second << ", ";
cout << "}" << endl;
ADT::print(" ", formatter);
}
//
// /* ************************************************************************* */
// template<class P>
// void Potentials::remapIndices(const P& remapping) {
// // Permute the _cardinalities (TODO: Inefficient Consider Improving)
// DiscreteKeys keys;
// map<Key, Key> ordering;
//
// // Get the original keys from cardinalities_
// for(const DiscreteKey& key: cardinalities_)
// keys & key;
//
// // Perform Permutation
// for(DiscreteKey& key: keys) {
// ordering[key.first] = remapping[key.first];
// key.first = ordering[key.first];
// }
//
// // Change *this
// AlgebraicDecisionTree<Key> permuted((*this), ordering);
// *this = permuted;
// cardinalities_ = keys.cardinalities();
// }
//
// /* ************************************************************************* */
// void Potentials::permuteWithInverse(const Permutation& inversePermutation) {
// remapIndices(inversePermutation);
// }
//
// /* ************************************************************************* */
// void Potentials::reduceWithInverse(const internal::Reduction& inverseReduction) {
// remapIndices(inverseReduction);
// }
/* ************************************************************************* */
} // namespace gtsam

View File

@ -1,97 +0,0 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file Potentials.h
* @date March 24, 2011
* @author Frank Dellaert
*/
#pragma once
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/Key.h>
#include <boost/shared_ptr.hpp>
#include <set>
namespace gtsam {
/**
* A base class for both DiscreteFactor and DiscreteConditional
*/
class GTSAM_EXPORT Potentials: public AlgebraicDecisionTree<Key> {
public:
typedef AlgebraicDecisionTree<Key> ADT;
protected:
/// Cardinality for each key, used in combine
std::map<Key,size_t> cardinalities_;
/** Constructor from ColumnIndex, and ADT */
Potentials(const ADT& potentials) :
ADT(potentials) {
}
// Safe division for probabilities
static double safe_div(const double& a, const double& b);
// // Apply either a permutation or a reduction
// template<class P>
// void remapIndices(const P& remapping);
public:
/** Default constructor for I/O */
Potentials();
/** Constructor from Indices and ADT */
Potentials(const DiscreteKeys& keys, const ADT& decisionTree);
/** Constructor from Indices and (string or doubles) */
template<class SOURCE>
Potentials(const DiscreteKeys& keys, SOURCE table) :
ADT(keys, table), cardinalities_(keys.cardinalities()) {
}
// Testable
bool equals(const Potentials& other, double tol = 1e-9) const;
void print(const std::string& s = "Potentials: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const;
size_t cardinality(Key j) const { return cardinalities_.at(j);}
// /**
// * @brief Permutes the keys in Potentials
// *
// * This permutes the Indices and performs necessary re-ordering of ADD.
// * This is virtual so that derived types e.g. DecisionTreeFactor can
// * re-implement it.
// */
// GTSAM_EXPORT virtual void permuteWithInverse(const Permutation& inversePermutation);
//
// /**
// * Apply a reduction, which is a remapping of variable indices.
// */
// GTSAM_EXPORT virtual void reduceWithInverse(const internal::Reduction& inverseReduction);
}; // Potentials
// traits
template<> struct traits<Potentials> : public Testable<Potentials> {};
template<> struct traits<Potentials::ADT> : public Testable<Potentials::ADT> {};
} // namespace gtsam

View File

@ -41,21 +41,23 @@ using namespace gtsam;
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
using ADT = AlgebraicDecisionTree<Key>;
/* ************************************************************************* */
TEST(DiscreteBayesNet, bayesNet) {
DiscreteBayesNet bayesNet;
DiscreteKey Parent(0, 2), Child(1, 2);
auto prior = boost::make_shared<DiscreteConditional>(Parent % "6/4");
CHECK(assert_equal(Potentials::ADT({Parent}, "0.6 0.4"),
(Potentials::ADT)*prior));
CHECK(assert_equal(ADT({Parent}, "0.6 0.4"),
(ADT)*prior));
bayesNet.push_back(prior);
auto conditional =
boost::make_shared<DiscreteConditional>(Child | Parent = "7/3 8/2");
EXPECT_LONGS_EQUAL(1, *(conditional->beginFrontals()));
Potentials::ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
CHECK(assert_equal(expected, (Potentials::ADT)*conditional));
ADT expected(Child & Parent, "0.7 0.8 0.3 0.2");
CHECK(assert_equal(expected, (ADT)*conditional));
bayesNet.push_back(conditional);
DiscreteFactorGraph fg(bayesNet);

View File

@ -143,7 +143,7 @@ namespace gtsam {
const Nodes& nodes() const { return nodes_; }
/** Access node by variable */
const sharedNode operator[](Key j) const { return nodes_.at(j); }
sharedClique operator[](Key j) const { return nodes_.at(j); }
/** return root cliques */
const Roots& roots() const { return roots_; }

View File

@ -130,9 +130,9 @@ void Scheduler::addStudentSpecificConstraints(size_t i,
// get all constraints then specialize to slot
size_t dummyIndex = maxNrStudents_ * 3 + maxNrStudents_;
DiscreteKey dummy(dummyIndex, nrTimeSlots());
Potentials::ADT p(dummy & areaKey,
AlgebraicDecisionTree<Key> p(dummy & areaKey,
available_); // available_ is Doodle string
Potentials::ADT q = p.choose(dummyIndex, *slot);
auto q = p.choose(dummyIndex, *slot);
CSP::add(areaKey, q);
} else {
DiscreteKeys keys {s.key_, areaKey};