From 7b3ce2fe3400a74ae4bd0a8eca518f27d815857f Mon Sep 17 00:00:00 2001 From: Yoonwoo Kim Date: Mon, 29 May 2023 01:17:50 +0900 Subject: [PATCH] added doc for disceteKey in .h file, formatted in Google style. --- gtsam/discrete/TableFactor.cpp | 893 ++++++++++++++++----------------- gtsam/discrete/TableFactor.h | 503 ++++++++++--------- 2 files changed, 702 insertions(+), 694 deletions(-) diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index e79f32bbc..acb59a8be 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -16,10 +16,10 @@ * @author Yoonwoo Kim */ -#include #include -#include +#include #include +#include #include #include @@ -28,528 +28,527 @@ using namespace std; namespace gtsam { - /* ************************************************************************ */ - TableFactor::TableFactor() {} +/* ************************************************************************ */ +TableFactor::TableFactor() {} - /* ************************************************************************ */ - TableFactor::TableFactor(const DiscreteKeys& dkeys, - const TableFactor& potentials) - : DiscreteFactor(dkeys.indices()), - cardinalities_(potentials .cardinalities_) { +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const TableFactor& potentials) + : DiscreteFactor(dkeys.indices()), + cardinalities_(potentials.cardinalities_) { sparse_table_ = potentials.sparse_table_; denominators_ = potentials.denominators_; sorted_dkeys_ = discreteKeys(); sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); - } +} - /* ************************************************************************ */ - TableFactor::TableFactor(const DiscreteKeys& dkeys, - const Eigen::SparseVector& table) - : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { - sparse_table_ = table; - double denom = table.size(); - for (const DiscreteKey& dkey : dkeys) { - cardinalities_.insert(dkey); - denom /= dkey.second; - denominators_.insert(std::pair(dkey.first, denom)); - } - sorted_dkeys_ = discreteKeys(); - sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); +/* ************************************************************************ */ +TableFactor::TableFactor(const DiscreteKeys& dkeys, + const Eigen::SparseVector& table) + : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { + sparse_table_ = table; + double denom = table.size(); + for (const DiscreteKey& dkey : dkeys) { + cardinalities_.insert(dkey); + denom /= dkey.second; + denominators_.insert(std::pair(dkey.first, denom)); } + sorted_dkeys_ = discreteKeys(); + sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); +} - /* ************************************************************************ */ - Eigen::SparseVector TableFactor::Convert( +/* ************************************************************************ */ +Eigen::SparseVector TableFactor::Convert( const std::vector& table) { - Eigen::SparseVector sparse_table(table.size()); - // Count number of nonzero elements in table and reserving the space. - const uint64_t nnz = std::count_if(table.begin(), table.end(), - [](uint64_t i) { return i != 0; }); - sparse_table.reserve(nnz); - for (uint64_t i = 0; i < table.size(); i++) { - if (table[i] != 0) sparse_table.insert(i) = table[i]; + Eigen::SparseVector sparse_table(table.size()); + // Count number of nonzero elements in table and reserving the space. + const uint64_t nnz = std::count_if(table.begin(), table.end(), + [](uint64_t i) { return i != 0; }); + sparse_table.reserve(nnz); + for (uint64_t i = 0; i < table.size(); i++) { + if (table[i] != 0) sparse_table.insert(i) = table[i]; + } + sparse_table.pruned(); + sparse_table.data().squeeze(); + return sparse_table; +} + +/* ************************************************************************ */ +Eigen::SparseVector TableFactor::Convert(const std::string& table) { + // Convert string to doubles. + std::vector ys; + std::istringstream iss(table); + std::copy(std::istream_iterator(iss), std::istream_iterator(), + std::back_inserter(ys)); + return Convert(ys); +} + +/* ************************************************************************ */ +bool TableFactor::equals(const DiscreteFactor& other, double tol) const { + if (!dynamic_cast(&other)) { + return false; + } else { + const auto& f(static_cast(other)); + return sparse_table_.isApprox(f.sparse_table_, tol); + } +} + +/* ************************************************************************ */ +double TableFactor::operator()(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { + if (values.find(it->first) != values.end()) { + idx += card * values.at(it->first); } - sparse_table.pruned(); - sparse_table.data().squeeze(); - return sparse_table; + card *= it->second; } + return sparse_table_.coeff(idx); +} - /* ************************************************************************ */ - Eigen::SparseVector TableFactor::Convert(const std::string& table) { - // Convert string to doubles. - std::vector ys; - std::istringstream iss(table); - std::copy(std::istream_iterator(iss), std::istream_iterator(), - std::back_inserter(ys)); - return Convert(ys); - } - - /* ************************************************************************ */ - bool TableFactor::equals(const DiscreteFactor& other, - double tol) const { - if (!dynamic_cast(&other)) { - return false; - } else { - const auto& f(static_cast(other)); - return sparse_table_.isApprox(f.sparse_table_, tol); +/* ************************************************************************ */ +double TableFactor::findValue(const DiscreteValues& values) const { + // a b c d => D * (C * (B * (a) + b) + c) + d + uint64_t idx = 0, card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (values.find(*it) != values.end()) { + idx += card * values.at(*it); } + card *= cardinality(*it); } + return sparse_table_.coeff(idx); +} - /* ************************************************************************ */ - double TableFactor::operator()(const DiscreteValues& values) const { - // a b c d => D * (C * (B * (a) + b) + c) + d - uint64_t idx = 0, card = 1; - for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { - if (values.find(it->first) != values.end()) { - idx += card * values.at(it->first); - } - card *= it->second; - } - return sparse_table_.coeff(idx); +/* ************************************************************************ */ +double TableFactor::error(const DiscreteValues& values) const { + return -log(evaluate(values)); +} +/* ************************************************************************ */ +double TableFactor::error(const HybridValues& values) const { + return error(values.discrete()); +} + +/* ************************************************************************ */ +DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { + return toDecisionTreeFactor() * f; +} + +/* ************************************************************************ */ +DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { + DiscreteKeys dkeys = discreteKeys(); + std::vector table; + for (auto i = 0; i < sparse_table_.size(); i++) { + table.push_back(sparse_table_.coeff(i)); } + DecisionTreeFactor f(dkeys, table); + return f; +} - /* ************************************************************************ */ - double TableFactor::findValue(const DiscreteValues& values) const { - // a b c d => D * (C * (B * (a) + b) + c) + d - uint64_t idx = 0, card = 1; - for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { - if (values.find(*it) != values.end()) { - idx += card * values.at(*it); - } +/* ************************************************************************ */ +TableFactor TableFactor::choose(const DiscreteValues parent_assign, + DiscreteKeys parent_keys) const { + if (parent_keys.empty()) return *this; + + // Unique representation of parent values. + uint64_t unique = 0; + uint64_t card = 1; + for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { + if (parent_assign.find(*it) != parent_assign.end()) { + unique += parent_assign.at(*it) * card; card *= cardinality(*it); } - return sparse_table_.coeff(idx); } - /* ************************************************************************ */ - double TableFactor::error(const DiscreteValues& values) const { - return -log(evaluate(values)); - } - - /* ************************************************************************ */ - double TableFactor::error(const HybridValues& values) const { - return error(values.discrete()); - } + // Find child DiscreteKeys + DiscreteKeys child_dkeys; + std::sort(parent_keys.begin(), parent_keys.end()); + std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), + parent_keys.begin(), parent_keys.end(), + std::back_inserter(child_dkeys)); - /* ************************************************************************ */ - DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { - return toDecisionTreeFactor() * f; - } + // Create child sparse table to populate. + uint64_t child_card = 1; + for (const DiscreteKey& child_dkey : child_dkeys) + child_card *= child_dkey.second; + Eigen::SparseVector child_sparse_table_(child_card); + child_sparse_table_.reserve(child_card); - /* ************************************************************************ */ - DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { - DiscreteKeys dkeys = discreteKeys(); - std::vector table; - for (auto i = 0; i < sparse_table_.size(); i++) { - table.push_back(sparse_table_.coeff(i)); + // Populate child sparse table. + for (SparseIt it(sparse_table_); it; ++it) { + // Create unique representation of parent keys + uint64_t parent_unique = uniqueRep(parent_keys, it.index()); + // Populate the table + if (parent_unique == unique) { + uint64_t idx = uniqueRep(child_dkeys, it.index()); + child_sparse_table_.insert(idx) = it.value(); } - DecisionTreeFactor f(dkeys, table); + } + + child_sparse_table_.pruned(); + child_sparse_table_.data().squeeze(); + return TableFactor(child_dkeys, child_sparse_table_); +} + +/* ************************************************************************ */ +double TableFactor::safe_div(const double& a, const double& b) { + // The use for safe_div is when we divide the product factor by the sum + // factor. If the product or sum is zero, we accord zero probability to the + // event. + return (a == 0 || b == 0) ? 0 : (a / b); +} + +/* ************************************************************************ */ +void TableFactor::print(const string& s, const KeyFormatter& formatter) const { + cout << s; + cout << " f["; + for (auto&& key : keys()) + cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); + cout << " ]" << endl; + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); + for (auto&& kv : assignment) { + cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; + } + cout << " | " << it.value() << " | " << it.index() << endl; + } + cout << "number of nnzs: " << sparse_table_.nonZeros() << endl; +} + +/* ************************************************************************ */ +TableFactor TableFactor::apply(const TableFactor& f, Binary op) const { + if (keys_.empty() && sparse_table_.nonZeros() == 0) return f; - } - - /* ************************************************************************ */ - TableFactor TableFactor::choose(const DiscreteValues parent_assign, - DiscreteKeys parent_keys) const { - if (parent_keys.empty()) return *this; - - // Unique representation of parent values. - uint64_t unique = 0; - uint64_t card = 1; - for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { - if (parent_assign.find(*it) != parent_assign.end()) { - unique += parent_assign.at(*it) * card; - card *= cardinality(*it); - } - } - - // Find child DiscreteKeys - DiscreteKeys child_dkeys; - std::sort(parent_keys.begin(), parent_keys.end()); - std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), parent_keys.begin(), - parent_keys.end(), std::back_inserter(child_dkeys)); - - // Create child sparse table to populate. - uint64_t child_card = 1; - for (const DiscreteKey& child_dkey : child_dkeys) - child_card *= child_dkey.second; - Eigen::SparseVector child_sparse_table_(child_card); - child_sparse_table_.reserve(child_card); - - // Populate child sparse table. - for (SparseIt it(sparse_table_); it; ++it) { - // Create unique representation of parent keys - uint64_t parent_unique = uniqueRep(parent_keys, it.index()); - // Populate the table - if (parent_unique == unique) { - uint64_t idx = uniqueRep(child_dkeys, it.index()); - child_sparse_table_.insert(idx) = it.value(); - } - } - - child_sparse_table_.pruned(); - child_sparse_table_.data().squeeze(); - return TableFactor(child_dkeys, child_sparse_table_); - } - - /* ************************************************************************ */ - double TableFactor::safe_div(const double& a, const double& b) { - // The use for safe_div is when we divide the product factor by the sum - // factor. If the product or sum is zero, we accord zero probability to the - // event. - return (a == 0 || b == 0) ? 0 : (a / b); - } - - /* ************************************************************************ */ - void TableFactor::print(const string& s, const KeyFormatter& formatter) const { - cout << s; - cout << " f["; - for (auto&& key : keys()) - cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); - cout << " ]" << endl; - for (SparseIt it(sparse_table_); it; ++it) { - DiscreteValues assignment = findAssignments(it.index()); - for (auto&& kv : assignment) { - cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; - } - cout << " | " << it.value() << " | " << it.index() << endl; - } - cout << "number of nnzs: " < map_f = + else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0) + return *this; + // 1. Identify keys for contract and free modes. + DiscreteKeys contract_dkeys = contractDkeys(f); + DiscreteKeys f_free_dkeys = f.freeDkeys(*this); + DiscreteKeys union_dkeys = unionDkeys(f); + // 2. Create hash table for input factor f + unordered_map map_f = f.createMap(contract_dkeys, f_free_dkeys); - // 3. Initialize multiplied factor. - uint64_t card = 1; - for (auto u_dkey : union_dkeys) card *= u_dkey.second; - Eigen::SparseVector mult_sparse_table(card); - mult_sparse_table.reserve(card); - // 3. Multiply. - for (SparseIt it(sparse_table_); it; ++it) { - uint64_t contract_unique = uniqueRep(contract_dkeys, it.index()); - if (map_f.find(contract_unique) == map_f.end()) continue; - for (auto assignVal : map_f[contract_unique]) { - uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index()); - mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second); - } + // 3. Initialize multiplied factor. + uint64_t card = 1; + for (auto u_dkey : union_dkeys) card *= u_dkey.second; + Eigen::SparseVector mult_sparse_table(card); + mult_sparse_table.reserve(card); + // 3. Multiply. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t contract_unique = uniqueRep(contract_dkeys, it.index()); + if (map_f.find(contract_unique) == map_f.end()) continue; + for (auto assignVal : map_f[contract_unique]) { + uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index()); + mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second); } - // 4. Free unused memory. - mult_sparse_table.pruned(); - mult_sparse_table.data().squeeze(); - // 5. Create union keys and return. - return TableFactor(union_dkeys, mult_sparse_table); } + // 4. Free unused memory. + mult_sparse_table.pruned(); + mult_sparse_table.data().squeeze(); + // 5. Create union keys and return. + return TableFactor(union_dkeys, mult_sparse_table); +} - /* ************************************************************************ */ - DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { - // Find contract modes. - DiscreteKeys contract; - set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), - f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), - back_inserter(contract)); - return contract; - } +/* ************************************************************************ */ +DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { + // Find contract modes. + DiscreteKeys contract; + set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(contract)); + return contract; +} - /* ************************************************************************ */ - DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { - // Find free modes. - DiscreteKeys free; - set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), - f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), - back_inserter(free)); - return free; - } +/* ************************************************************************ */ +DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { + // Find free modes. + DiscreteKeys free; + set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), + f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), + back_inserter(free)); + return free; +} - /* ************************************************************************ */ - DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { - // Find union modes. - DiscreteKeys union_dkeys; - set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), - f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), - back_inserter(union_dkeys)); - return union_dkeys; - } +/* ************************************************************************ */ +DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { + // Find union modes. + DiscreteKeys union_dkeys; + set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), f.sorted_dkeys_.begin(), + f.sorted_dkeys_.end(), back_inserter(union_dkeys)); + return union_dkeys; +} - /* ************************************************************************ */ - uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, - const DiscreteValues& f_free, const uint64_t idx) const { - uint64_t union_idx = 0, card = 1; - for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { - if (f_free.find(it->first) == f_free.end()) { - union_idx += keyValueForIndex(it->first, idx) * card; - } else { - union_idx += f_free.at(it->first) * card; - } - card *= it->second; +/* ************************************************************************ */ +uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, + const DiscreteValues& f_free, + const uint64_t idx) const { + uint64_t union_idx = 0, card = 1; + for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { + if (f_free.find(it->first) == f_free.end()) { + union_idx += keyValueForIndex(it->first, idx) * card; + } else { + union_idx += f_free.at(it->first) * card; } - return union_idx; + card *= it->second; } + return union_idx; +} - /* ************************************************************************ */ - unordered_map TableFactor::createMap( +/* ************************************************************************ */ +unordered_map TableFactor::createMap( const DiscreteKeys& contract, const DiscreteKeys& free) const { - // 1. Initialize map. - unordered_map map_f; - // 2. Iterate over nonzero elements. - for (SparseIt it(sparse_table_); it; ++it) { - // 3. Create unique representation of contract modes. - uint64_t unique_rep = uniqueRep(contract, it.index()); - // 4. Create assignment for free modes. - DiscreteValues free_assignments; - for (auto& key : free) free_assignments[key.first] - = keyValueForIndex(key.first, it.index()); - // 5. Populate map. - if (map_f.find(unique_rep) == map_f.end()) { - map_f[unique_rep] = {make_pair(free_assignments, it.value())}; - } else { - map_f[unique_rep].push_back(make_pair(free_assignments, it.value())); - } + // 1. Initialize map. + unordered_map map_f; + // 2. Iterate over nonzero elements. + for (SparseIt it(sparse_table_); it; ++it) { + // 3. Create unique representation of contract modes. + uint64_t unique_rep = uniqueRep(contract, it.index()); + // 4. Create assignment for free modes. + DiscreteValues free_assignments; + for (auto& key : free) + free_assignments[key.first] = keyValueForIndex(key.first, it.index()); + // 5. Populate map. + if (map_f.find(unique_rep) == map_f.end()) { + map_f[unique_rep] = {make_pair(free_assignments, it.value())}; + } else { + map_f[unique_rep].push_back(make_pair(free_assignments, it.value())); } - return map_f; } + return map_f; +} - /* ************************************************************************ */ - uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, const uint64_t idx) const { - if (dkeys.empty()) return 0; - uint64_t unique_rep = 0, card = 1; - for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { - unique_rep += keyValueForIndex(it->first, idx) * card; - card *= it->second; - } - return unique_rep; +/* ************************************************************************ */ +uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, + const uint64_t idx) const { + if (dkeys.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { + unique_rep += keyValueForIndex(it->first, idx) * card; + card *= it->second; } + return unique_rep; +} - /* ************************************************************************ */ - uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { - if (assignments.empty()) return 0; - uint64_t unique_rep = 0, card = 1; - for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { - unique_rep += it->second * card; - card *= cardinalities_.at(it->first); - } - return unique_rep; +/* ************************************************************************ */ +uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { + if (assignments.empty()) return 0; + uint64_t unique_rep = 0, card = 1; + for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { + unique_rep += it->second * card; + card *= cardinalities_.at(it->first); } + return unique_rep; +} - /* ************************************************************************ */ - DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { - DiscreteValues assignment; - for (Key key : keys_) { - assignment[key] = keyValueForIndex(key, idx); - } - return assignment; +/* ************************************************************************ */ +DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { + DiscreteValues assignment; + for (Key key : keys_) { + assignment[key] = keyValueForIndex(key, idx); } + return assignment; +} - /* ************************************************************************ */ - TableFactor::shared_ptr TableFactor::combine( - size_t nrFrontals, Binary op) const { - if (nrFrontals > size()) { - throw invalid_argument( - "TableFactor::combine: invalid number of frontal " - "keys " + - to_string(nrFrontals) + ", nr.keys=" + std::to_string(size())); - } - // Find remaining keys. - DiscreteKeys remain_dkeys; - uint64_t card = 1; - for (auto i = nrFrontals; i < keys_.size(); i++) { - remain_dkeys.push_back(discreteKey(i)); - card *= cardinality(keys_[i]); - } - // Create combined table. - Eigen::SparseVector combined_table(card); - combined_table.reserve(sparse_table_.nonZeros()); - // Populate combined table. - for (SparseIt it(sparse_table_); it; ++it) { - uint64_t idx = uniqueRep(remain_dkeys, it.index()); - double new_val = op(combined_table.coeff(idx), it.value()); - combined_table.coeffRef(idx) = new_val; +/* ************************************************************************ */ +TableFactor::shared_ptr TableFactor::combine(size_t nrFrontals, + Binary op) const { + if (nrFrontals > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + to_string(nrFrontals) + ", nr.keys=" + std::to_string(size())); + } + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (auto i = nrFrontals; i < keys_.size(); i++) { + remain_dkeys.push_back(discreteKey(i)); + card *= cardinality(keys_[i]); + } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; } // Free unused memory. combined_table.pruned(); combined_table.data().squeeze(); return std::make_shared(remain_dkeys, combined_table); - } +} - /* ************************************************************************ */ - TableFactor::shared_ptr TableFactor::combine( - const Ordering& frontalKeys, Binary op) const { - if (frontalKeys.size() > size()) { - throw invalid_argument( - "TableFactor::combine: invalid number of frontal " - "keys " + - std::to_string(frontalKeys.size()) + ", nr.keys=" + - std::to_string(size())); - } - // Find remaining keys. - DiscreteKeys remain_dkeys; - uint64_t card = 1; - for (Key key : keys_) { - if (std::find(frontalKeys.begin(), frontalKeys.end(), key) == - frontalKeys.end()) { - remain_dkeys.emplace_back(key, cardinality(key)); - card *= cardinality(key); - } - } - // Create combined table. - Eigen::SparseVector combined_table(card); - combined_table.reserve(sparse_table_.nonZeros()); - // Populate combined table. - for (SparseIt it(sparse_table_); it; ++it) { - uint64_t idx = uniqueRep(remain_dkeys, it.index()); - double new_val = op(combined_table.coeff(idx), it.value()); - combined_table.coeffRef(idx) = new_val; - } - // Free unused memory. - combined_table.pruned(); - combined_table.data().squeeze(); - return std::make_shared(remain_dkeys, combined_table); +/* ************************************************************************ */ +TableFactor::shared_ptr TableFactor::combine(const Ordering& frontalKeys, + Binary op) const { + if (frontalKeys.size() > size()) { + throw invalid_argument( + "TableFactor::combine: invalid number of frontal " + "keys " + + std::to_string(frontalKeys.size()) + + ", nr.keys=" + std::to_string(size())); } - - /* ************************************************************************ */ - size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { - // http://phrogz.net/lazy-cartesian-product - return (index / denominators_.at(target_key)) % cardinality(target_key); + // Find remaining keys. + DiscreteKeys remain_dkeys; + uint64_t card = 1; + for (Key key : keys_) { + if (std::find(frontalKeys.begin(), frontalKeys.end(), key) == + frontalKeys.end()) { + remain_dkeys.emplace_back(key, cardinality(key)); + card *= cardinality(key); + } } + // Create combined table. + Eigen::SparseVector combined_table(card); + combined_table.reserve(sparse_table_.nonZeros()); + // Populate combined table. + for (SparseIt it(sparse_table_); it; ++it) { + uint64_t idx = uniqueRep(remain_dkeys, it.index()); + double new_val = op(combined_table.coeff(idx), it.value()); + combined_table.coeffRef(idx) = new_val; + } + // Free unused memory. + combined_table.pruned(); + combined_table.data().squeeze(); + return std::make_shared(remain_dkeys, combined_table); +} - /* ************************************************************************ */ - std::vector> TableFactor::enumerate() - const { - // Get all possible assignments - std::vector> pairs = discreteKeys(); - // Reverse to make cartesian product output a more natural ordering. - std::vector> rpairs(pairs.rbegin(), pairs.rend()); - const auto assignments = DiscreteValues::CartesianProduct(rpairs); - // Construct unordered_map with values - std::vector> result; - for (const auto& assignment : assignments) { - result.emplace_back(assignment, operator()(assignment)); - } - return result; - } +/* ************************************************************************ */ +size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { + // http://phrogz.net/lazy-cartesian-product + return (index / denominators_.at(target_key)) % cardinality(target_key); +} - /* ************************************************************************ */ - DiscreteKeys TableFactor::discreteKeys() const { - DiscreteKeys result; - for (auto&& key : keys()) { - DiscreteKey dkey(key, cardinality(key)); - if (std::find(result.begin(), result.end(), dkey) == result.end()) { - result.push_back(dkey); - } - } - return result; +/* ************************************************************************ */ +std::vector> TableFactor::enumerate() const { + // Get all possible assignments + std::vector> pairs = discreteKeys(); + // Reverse to make cartesian product output a more natural ordering. + std::vector> rpairs(pairs.rbegin(), pairs.rend()); + const auto assignments = DiscreteValues::CartesianProduct(rpairs); + // Construct unordered_map with values + std::vector> result; + for (const auto& assignment : assignments) { + result.emplace_back(assignment, operator()(assignment)); } + return result; +} + +/* ************************************************************************ */ +DiscreteKeys TableFactor::discreteKeys() const { + DiscreteKeys result; + for (auto&& key : keys()) { + DiscreteKey dkey(key, cardinality(key)); + if (std::find(result.begin(), result.end(), dkey) == result.end()) { + result.push_back(dkey); + } + } + return result; +} + +// Print out header. +/* ************************************************************************ */ +string TableFactor::markdown(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; // Print out header. - /* ************************************************************************ */ - string TableFactor::markdown(const KeyFormatter& keyFormatter, - const Names& names) const { - stringstream ss; + ss << "|"; + for (auto& key : keys()) { + ss << keyFormatter(key) << "|"; + } + ss << "value|\n"; - // Print out header. + // Print out separator with alignment hints. + ss << "|"; + for (size_t j = 0; j < size(); j++) ss << ":-:|"; + ss << ":-:|\n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); ss << "|"; for (auto& key : keys()) { - ss << keyFormatter(key) << "|"; + size_t index = assignment.at(key); + ss << DiscreteValues::Translate(names, key, index) << "|"; } - ss << "value|\n"; - - // Print out separator with alignment hints. - ss << "|"; - for (size_t j = 0; j < size(); j++) ss << ":-:|"; - ss << ":-:|\n"; - - // Print out all rows. - for (SparseIt it(sparse_table_); it; ++it) { - DiscreteValues assignment = findAssignments(it.index()); - ss << "|"; - for (auto& key : keys()) { - size_t index = assignment.at(key); - ss << DiscreteValues::Translate(names, key, index) << "|"; - } - ss << it.value() << "|\n"; - } - return ss.str(); + ss << it.value() << "|\n"; } + return ss.str(); +} - /* ************************************************************************ */ - string TableFactor::html(const KeyFormatter& keyFormatter, - const Names& names) const { - stringstream ss; +/* ************************************************************************ */ +string TableFactor::html(const KeyFormatter& keyFormatter, + const Names& names) const { + stringstream ss; - // Print out preamble. - ss << "
\n\n \n"; + // Print out preamble. + ss << "
\n
\n \n"; - // Print out header row. + // Print out header row. + ss << " "; + for (auto& key : keys()) { + ss << ""; + } + ss << "\n"; + + // Finish header and start body. + ss << " \n \n"; + + // Print out all rows. + for (SparseIt it(sparse_table_); it; ++it) { + DiscreteValues assignment = findAssignments(it.index()); ss << " "; for (auto& key : keys()) { - ss << ""; + size_t index = assignment.at(key); + ss << ""; } - ss << "\n"; + ss << ""; // value + ss << "\n"; + } + ss << " \n
" << keyFormatter(key) << "value
" << keyFormatter(key) << "" << DiscreteValues::Translate(names, key, index) << "value
" << it.value() << "
\n
"; + return ss.str(); +} - // Finish header and start body. - ss << " \n \n"; - - // Print out all rows. - for (SparseIt it(sparse_table_); it; ++it) { - DiscreteValues assignment = findAssignments(it.index()); - ss << " "; - for (auto& key : keys()) { - size_t index = assignment.at(key); - ss << "" << DiscreteValues::Translate(names, key, index) << ""; - } - ss << "" << it.value() << ""; // value - ss << "\n"; - } - ss << " \n\n"; - return ss.str(); +/* ************************************************************************ */ +TableFactor TableFactor::prune(size_t maxNrAssignments) const { + const size_t N = maxNrAssignments; + + // Get the probabilities in the TableFactor so we can threshold. + vector> probabilities; + + // Store non-zero probabilities along with their indices in a vector. + for (SparseIt it(sparse_table_); it; ++it) { + probabilities.emplace_back(it.index(), it.value()); } - /* ************************************************************************ */ - TableFactor TableFactor::prune(size_t maxNrAssignments) const { - const size_t N = maxNrAssignments; + // The number of probabilities can be lower than max_leaves. + if (probabilities.size() <= N) return *this; - // Get the probabilities in the TableFactor so we can threshold. - vector> probabilities; - - // Store non-zero probabilities along with their indices in a vector. - for (SparseIt it(sparse_table_); it; ++it) { - probabilities.emplace_back(it.index(), it.value()); - } - - // The number of probabilities can be lower than max_leaves. - if (probabilities.size() <= N) return *this; - - // Sort the vector in descending order based on the element values. - sort(probabilities.begin(), probabilities.end(), [] ( - const std::pair& a, - const std::pair& b) { - return a.second > b.second; - }); - - // Keep the largest N probabilities in the vector. - if (probabilities.size() > N) probabilities.resize(N); + // Sort the vector in descending order based on the element values. + sort(probabilities.begin(), probabilities.end(), + [](const std::pair& a, + const std::pair& b) { + return a.second > b.second; + }); - // Create pruned sparse vector. - Eigen::SparseVector pruned_vec(sparse_table_.size()); - pruned_vec.reserve(probabilities.size()); + // Keep the largest N probabilities in the vector. + if (probabilities.size() > N) probabilities.resize(N); - // Populate pruned sparse vector. - for (const auto& prob : probabilities) { - pruned_vec.insert(prob.first) = prob.second; - } + // Create pruned sparse vector. + Eigen::SparseVector pruned_vec(sparse_table_.size()); + pruned_vec.reserve(probabilities.size()); - // Create pruned decision tree factor and return. - return TableFactor(this->discreteKeys(), pruned_vec); + // Populate pruned sparse vector. + for (const auto& prob : probabilities) { + pruned_vec.insert(prob.first) = prob.second; } - /* ************************************************************************ */ + // Create pruned decision tree factor and return. + return TableFactor(this->discreteKeys(), pruned_vec); +} + +/* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index c565cbe6b..d73dc1c9d 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -23,8 +23,8 @@ #include #include -#include #include +#include #include #include #include @@ -32,287 +32,296 @@ namespace gtsam { - class HybridValues; +class HybridValues; + +/** + * A discrete probabilistic factor optimized for sparsity. + * Uses sparse_table_ to store only the nonzero probabilities. + * Computes the assigned value for the key using the ordering which the + * nonzero probabilties are stored in. (lazy cartesian product) + * + * @ingroup discrete + */ +class GTSAM_EXPORT TableFactor : public DiscreteFactor { + protected: + /// Map of Keys and their cardinalities. + std::map cardinalities_; + /// SparseVector of nonzero probabilities. + Eigen::SparseVector sparse_table_; + + private: + /// Map of Keys and their denominators used in keyValueForIndex. + std::map denominators_; + /// Sorted DiscreteKeys to use internally. + DiscreteKeys sorted_dkeys_; /** - * A discrete probabilistic factor optimized for sparsity. - * Uses sparse_table_ to store only the nonzero probabilities. - * Computes the assigned value for the key using the ordering which the - * nonzero probabilties are stored in. (lazy cartesian product) - * - * @ingroup discrete + * @brief Uses lazy cartesian product to find nth entry in the cartesian + * product of arrays in O(1) + * Example) + * v0 | v1 | val + * 0 | 0 | 10 + * 0 | 1 | 21 + * 1 | 0 | 32 + * 1 | 1 | 43 + * keyValueForIndex(v1, 2) = 0 + * @param target_key nth entry's key to find out its assigned value + * @param index nth entry in the sparse vector + * @return TableFactor */ - class GTSAM_EXPORT TableFactor : public DiscreteFactor { - protected: - std::map cardinalities_; /// Map of Keys and their cardinalities. - Eigen::SparseVector sparse_table_; /// SparseVector of nonzero probabilities. - - private: - std::map denominators_; /// Map of Keys and their denominators used in keyValueForIndex. - DiscreteKeys sorted_dkeys_; /// Sorted DiscreteKeys to use internally. - - /** - * @brief Uses lazy cartesian product to find nth entry in the cartesian product of arrays in O(1) - * Example) - * v0 | v1 | val - * 0 | 0 | 10 - * 0 | 1 | 21 - * 1 | 0 | 32 - * 1 | 1 | 43 - * keyValueForIndex(v1, 2) = 0 - * @param target_key nth entry's key to find out its assigned value - * @param index nth entry in the sparse vector - * @return TableFactor - */ - size_t keyValueForIndex(Key target_key, uint64_t index) const; + size_t keyValueForIndex(Key target_key, uint64_t index) const; - DiscreteKey discreteKey(size_t i) const { - return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); + /** + * @brief Return ith key in keys_ as a DiscreteKey + * @param i ith key in keys_ + * @return DiscreteKey + * */ + DiscreteKey discreteKey(size_t i) const { + return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); + } + + /// Convert probability table given as doubles to SparseVector. + static Eigen::SparseVector Convert(const std::vector& table); + + /// Convert probability table given as string to SparseVector. + static Eigen::SparseVector Convert(const std::string& table); + + public: + // typedefs needed to play nice with gtsam + typedef TableFactor This; + typedef DiscreteFactor Base; ///< Typedef to base class + typedef std::shared_ptr shared_ptr; + typedef Eigen::SparseVector::InnerIterator SparseIt; + typedef std::vector> AssignValList; + using Binary = std::function; + + public: + /** The Real ring with addition and multiplication */ + struct Ring { + static inline double zero() { return 0.0; } + static inline double one() { return 1.0; } + static inline double add(const double& a, const double& b) { return a + b; } + static inline double max(const double& a, const double& b) { + return std::max(a, b); } - - /// Convert probability table given as doubles to SparseVector. - static Eigen::SparseVector Convert(const std::vector& table); - - /// Convert probability table given as string to SparseVector. - static Eigen::SparseVector Convert(const std::string& table); - - public: - // typedefs needed to play nice with gtsam - typedef TableFactor This; - typedef DiscreteFactor Base; ///< Typedef to base class - typedef std::shared_ptr shared_ptr; - typedef Eigen::SparseVector::InnerIterator SparseIt; - typedef std::vector> AssignValList; - using Binary = std::function; - - public: - /** The Real ring with addition and multiplication */ - struct Ring { - static inline double zero() { return 0.0; } - static inline double one() { return 1.0; } - static inline double add(const double& a, const double& b) { return a + b; } - static inline double max(const double& a, const double& b) { - return std::max(a, b); - } - static inline double mul(const double& a, const double& b) { return a * b; } - static inline double div(const double& a, const double& b) { - return (a == 0 || b == 0) ? 0 : (a / b); - } - static inline double id(const double& x) { return x; } - }; - - /// @name Standard Constructors - /// @{ - - /** Default constructor for I/O */ - TableFactor(); - - /** Constructor from DiscreteKeys and TableFactor */ - TableFactor(const DiscreteKeys& keys, const TableFactor& potentials); - - /** Constructor from sparse_table */ - TableFactor(const DiscreteKeys& keys, - const Eigen::SparseVector& table); - - /** Constructor from doubles */ - TableFactor(const DiscreteKeys& keys, const std::vector& table) - : TableFactor(keys, Convert(table)) {} - - /** Constructor from string */ - TableFactor(const DiscreteKeys& keys, const std::string& table) - : TableFactor(keys, Convert(table)) {} - - /// Single-key specialization - template - TableFactor(const DiscreteKey& key, SOURCE table) - : TableFactor(DiscreteKeys{key}, table) {} - - /// Single-key specialization, with vector of doubles. - TableFactor(const DiscreteKey& key, const std::vector& row) - : TableFactor(DiscreteKeys{key}, row) {} - - - /// @} - /// @name Testable - /// @{ - - /// equality - bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; - - // print - void print( - const std::string& s = "TableFactor:\n", - const KeyFormatter& formatter = DefaultKeyFormatter) const override; - - // /// @} - // /// @name Standard Interface - // /// @{ - - /// Calculate probability for given values `x`, - /// is just look up in TableFactor. - double evaluate(const DiscreteValues& values) const { - return operator()(values); + static inline double mul(const double& a, const double& b) { return a * b; } + static inline double div(const double& a, const double& b) { + return (a == 0 || b == 0) ? 0 : (a / b); } + static inline double id(const double& x) { return x; } + }; - /// Evaluate probability distribution, sugar. - double operator()(const DiscreteValues& values) const override; + /// @name Standard Constructors + /// @{ - /// Calculate error for DiscreteValues `x`, is -log(probability). - double error(const DiscreteValues& values) const; + /** Default constructor for I/O */ + TableFactor(); - /// multiply two TableFactors - TableFactor operator*(const TableFactor& f) const { - return apply(f, Ring::mul); - }; + /** Constructor from DiscreteKeys and TableFactor */ + TableFactor(const DiscreteKeys& keys, const TableFactor& potentials); - /// multiple with DecisionTreeFactor - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /** Constructor from sparse_table */ + TableFactor(const DiscreteKeys& keys, + const Eigen::SparseVector& table); - static double safe_div(const double& a, const double& b); + /** Constructor from doubles */ + TableFactor(const DiscreteKeys& keys, const std::vector& table) + : TableFactor(keys, Convert(table)) {} - size_t cardinality(Key j) const { return cardinalities_.at(j); } + /** Constructor from string */ + TableFactor(const DiscreteKeys& keys, const std::string& table) + : TableFactor(keys, Convert(table)) {} - /// divide by factor f (safely) - TableFactor operator/(const TableFactor& f) const { - return apply(f, safe_div); - } + /// Single-key specialization + template + TableFactor(const DiscreteKey& key, SOURCE table) + : TableFactor(DiscreteKeys{key}, table) {} - /// Convert into a decisiontree - DecisionTreeFactor toDecisionTreeFactor() const override; + /// Single-key specialization, with vector of doubles. + TableFactor(const DiscreteKey& key, const std::vector& row) + : TableFactor(DiscreteKeys{key}, row) {} - /// Generate TableFactor from TableFactor - // TableFactor toTableFactor() const override { return *this; } + /// @} + /// @name Testable + /// @{ - /// Create a TableFactor that is a subset of this TableFactor - TableFactor choose(const DiscreteValues assignments, - DiscreteKeys parent_keys) const; + /// equality + bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; - /// Create new factor by summing all values with the same separator values - shared_ptr sum(size_t nrFrontals) const { - return combine(nrFrontals, Ring::add); - } + // print + void print( + const std::string& s = "TableFactor:\n", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; - /// Create new factor by summing all values with the same separator values - shared_ptr sum(const Ordering& keys) const { - return combine(keys, Ring::add); - } + // /// @} + // /// @name Standard Interface + // /// @{ - /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(size_t nrFrontals) const { - return combine(nrFrontals, Ring::max); - } + /// Calculate probability for given values `x`, + /// is just look up in TableFactor. + double evaluate(const DiscreteValues& values) const { + return operator()(values); + } - /// Create new factor by maximizing over all values with the same separator. - shared_ptr max(const Ordering& keys) const { - return combine(keys, Ring::max); - } + /// Evaluate probability distribution, sugar. + double operator()(const DiscreteValues& values) const override; - /// @} - /// @name Advanced Interface - /// @{ + /// Calculate error for DiscreteValues `x`, is -log(probability). + double error(const DiscreteValues& values) const; - /** - * Apply binary operator (*this) "op" f - * @param f the second argument for op - * @param op a binary operator that operates on TableFactor - */ - TableFactor apply(const TableFactor& f, Binary op) const; + /// multiply two TableFactors + TableFactor operator*(const TableFactor& f) const { + return apply(f, Ring::mul); + }; - /// Return keys in contract mode. - DiscreteKeys contractDkeys(const TableFactor& f) const; - - /// Return keys in free mode. - DiscreteKeys freeDkeys(const TableFactor& f) const; + /// multiple with DecisionTreeFactor + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; - /// Return union of DiscreteKeys in two factors. - DiscreteKeys unionDkeys(const TableFactor& f) const; + static double safe_div(const double& a, const double& b); - /// Create unique representation of union modes. - uint64_t unionRep(const DiscreteKeys& keys, - const DiscreteValues& assign, const uint64_t idx) const; - - /// Create a hash map of input factor with assignment of contract modes as - /// keys and vector of hashed assignment of free modes and value as values. - std::unordered_map createMap( + size_t cardinality(Key j) const { return cardinalities_.at(j); } + + /// divide by factor f (safely) + TableFactor operator/(const TableFactor& f) const { + return apply(f, safe_div); + } + + /// Convert into a decisiontree + DecisionTreeFactor toDecisionTreeFactor() const override; + + /// Generate TableFactor from TableFactor + // TableFactor toTableFactor() const override { return *this; } + + /// Create a TableFactor that is a subset of this TableFactor + TableFactor choose(const DiscreteValues assignments, + DiscreteKeys parent_keys) const; + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(size_t nrFrontals) const { + return combine(nrFrontals, Ring::add); + } + + /// Create new factor by summing all values with the same separator values + shared_ptr sum(const Ordering& keys) const { + return combine(keys, Ring::add); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(size_t nrFrontals) const { + return combine(nrFrontals, Ring::max); + } + + /// Create new factor by maximizing over all values with the same separator. + shared_ptr max(const Ordering& keys) const { + return combine(keys, Ring::max); + } + + /// @} + /// @name Advanced Interface + /// @{ + + /** + * Apply binary operator (*this) "op" f + * @param f the second argument for op + * @param op a binary operator that operates on TableFactor + */ + TableFactor apply(const TableFactor& f, Binary op) const; + + /// Return keys in contract mode. + DiscreteKeys contractDkeys(const TableFactor& f) const; + + /// Return keys in free mode. + DiscreteKeys freeDkeys(const TableFactor& f) const; + + /// Return union of DiscreteKeys in two factors. + DiscreteKeys unionDkeys(const TableFactor& f) const; + + /// Create unique representation of union modes. + uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign, + const uint64_t idx) const; + + /// Create a hash map of input factor with assignment of contract modes as + /// keys and vector of hashed assignment of free modes and value as values. + std::unordered_map createMap( const DiscreteKeys& contract, const DiscreteKeys& free) const; - /// Create unique representation - uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const; - - /// Create unique representation with DiscreteValues - uint64_t uniqueRep(const DiscreteValues& assignments) const; + /// Create unique representation + uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const; - /// Find DiscreteValues for corresponding index. - DiscreteValues findAssignments(const uint64_t idx) const; - - /// Find value for corresponding DiscreteValues. - double findValue(const DiscreteValues& values) const; + /// Create unique representation with DiscreteValues + uint64_t uniqueRep(const DiscreteValues& assignments) const; - /** - * Combine frontal variables using binary operator "op" - * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on TableFactor - * @return shared pointer to newly created TableFactor - */ - shared_ptr combine(size_t nrFrontals, Binary op) const; + /// Find DiscreteValues for corresponding index. + DiscreteValues findAssignments(const uint64_t idx) const; - /** - * Combine frontal variables in an Ordering using binary operator "op" - * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on TableFactor - * @return shared pointer to newly created TableFactor - */ - shared_ptr combine(const Ordering& keys, Binary op) const; + /// Find value for corresponding DiscreteValues. + double findValue(const DiscreteValues& values) const; - /// Enumerate all values into a map from values to double. - std::vector> enumerate() const; + /** + * Combine frontal variables using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(size_t nrFrontals, Binary op) const; - /// Return all the discrete keys associated with this factor. - DiscreteKeys discreteKeys() const; + /** + * Combine frontal variables in an Ordering using binary operator "op" + * @param nrFrontals nr. of frontal to combine variables in this factor + * @param op a binary operator that operates on TableFactor + * @return shared pointer to newly created TableFactor + */ + shared_ptr combine(const Ordering& keys, Binary op) const; - /** - * @brief Prune the decision tree of discrete variables. - * - * Pruning will set the values to be "pruned" to 0 indicating a 0 - * probability. An assignment is pruned if it is not in the top - * `maxNrAssignments` values. - * - * A violation can occur if there are more - * duplicate values than `maxNrAssignments`. A violation here is the need to - * un-prune the decision tree (e.g. all assignment values are 1.0). We could - * have another case where some subset of duplicates exist (e.g. for a tree - * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is - * not a violation since the for `maxNrAssignments=5` the top values are (1, - * 0.8). - * - * @param maxNrAssignments The maximum number of assignments to keep. - * @return TableFactor - */ - TableFactor prune(size_t maxNrAssignments) const; + /// Enumerate all values into a map from values to double. + std::vector> enumerate() const; - /// @} - /// @name Wrapper support - /// @{ + /// Return all the discrete keys associated with this factor. + DiscreteKeys discreteKeys() const; - /** - * @brief Render as markdown table - * - * @param keyFormatter GTSAM-style Key formatter. - * @param names optional, category names corresponding to choices. - * @return std::string a markdown string. - */ - std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const Names& names = {}) const override; + /** + * @brief Prune the decision tree of discrete variables. + * + * Pruning will set the values to be "pruned" to 0 indicating a 0 + * probability. An assignment is pruned if it is not in the top + * `maxNrAssignments` values. + * + * A violation can occur if there are more + * duplicate values than `maxNrAssignments`. A violation here is the need to + * un-prune the decision tree (e.g. all assignment values are 1.0). We could + * have another case where some subset of duplicates exist (e.g. for a tree + * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is + * not a violation since the for `maxNrAssignments=5` the top values are (1, + * 0.8). + * + * @param maxNrAssignments The maximum number of assignments to keep. + * @return TableFactor + */ + TableFactor prune(size_t maxNrAssignments) const; - /** - * @brief Render as html table - * - * @param keyFormatter GTSAM-style Key formatter. - * @param names optional, category names corresponding to choices. - * @return std::string a html string. - */ - std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, - const Names& names = {}) const override; + /// @} + /// @name Wrapper support + /// @{ + + /** + * @brief Render as markdown table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a markdown string. + */ + std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; + + /** + * @brief Render as html table + * + * @param keyFormatter GTSAM-style Key formatter. + * @param names optional, category names corresponding to choices. + * @return std::string a html string. + */ + std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, + const Names& names = {}) const override; /// @} /// @name HybridValues methods. @@ -325,7 +334,7 @@ namespace gtsam { double error(const HybridValues& values) const override; /// @} - }; +}; // traits template <>