added doc for disceteKey in .h file, formatted in Google style.
parent
361f9fa391
commit
7b3ce2fe34
|
|
@ -16,10 +16,10 @@
|
||||||
* @author Yoonwoo Kim
|
* @author Yoonwoo Kim
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
|
||||||
#include <gtsam/base/FastSet.h>
|
#include <gtsam/base/FastSet.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/TableFactor.h>
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <boost/format.hpp>
|
#include <boost/format.hpp>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
@ -84,8 +84,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
bool TableFactor::equals(const DiscreteFactor& other,
|
bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
|
||||||
double tol) const {
|
|
||||||
if (!dynamic_cast<const TableFactor*>(&other)) {
|
if (!dynamic_cast<const TableFactor*>(&other)) {
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -105,7 +104,6 @@ namespace gtsam {
|
||||||
card *= it->second;
|
card *= it->second;
|
||||||
}
|
}
|
||||||
return sparse_table_.coeff(idx);
|
return sparse_table_.coeff(idx);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
@ -165,8 +163,9 @@ namespace gtsam {
|
||||||
// Find child DiscreteKeys
|
// Find child DiscreteKeys
|
||||||
DiscreteKeys child_dkeys;
|
DiscreteKeys child_dkeys;
|
||||||
std::sort(parent_keys.begin(), parent_keys.end());
|
std::sort(parent_keys.begin(), parent_keys.end());
|
||||||
std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), parent_keys.begin(),
|
std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
|
||||||
parent_keys.end(), std::back_inserter(child_dkeys));
|
parent_keys.begin(), parent_keys.end(),
|
||||||
|
std::back_inserter(child_dkeys));
|
||||||
|
|
||||||
// Create child sparse table to populate.
|
// Create child sparse table to populate.
|
||||||
uint64_t child_card = 1;
|
uint64_t child_card = 1;
|
||||||
|
|
@ -274,15 +273,15 @@ namespace gtsam {
|
||||||
DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const {
|
DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const {
|
||||||
// Find union modes.
|
// Find union modes.
|
||||||
DiscreteKeys union_dkeys;
|
DiscreteKeys union_dkeys;
|
||||||
set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(),
|
set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), f.sorted_dkeys_.begin(),
|
||||||
f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
|
f.sorted_dkeys_.end(), back_inserter(union_dkeys));
|
||||||
back_inserter(union_dkeys));
|
|
||||||
return union_dkeys;
|
return union_dkeys;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys,
|
uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys,
|
||||||
const DiscreteValues& f_free, const uint64_t idx) const {
|
const DiscreteValues& f_free,
|
||||||
|
const uint64_t idx) const {
|
||||||
uint64_t union_idx = 0, card = 1;
|
uint64_t union_idx = 0, card = 1;
|
||||||
for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) {
|
for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) {
|
||||||
if (f_free.find(it->first) == f_free.end()) {
|
if (f_free.find(it->first) == f_free.end()) {
|
||||||
|
|
@ -306,8 +305,8 @@ namespace gtsam {
|
||||||
uint64_t unique_rep = uniqueRep(contract, it.index());
|
uint64_t unique_rep = uniqueRep(contract, it.index());
|
||||||
// 4. Create assignment for free modes.
|
// 4. Create assignment for free modes.
|
||||||
DiscreteValues free_assignments;
|
DiscreteValues free_assignments;
|
||||||
for (auto& key : free) free_assignments[key.first]
|
for (auto& key : free)
|
||||||
= keyValueForIndex(key.first, it.index());
|
free_assignments[key.first] = keyValueForIndex(key.first, it.index());
|
||||||
// 5. Populate map.
|
// 5. Populate map.
|
||||||
if (map_f.find(unique_rep) == map_f.end()) {
|
if (map_f.find(unique_rep) == map_f.end()) {
|
||||||
map_f[unique_rep] = {make_pair(free_assignments, it.value())};
|
map_f[unique_rep] = {make_pair(free_assignments, it.value())};
|
||||||
|
|
@ -319,7 +318,8 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, const uint64_t idx) const {
|
uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys,
|
||||||
|
const uint64_t idx) const {
|
||||||
if (dkeys.empty()) return 0;
|
if (dkeys.empty()) return 0;
|
||||||
uint64_t unique_rep = 0, card = 1;
|
uint64_t unique_rep = 0, card = 1;
|
||||||
for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) {
|
for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) {
|
||||||
|
|
@ -350,8 +350,8 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::shared_ptr TableFactor::combine(
|
TableFactor::shared_ptr TableFactor::combine(size_t nrFrontals,
|
||||||
size_t nrFrontals, Binary op) const {
|
Binary op) const {
|
||||||
if (nrFrontals > size()) {
|
if (nrFrontals > size()) {
|
||||||
throw invalid_argument(
|
throw invalid_argument(
|
||||||
"TableFactor::combine: invalid number of frontal "
|
"TableFactor::combine: invalid number of frontal "
|
||||||
|
|
@ -381,14 +381,14 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
TableFactor::shared_ptr TableFactor::combine(
|
TableFactor::shared_ptr TableFactor::combine(const Ordering& frontalKeys,
|
||||||
const Ordering& frontalKeys, Binary op) const {
|
Binary op) const {
|
||||||
if (frontalKeys.size() > size()) {
|
if (frontalKeys.size() > size()) {
|
||||||
throw invalid_argument(
|
throw invalid_argument(
|
||||||
"TableFactor::combine: invalid number of frontal "
|
"TableFactor::combine: invalid number of frontal "
|
||||||
"keys " +
|
"keys " +
|
||||||
std::to_string(frontalKeys.size()) + ", nr.keys=" +
|
std::to_string(frontalKeys.size()) +
|
||||||
std::to_string(size()));
|
", nr.keys=" + std::to_string(size()));
|
||||||
}
|
}
|
||||||
// Find remaining keys.
|
// Find remaining keys.
|
||||||
DiscreteKeys remain_dkeys;
|
DiscreteKeys remain_dkeys;
|
||||||
|
|
@ -422,8 +422,7 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate()
|
std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
|
||||||
const {
|
|
||||||
// Get all possible assignments
|
// Get all possible assignments
|
||||||
std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
|
std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
|
||||||
// Reverse to make cartesian product output a more natural ordering.
|
// Reverse to make cartesian product output a more natural ordering.
|
||||||
|
|
@ -529,8 +528,8 @@ namespace gtsam {
|
||||||
if (probabilities.size() <= N) return *this;
|
if (probabilities.size() <= N) return *this;
|
||||||
|
|
||||||
// Sort the vector in descending order based on the element values.
|
// Sort the vector in descending order based on the element values.
|
||||||
sort(probabilities.begin(), probabilities.end(), [] (
|
sort(probabilities.begin(), probabilities.end(),
|
||||||
const std::pair<Eigen::Index, double>& a,
|
[](const std::pair<Eigen::Index, double>& a,
|
||||||
const std::pair<Eigen::Index, double>& b) {
|
const std::pair<Eigen::Index, double>& b) {
|
||||||
return a.second > b.second;
|
return a.second > b.second;
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -23,8 +23,8 @@
|
||||||
|
|
||||||
#include <Eigen/Sparse>
|
#include <Eigen/Sparse>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
@ -44,15 +44,20 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
||||||
protected:
|
protected:
|
||||||
std::map<Key, size_t> cardinalities_; /// Map of Keys and their cardinalities.
|
/// Map of Keys and their cardinalities.
|
||||||
Eigen::SparseVector<double> sparse_table_; /// SparseVector of nonzero probabilities.
|
std::map<Key, size_t> cardinalities_;
|
||||||
|
/// SparseVector of nonzero probabilities.
|
||||||
|
Eigen::SparseVector<double> sparse_table_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<Key, size_t> denominators_; /// Map of Keys and their denominators used in keyValueForIndex.
|
/// Map of Keys and their denominators used in keyValueForIndex.
|
||||||
DiscreteKeys sorted_dkeys_; /// Sorted DiscreteKeys to use internally.
|
std::map<Key, size_t> denominators_;
|
||||||
|
/// Sorted DiscreteKeys to use internally.
|
||||||
|
DiscreteKeys sorted_dkeys_;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Uses lazy cartesian product to find nth entry in the cartesian product of arrays in O(1)
|
* @brief Uses lazy cartesian product to find nth entry in the cartesian
|
||||||
|
* product of arrays in O(1)
|
||||||
* Example)
|
* Example)
|
||||||
* v0 | v1 | val
|
* v0 | v1 | val
|
||||||
* 0 | 0 | 10
|
* 0 | 0 | 10
|
||||||
|
|
@ -66,6 +71,11 @@ namespace gtsam {
|
||||||
*/
|
*/
|
||||||
size_t keyValueForIndex(Key target_key, uint64_t index) const;
|
size_t keyValueForIndex(Key target_key, uint64_t index) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return ith key in keys_ as a DiscreteKey
|
||||||
|
* @param i ith key in keys_
|
||||||
|
* @return DiscreteKey
|
||||||
|
* */
|
||||||
DiscreteKey discreteKey(size_t i) const {
|
DiscreteKey discreteKey(size_t i) const {
|
||||||
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
||||||
}
|
}
|
||||||
|
|
@ -131,7 +141,6 @@ namespace gtsam {
|
||||||
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
||||||
: TableFactor(DiscreteKeys{key}, row) {}
|
: TableFactor(DiscreteKeys{key}, row) {}
|
||||||
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
@ -228,8 +237,8 @@ namespace gtsam {
|
||||||
DiscreteKeys unionDkeys(const TableFactor& f) const;
|
DiscreteKeys unionDkeys(const TableFactor& f) const;
|
||||||
|
|
||||||
/// Create unique representation of union modes.
|
/// Create unique representation of union modes.
|
||||||
uint64_t unionRep(const DiscreteKeys& keys,
|
uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign,
|
||||||
const DiscreteValues& assign, const uint64_t idx) const;
|
const uint64_t idx) const;
|
||||||
|
|
||||||
/// Create a hash map of input factor with assignment of contract modes as
|
/// Create a hash map of input factor with assignment of contract modes as
|
||||||
/// keys and vector of hashed assignment of free modes and value as values.
|
/// keys and vector of hashed assignment of free modes and value as values.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue