gtsam/gtsam/discrete/DiscreteConditional.cpp

516 lines
18 KiB
C++
Raw Normal View History

/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file DiscreteConditional.cpp
2012-04-16 06:35:28 +08:00
* @date Feb 14, 2011
* @author Duy-Nguyen Ta
* @author Frank Dellaert
*/
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
2012-04-16 07:12:17 +08:00
#include <gtsam/discrete/DiscreteConditional.h>
2024-12-09 00:45:10 +08:00
#include <gtsam/discrete/Ring.h>
2012-04-16 07:12:17 +08:00
#include <gtsam/discrete/Signature.h>
#include <gtsam/hybrid/HybridValues.h>
2012-04-16 06:35:28 +08:00
#include <algorithm>
#include <cassert>
#include <random>
#include <set>
2012-04-16 06:35:28 +08:00
#include <stdexcept>
2020-07-12 01:16:35 +08:00
#include <string>
2022-01-09 22:10:08 +08:00
#include <utility>
#include <vector>
2012-04-16 06:35:28 +08:00
using namespace std;
using std::pair;
2022-01-09 22:10:08 +08:00
using std::stringstream;
using std::vector;
2012-04-16 06:35:28 +08:00
namespace gtsam {
// Instantiate base class
template class GTSAM_EXPORT
Conditional<DecisionTreeFactor, DiscreteConditional>;
2012-09-16 12:37:04 +08:00
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
2012-09-16 12:37:04 +08:00
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
const DecisionTreeFactor& f,
const Ordering& orderedKeys)
: BaseFactor(f), BaseConditional(nrFrontals) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
const DiscreteKeys& keys,
const ADT& potentials)
: BaseFactor(keys, potentials), BaseConditional(nrFrontals) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal)
: BaseFactor(joint / marginal),
BaseConditional(joint.size() - marginal.size()) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint,
const DecisionTreeFactor& marginal,
const Ordering& orderedKeys)
: DiscreteConditional(joint, marginal) {
keys_.clear();
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const Signature& signature)
: BaseFactor(signature.discreteKeys(), signature.cpt()),
BaseConditional(1) {}
/* ************************************************************************** */
DiscreteConditional DiscreteConditional::operator*(
const DiscreteConditional& other) const {
// Take union of frontal keys
std::set<Key> newFrontals;
for (auto&& key : this->frontals()) newFrontals.insert(key);
for (auto&& key : other.frontals()) newFrontals.insert(key);
// Check if frontals overlapped
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
throw std::invalid_argument(
"DiscreteConditional::operator* called with overlapping frontal keys.");
// Now, add cardinalities.
DiscreteKeys discreteKeys;
for (auto&& key : frontals())
discreteKeys.emplace_back(key, cardinality(key));
for (auto&& key : other.frontals())
discreteKeys.emplace_back(key, other.cardinality(key));
// Sort
std::sort(discreteKeys.begin(), discreteKeys.end());
// Add parents to set, to make them unique
std::set<DiscreteKey> parents;
for (auto&& key : this->parents())
if (!newFrontals.count(key)) parents.emplace(key, cardinality(key));
for (auto&& key : other.parents())
if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key));
// Finally, add parents to keys, in order
for (auto&& dk : parents) discreteKeys.push_back(dk);
2024-12-09 00:45:10 +08:00
ADT product = ADT::apply(other, Ring::mul);
return DiscreteConditional(newFrontals.size(), discreteKeys, product);
}
2022-01-16 05:28:34 +08:00
/* ************************************************************************** */
DiscreteConditional DiscreteConditional::marginal(Key key) const {
if (nrParents() > 0)
throw std::invalid_argument(
"DiscreteConditional::marginal: single argument version only valid for "
"fully specified joint distributions (i.e., no parents).");
// Calculate the keys as the frontal keys without the given key.
DiscreteKeys discreteKeys{{key, cardinality(key)}};
// Calculate sum
ADT adt(*this);
for (auto&& k : frontals())
if (k != key) adt = adt.sum(k, cardinality(k));
// Return new factor
return DiscreteConditional(1, discreteKeys, adt);
}
/* ************************************************************************** */
2020-07-12 01:16:35 +08:00
void DiscreteConditional::print(const string& s,
const KeyFormatter& formatter) const {
cout << s << " P( ";
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
cout << formatter(*it) << " ";
}
if (nrParents()) {
cout << "| ";
for (const_iterator it = beginParents(); it != endParents(); ++it) {
cout << formatter(*it) << " ";
}
}
cout << "):\n";
2022-01-20 04:14:22 +08:00
ADT::print("", formatter);
2020-07-12 01:16:35 +08:00
cout << endl;
}
/* ************************************************************************** */
2013-10-12 04:26:50 +08:00
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
2013-10-12 04:26:50 +08:00
return false;
} else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
2013-10-12 04:26:50 +08:00
return DecisionTreeFactor::equals(f, tol);
}
}
/* ************************************************************************** */
DiscreteConditional::ADT DiscreteConditional::choose(
const DiscreteValues& given, bool forceComplete) const {
2021-12-28 02:01:29 +08:00
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
DiscreteConditional::ADT adt(*this);
2021-12-28 02:01:29 +08:00
size_t value;
for (Key j : parents()) {
2021-12-28 02:01:29 +08:00
try {
value = given.at(j);
2021-12-28 02:01:29 +08:00
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (std::out_of_range&) {
if (forceComplete) {
given.print("parentsValues: ");
throw runtime_error(
"DiscreteConditional::choose: parent value missing");
}
}
2021-12-28 02:01:29 +08:00
}
return adt;
}
/* ************************************************************************** */
DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& given) const {
ADT adt = choose(given, false); // P(F|S=given)
2021-12-28 02:01:29 +08:00
// Collect all keys not in given.
DiscreteKeys dKeys;
2021-12-28 02:01:29 +08:00
for (Key j : frontals()) {
dKeys.emplace_back(j, this->cardinality(j));
2021-12-28 02:01:29 +08:00
}
for (size_t i = nrFrontals(); i < size(); i++) {
Key j = keys_[i];
if (given.count(j) == 0) {
dKeys.emplace_back(j, this->cardinality(j));
}
}
return std::make_shared<DiscreteConditional>(nrFrontals(), dKeys, adt);
}
/* ************************************************************************** */
2021-12-28 02:01:29 +08:00
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
const DiscreteValues& frontalValues) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the frontal variables.
ADT adt(*this);
size_t value;
for (Key j : frontals()) {
try {
value = frontalValues.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (exception&) {
frontalValues.print("frontalValues: ");
throw runtime_error("DiscreteConditional::choose: frontal value missing");
}
2021-12-28 02:01:29 +08:00
}
// Convert ADT to factor.
2021-12-28 02:01:29 +08:00
DiscreteKeys discreteKeys;
for (Key j : parents()) {
discreteKeys.emplace_back(j, this->cardinality(j));
}
return std::make_shared<DecisionTreeFactor>(discreteKeys, adt);
2021-12-28 02:01:29 +08:00
}
/* ****************************************************************************/
2021-12-28 02:01:29 +08:00
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t frontal) const {
2021-12-28 02:01:29 +08:00
if (nrFrontals() != 1)
throw std::invalid_argument(
"Single value likelihood can only be invoked on single-variable "
"conditional");
DiscreteValues values;
values.emplace(keys_[0], frontal);
2021-12-28 02:01:29 +08:00
return likelihood(values);
}
2022-01-22 07:10:47 +08:00
/* ************************************************************************** */
size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
// Initialize
2022-01-22 07:10:47 +08:00
size_t maxValue = 0;
double maxP = 0;
DiscreteValues values = parentsValues;
2022-01-22 07:10:47 +08:00
assert(nrFrontals() == 1);
Key j = firstFrontalKey();
for (size_t value = 0; value < cardinality(j); value++) {
values[j] = value;
2024-07-16 06:45:15 +08:00
double pValueS = (*this)(values);
2022-01-22 07:10:47 +08:00
// Update MPE solution if better
if (pValueS > maxP) {
maxP = pValueS;
maxValue = value;
}
}
return maxValue;
}
2022-01-22 02:18:46 +08:00
/* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
// throw if more than one frontal:
if (nrFrontals() != 1) {
throw std::invalid_argument(
"DiscreteConditional::sampleInPlace can only be called on single "
"variable conditionals");
}
Key j = firstFrontalKey();
// throw if values already contains j:
if (values->count(j) > 0) {
throw std::invalid_argument(
"DiscreteConditional::sampleInPlace: values already contains j");
}
2022-01-22 02:18:46 +08:00
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
/* ************************************************************************** */
2021-12-14 02:46:53 +08:00
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
2020-07-13 11:25:07 +08:00
static mt19937 rng(2); // random number generator
2023-02-05 06:54:08 +08:00
// Get the correct conditional distribution
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)
2020-07-13 11:25:07 +08:00
// TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) {
throw std::invalid_argument(
"DiscreteConditional::sample can only be called on single variable "
"conditionals");
}
2020-07-13 11:25:07 +08:00
Key key = firstFrontalKey();
size_t nj = cardinality(key);
vector<double> p(nj);
2021-12-14 02:46:53 +08:00
DiscreteValues frontals;
for (size_t value = 0; value < nj; value++) {
2020-07-13 11:25:07 +08:00
frontals[key] = value;
p[value] = pFS(frontals); // P(F=value|S=parentsValues)
if (p[value] == 1.0) {
return value; // shortcut exit
}
}
2020-07-13 11:25:07 +08:00
std::discrete_distribution<size_t> distribution(p.begin(), p.end());
return distribution(rng);
}
2022-01-22 06:39:06 +08:00
/* ************************************************************************** */
2021-12-29 06:49:18 +08:00
size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1)
throw std::invalid_argument(
"Single value sample() can only be invoked on single-parent "
"conditional");
DiscreteValues values;
values.emplace(keys_.back(), parent_value);
return sample(values);
}
2022-01-22 06:39:06 +08:00
/* ************************************************************************** */
2022-01-03 12:23:51 +08:00
size_t DiscreteConditional::sample() const {
if (nrParents() != 0)
throw std::invalid_argument(
"sample() can only be invoked on no-parent prior");
DiscreteValues values;
return sample(values);
}
2021-12-24 04:57:55 +08:00
/* ************************************************************************* */
vector<DiscreteValues> DiscreteConditional::frontalAssignments() const {
2022-01-09 22:10:08 +08:00
vector<pair<Key, size_t>> pairs;
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
return DiscreteValues::CartesianProduct(rpairs);
2022-01-09 22:10:08 +08:00
}
/* ************************************************************************* */
vector<DiscreteValues> DiscreteConditional::allAssignments() const {
2022-01-09 22:10:08 +08:00
vector<pair<Key, size_t>> pairs;
for (Key key : parents()) pairs.emplace_back(key, cardinalities_.at(key));
for (Key key : frontals()) pairs.emplace_back(key, cardinalities_.at(key));
vector<pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
return DiscreteValues::CartesianProduct(rpairs);
2022-01-09 22:10:08 +08:00
}
2021-12-24 04:57:55 +08:00
2022-01-09 22:10:08 +08:00
/* ************************************************************************* */
// Print out signature.
static void streamSignature(const DiscreteConditional& conditional,
const KeyFormatter& keyFormatter,
stringstream* ss) {
*ss << "P(";
2021-12-25 00:00:28 +08:00
bool first = true;
2022-01-09 22:10:08 +08:00
for (Key key : conditional.frontals()) {
if (!first) *ss << ",";
*ss << keyFormatter(key);
first = false;
}
2022-01-09 22:10:08 +08:00
if (conditional.nrParents() > 0) {
*ss << "|";
bool first = true;
for (Key parent : conditional.parents()) {
if (!first) *ss << ",";
*ss << keyFormatter(parent);
first = false;
}
}
*ss << "):";
}
/* ************************************************************************* */
std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
ss << " *";
streamSignature(*this, keyFormatter, &ss);
ss << "*\n" << std::endl;
if (nrParents() == 0) {
2022-01-09 22:10:08 +08:00
// We have no parents, call factor method.
ss << DecisionTreeFactor::markdown(keyFormatter, names);
return ss.str();
}
2022-01-09 23:16:25 +08:00
// Print out header.
2021-12-24 04:57:55 +08:00
ss << "|";
2022-01-09 22:10:08 +08:00
for (Key parent : parents()) {
2022-01-03 11:38:39 +08:00
ss << "*" << keyFormatter(parent) << "*|";
2021-12-25 00:00:28 +08:00
}
2022-01-09 23:16:25 +08:00
auto frontalAssignments = this->frontalAssignments();
for (const auto& a : frontalAssignments) {
2022-01-09 22:10:08 +08:00
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
size_t index = a.at(*it);
ss << DiscreteValues::Translate(names, *it, index);
}
2021-12-25 00:00:28 +08:00
ss << "|";
2021-12-24 04:57:55 +08:00
}
2021-12-25 00:00:28 +08:00
ss << "\n";
2021-12-24 04:57:55 +08:00
// Print out separator with alignment hints.
ss << "|";
2022-01-09 23:16:25 +08:00
size_t n = frontalAssignments.size();
for (size_t j = 0; j < nrParents() + n; j++) ss << ":-:|";
2021-12-25 00:00:28 +08:00
ss << "\n";
2021-12-24 04:57:55 +08:00
// Print out all rows.
2021-12-25 00:00:28 +08:00
size_t count = 0;
2022-01-09 22:10:08 +08:00
for (const auto& a : allAssignments()) {
2021-12-25 00:00:28 +08:00
if (count == 0) {
ss << "|";
2022-01-09 22:10:08 +08:00
for (auto&& it = beginParents(); it != endParents(); ++it) {
size_t index = a.at(*it);
ss << DiscreteValues::Translate(names, *it, index) << "|";
}
2021-12-25 00:00:28 +08:00
}
ss << operator()(a) << "|";
count = (count + 1) % n;
if (count == 0) ss << "\n";
2021-12-24 04:57:55 +08:00
}
return ss.str();
}
2022-01-09 22:10:08 +08:00
/* ************************************************************************ */
string DiscreteConditional::html(const KeyFormatter& keyFormatter,
const Names& names) const {
stringstream ss;
ss << "<div>\n<p> <i>";
streamSignature(*this, keyFormatter, &ss);
ss << "</i></p>\n";
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << DecisionTreeFactor::html(keyFormatter, names);
return ss.str();
}
// Print out preamble.
2022-01-10 06:00:41 +08:00
ss << "<table class='DiscreteConditional'>\n <thead>\n";
2022-01-09 22:10:08 +08:00
// Print out header row.
ss << " <tr>";
2022-01-09 23:16:25 +08:00
for (Key parent : parents()) {
ss << "<th><i>" << keyFormatter(parent) << "</i></th>";
}
auto frontalAssignments = this->frontalAssignments();
for (const auto& a : frontalAssignments) {
2022-01-10 00:42:56 +08:00
ss << "<th>";
2022-01-09 23:16:25 +08:00
for (auto&& it = beginFrontals(); it != endFrontals(); ++it) {
size_t index = a.at(*it);
ss << DiscreteValues::Translate(names, *it, index);
2022-01-09 23:16:25 +08:00
}
2022-01-10 00:42:56 +08:00
ss << "</th>";
2022-01-09 22:10:08 +08:00
}
2022-01-09 23:16:25 +08:00
ss << "</tr>\n";
2022-01-09 22:10:08 +08:00
// Finish header and start body.
ss << " </thead>\n <tbody>\n";
2022-01-10 00:42:56 +08:00
// Output all rows, one per assignment:
2022-01-09 23:16:25 +08:00
size_t count = 0, n = frontalAssignments.size();
for (const auto& a : allAssignments()) {
if (count == 0) {
ss << " <tr>";
for (auto&& it = beginParents(); it != endParents(); ++it) {
size_t index = a.at(*it);
ss << "<th>" << DiscreteValues::Translate(names, *it, index) << "</th>";
2022-01-09 23:16:25 +08:00
}
2022-01-09 22:10:08 +08:00
}
2022-01-09 23:16:25 +08:00
ss << "<td>" << operator()(a) << "</td>"; // value
count = (count + 1) % n;
if (count == 0) ss << "</tr>\n";
2022-01-09 22:10:08 +08:00
}
2022-01-09 23:16:25 +08:00
// Finish up
ss << " </tbody>\n</table>\n</div>";
2022-01-09 22:10:08 +08:00
return ss.str();
}
/* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->evaluate(x.discrete());
}
2024-12-31 12:02:26 +08:00
/* ************************************************************************* */
2024-12-31 13:11:02 +08:00
void DiscreteConditional::setData(const DiscreteConditional::shared_ptr& dc) {
2024-12-31 12:02:26 +08:00
this->root_ = dc->root_;
}
2024-12-31 13:16:44 +08:00
/* ************************************************************************* */
DiscreteConditional::shared_ptr DiscreteConditional::max(
const Ordering& keys) const {
auto m = *BaseFactor::max(keys);
return std::make_shared<DiscreteConditional>(m.discreteKeys().size(), m);
}
/* ************************************************************************* */
DiscreteConditional::shared_ptr DiscreteConditional::prune(
size_t maxNrAssignments) const {
return std::make_shared<DiscreteConditional>(
this->nrFrontals(), BaseFactor::prune(maxNrAssignments));
}
/* ************************************************************************* */
double DiscreteConditional::negLogConstant() const { return 0.0; }
2021-12-25 00:00:28 +08:00
/* ************************************************************************* */
2012-04-16 06:35:28 +08:00
2021-12-25 00:00:28 +08:00
} // namespace gtsam