Merge branch 'develop' into decisiontree-improvements
						commit
						d74e41a1c3
					
				|  | @ -19,7 +19,8 @@ option(GTSAM_FORCE_STATIC_LIB               "Force gtsam to be a static library, | |||
| option(GTSAM_USE_QUATERNIONS                "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF) | ||||
| option(GTSAM_POSE3_EXPMAP                   "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON) | ||||
| option(GTSAM_ROT3_EXPMAP                    "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON) | ||||
| option(GTSAM_ENABLE_CONSISTENCY_CHECKS      "Enable/Disable expensive consistency checks"       OFF) | ||||
| option(GTSAM_ENABLE_CONSISTENCY_CHECKS      "Enable/Disable expensive consistency checks" OFF) | ||||
| option(GTSAM_ENABLE_MEMORY_SANITIZER        "Enable/Disable memory sanitizer" OFF) | ||||
| option(GTSAM_WITH_TBB                       "Use Intel Threaded Building Blocks (TBB) if available" ON) | ||||
| option(GTSAM_WITH_EIGEN_MKL                 "Eigen will use Intel MKL if available" OFF) | ||||
| option(GTSAM_WITH_EIGEN_MKL_OPENMP          "Eigen, when using Intel MKL, will also use OpenMP for multithreading if available" OFF) | ||||
|  |  | |||
|  | @ -50,3 +50,10 @@ if(GTSAM_ENABLE_CONSISTENCY_CHECKS) | |||
|   # This should be made PUBLIC if GTSAM_EXTRA_CONSISTENCY_CHECKS is someday used in a public .h | ||||
|   list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE GTSAM_EXTRA_CONSISTENCY_CHECKS) | ||||
| endif() | ||||
| 
 | ||||
| if(GTSAM_ENABLE_MEMORY_SANITIZER) | ||||
|   set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address  -fsanitize=leak -g") | ||||
|   set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address  -fsanitize=leak -g") | ||||
|   set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=address  -fsanitize=leak") | ||||
|   set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -fsanitize=address  -fsanitize=leak") | ||||
| endif() | ||||
|  |  | |||
|  | @ -87,6 +87,7 @@ print_config("CPack Generator" "${CPACK_GENERATOR}") | |||
| message(STATUS "GTSAM flags                                               ") | ||||
| print_enabled_config(${GTSAM_USE_QUATERNIONS}             "Quaternions as default Rot3     ") | ||||
| print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS}   "Runtime consistency checking    ") | ||||
| print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER}     "Build with Memory Sanitizer     ") | ||||
| print_enabled_config(${GTSAM_ROT3_EXPMAP}                 "Rot3 retract is full ExpMap     ") | ||||
| print_enabled_config(${GTSAM_POSE3_EXPMAP}                "Pose3 retract is full ExpMap    ") | ||||
| print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43}  "Allow features deprecated in GTSAM 4.3") | ||||
|  |  | |||
|  | @ -149,6 +149,9 @@ TEST(StdOptionalSerialization, SerializTestOptionalStructPointerPointer) { | |||
|   // Check that it worked
 | ||||
|   EXPECT(opt2.has_value()); | ||||
|   EXPECT(**opt2 == TestOptionalStruct(42)); | ||||
| 
 | ||||
|   delete (*opt); | ||||
|   delete (*opt2); | ||||
| } | ||||
| 
 | ||||
| int main() { | ||||
|  |  | |||
|  | @ -0,0 +1,554 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /**
 | ||||
|  * @file TableFactor.cpp | ||||
|  * @brief discrete factor | ||||
|  * @date May 4, 2023 | ||||
|  * @author Yoonwoo Kim | ||||
|  */ | ||||
| 
 | ||||
| #include <gtsam/base/FastSet.h> | ||||
| #include <gtsam/discrete/DecisionTreeFactor.h> | ||||
| #include <gtsam/discrete/TableFactor.h> | ||||
| #include <gtsam/hybrid/HybridValues.h> | ||||
| 
 | ||||
| #include <boost/format.hpp> | ||||
| #include <utility> | ||||
| 
 | ||||
| using namespace std; | ||||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| TableFactor::TableFactor() {} | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| TableFactor::TableFactor(const DiscreteKeys& dkeys, | ||||
|                          const TableFactor& potentials) | ||||
|     : DiscreteFactor(dkeys.indices()), | ||||
|       cardinalities_(potentials.cardinalities_) { | ||||
|   sparse_table_ = potentials.sparse_table_; | ||||
|   denominators_ = potentials.denominators_; | ||||
|   sorted_dkeys_ = discreteKeys(); | ||||
|   sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| TableFactor::TableFactor(const DiscreteKeys& dkeys, | ||||
|                          const Eigen::SparseVector<double>& table) | ||||
|     : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) { | ||||
|   sparse_table_ = table; | ||||
|   double denom = table.size(); | ||||
|   for (const DiscreteKey& dkey : dkeys) { | ||||
|     cardinalities_.insert(dkey); | ||||
|     denom /= dkey.second; | ||||
|     denominators_.insert(std::pair<Key, double>(dkey.first, denom)); | ||||
|   } | ||||
|   sorted_dkeys_ = discreteKeys(); | ||||
|   sort(sorted_dkeys_.begin(), sorted_dkeys_.end()); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| Eigen::SparseVector<double> TableFactor::Convert( | ||||
|     const std::vector<double>& table) { | ||||
|   Eigen::SparseVector<double> sparse_table(table.size()); | ||||
|   // Count number of nonzero elements in table and reserving the space.
 | ||||
|   const uint64_t nnz = std::count_if(table.begin(), table.end(), | ||||
|                                      [](uint64_t i) { return i != 0; }); | ||||
|   sparse_table.reserve(nnz); | ||||
|   for (uint64_t i = 0; i < table.size(); i++) { | ||||
|     if (table[i] != 0) sparse_table.insert(i) = table[i]; | ||||
|   } | ||||
|   sparse_table.pruned(); | ||||
|   sparse_table.data().squeeze(); | ||||
|   return sparse_table; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) { | ||||
|   // Convert string to doubles.
 | ||||
|   std::vector<double> ys; | ||||
|   std::istringstream iss(table); | ||||
|   std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(), | ||||
|             std::back_inserter(ys)); | ||||
|   return Convert(ys); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| bool TableFactor::equals(const DiscreteFactor& other, double tol) const { | ||||
|   if (!dynamic_cast<const TableFactor*>(&other)) { | ||||
|     return false; | ||||
|   } else { | ||||
|     const auto& f(static_cast<const TableFactor&>(other)); | ||||
|     return sparse_table_.isApprox(f.sparse_table_, tol); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| double TableFactor::operator()(const DiscreteValues& values) const { | ||||
|   // a b c d => D * (C * (B * (a) + b) + c) + d
 | ||||
|   uint64_t idx = 0, card = 1; | ||||
|   for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) { | ||||
|     if (values.find(it->first) != values.end()) { | ||||
|       idx += card * values.at(it->first); | ||||
|     } | ||||
|     card *= it->second; | ||||
|   } | ||||
|   return sparse_table_.coeff(idx); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| double TableFactor::findValue(const DiscreteValues& values) const { | ||||
|   // a b c d => D * (C * (B * (a) + b) + c) + d
 | ||||
|   uint64_t idx = 0, card = 1; | ||||
|   for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { | ||||
|     if (values.find(*it) != values.end()) { | ||||
|       idx += card * values.at(*it); | ||||
|     } | ||||
|     card *= cardinality(*it); | ||||
|   } | ||||
|   return sparse_table_.coeff(idx); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| double TableFactor::error(const DiscreteValues& values) const { | ||||
|   return -log(evaluate(values)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| double TableFactor::error(const HybridValues& values) const { | ||||
|   return error(values.discrete()); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { | ||||
|   return toDecisionTreeFactor() * f; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { | ||||
|   DiscreteKeys dkeys = discreteKeys(); | ||||
|   std::vector<double> table; | ||||
|   for (auto i = 0; i < sparse_table_.size(); i++) { | ||||
|     table.push_back(sparse_table_.coeff(i)); | ||||
|   } | ||||
|   DecisionTreeFactor f(dkeys, table); | ||||
|   return f; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| TableFactor TableFactor::choose(const DiscreteValues parent_assign, | ||||
|                                 DiscreteKeys parent_keys) const { | ||||
|   if (parent_keys.empty()) return *this; | ||||
| 
 | ||||
|   // Unique representation of parent values.
 | ||||
|   uint64_t unique = 0; | ||||
|   uint64_t card = 1; | ||||
|   for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) { | ||||
|     if (parent_assign.find(*it) != parent_assign.end()) { | ||||
|       unique += parent_assign.at(*it) * card; | ||||
|       card *= cardinality(*it); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Find child DiscreteKeys
 | ||||
|   DiscreteKeys child_dkeys; | ||||
|   std::sort(parent_keys.begin(), parent_keys.end()); | ||||
|   std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), | ||||
|                       parent_keys.begin(), parent_keys.end(), | ||||
|                       std::back_inserter(child_dkeys)); | ||||
| 
 | ||||
|   // Create child sparse table to populate.
 | ||||
|   uint64_t child_card = 1; | ||||
|   for (const DiscreteKey& child_dkey : child_dkeys) | ||||
|     child_card *= child_dkey.second; | ||||
|   Eigen::SparseVector<double> child_sparse_table_(child_card); | ||||
|   child_sparse_table_.reserve(child_card); | ||||
| 
 | ||||
|   // Populate child sparse table.
 | ||||
|   for (SparseIt it(sparse_table_); it; ++it) { | ||||
|     // Create unique representation of parent keys
 | ||||
|     uint64_t parent_unique = uniqueRep(parent_keys, it.index()); | ||||
|     // Populate the table
 | ||||
|     if (parent_unique == unique) { | ||||
|       uint64_t idx = uniqueRep(child_dkeys, it.index()); | ||||
|       child_sparse_table_.insert(idx) = it.value(); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   child_sparse_table_.pruned(); | ||||
|   child_sparse_table_.data().squeeze(); | ||||
|   return TableFactor(child_dkeys, child_sparse_table_); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| double TableFactor::safe_div(const double& a, const double& b) { | ||||
|   // The use for safe_div is when we divide the product factor by the sum
 | ||||
|   // factor. If the product or sum is zero, we accord zero probability to the
 | ||||
|   // event.
 | ||||
|   return (a == 0 || b == 0) ? 0 : (a / b); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| void TableFactor::print(const string& s, const KeyFormatter& formatter) const { | ||||
|   cout << s; | ||||
|   cout << " f["; | ||||
|   for (auto&& key : keys()) | ||||
|     cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key); | ||||
|   cout << " ]" << endl; | ||||
|   for (SparseIt it(sparse_table_); it; ++it) { | ||||
|     DiscreteValues assignment = findAssignments(it.index()); | ||||
|     for (auto&& kv : assignment) { | ||||
|       cout << "(" << formatter(kv.first) << ", " << kv.second << ")"; | ||||
|     } | ||||
|     cout << " | " << it.value() << " | " << it.index() << endl; | ||||
|   } | ||||
|   cout << "number of nnzs: " << sparse_table_.nonZeros() << endl; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| TableFactor TableFactor::apply(const TableFactor& f, Binary op) const { | ||||
|   if (keys_.empty() && sparse_table_.nonZeros() == 0) | ||||
|     return f; | ||||
|   else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0) | ||||
|     return *this; | ||||
|   // 1. Identify keys for contract and free modes.
 | ||||
|   DiscreteKeys contract_dkeys = contractDkeys(f); | ||||
|   DiscreteKeys f_free_dkeys = f.freeDkeys(*this); | ||||
|   DiscreteKeys union_dkeys = unionDkeys(f); | ||||
|   // 2. Create hash table for input factor f
 | ||||
|   unordered_map<uint64_t, AssignValList> map_f = | ||||
|       f.createMap(contract_dkeys, f_free_dkeys); | ||||
|   // 3. Initialize multiplied factor.
 | ||||
|   uint64_t card = 1; | ||||
|   for (auto u_dkey : union_dkeys) card *= u_dkey.second; | ||||
|   Eigen::SparseVector<double> mult_sparse_table(card); | ||||
|   mult_sparse_table.reserve(card); | ||||
|   // 3. Multiply.
 | ||||
|   for (SparseIt it(sparse_table_); it; ++it) { | ||||
|     uint64_t contract_unique = uniqueRep(contract_dkeys, it.index()); | ||||
|     if (map_f.find(contract_unique) == map_f.end()) continue; | ||||
|     for (auto assignVal : map_f[contract_unique]) { | ||||
|       uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index()); | ||||
|       mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second); | ||||
|     } | ||||
|   } | ||||
|   // 4. Free unused memory.
 | ||||
|   mult_sparse_table.pruned(); | ||||
|   mult_sparse_table.data().squeeze(); | ||||
|   // 5. Create union keys and return.
 | ||||
|   return TableFactor(union_dkeys, mult_sparse_table); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const { | ||||
|   // Find contract modes.
 | ||||
|   DiscreteKeys contract; | ||||
|   set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(), | ||||
|                    f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), | ||||
|                    back_inserter(contract)); | ||||
|   return contract; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const { | ||||
|   // Find free modes.
 | ||||
|   DiscreteKeys free; | ||||
|   set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(), | ||||
|                  f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(), | ||||
|                  back_inserter(free)); | ||||
|   return free; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const { | ||||
|   // Find union modes.
 | ||||
|   DiscreteKeys union_dkeys; | ||||
|   set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), f.sorted_dkeys_.begin(), | ||||
|             f.sorted_dkeys_.end(), back_inserter(union_dkeys)); | ||||
|   return union_dkeys; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys, | ||||
|                                const DiscreteValues& f_free, | ||||
|                                const uint64_t idx) const { | ||||
|   uint64_t union_idx = 0, card = 1; | ||||
|   for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) { | ||||
|     if (f_free.find(it->first) == f_free.end()) { | ||||
|       union_idx += keyValueForIndex(it->first, idx) * card; | ||||
|     } else { | ||||
|       union_idx += f_free.at(it->first) * card; | ||||
|     } | ||||
|     card *= it->second; | ||||
|   } | ||||
|   return union_idx; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| unordered_map<uint64_t, TableFactor::AssignValList> TableFactor::createMap( | ||||
|     const DiscreteKeys& contract, const DiscreteKeys& free) const { | ||||
|   // 1. Initialize map.
 | ||||
|   unordered_map<uint64_t, AssignValList> map_f; | ||||
|   // 2. Iterate over nonzero elements.
 | ||||
|   for (SparseIt it(sparse_table_); it; ++it) { | ||||
|     // 3. Create unique representation of contract modes.
 | ||||
|     uint64_t unique_rep = uniqueRep(contract, it.index()); | ||||
|     // 4. Create assignment for free modes.
 | ||||
|     DiscreteValues free_assignments; | ||||
|     for (auto& key : free) | ||||
|       free_assignments[key.first] = keyValueForIndex(key.first, it.index()); | ||||
|     // 5. Populate map.
 | ||||
|     if (map_f.find(unique_rep) == map_f.end()) { | ||||
|       map_f[unique_rep] = {make_pair(free_assignments, it.value())}; | ||||
|     } else { | ||||
|       map_f[unique_rep].push_back(make_pair(free_assignments, it.value())); | ||||
|     } | ||||
|   } | ||||
|   return map_f; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys, | ||||
|                                 const uint64_t idx) const { | ||||
|   if (dkeys.empty()) return 0; | ||||
|   uint64_t unique_rep = 0, card = 1; | ||||
|   for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) { | ||||
|     unique_rep += keyValueForIndex(it->first, idx) * card; | ||||
|     card *= it->second; | ||||
|   } | ||||
|   return unique_rep; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const { | ||||
|   if (assignments.empty()) return 0; | ||||
|   uint64_t unique_rep = 0, card = 1; | ||||
|   for (auto it = assignments.rbegin(); it != assignments.rend(); it++) { | ||||
|     unique_rep += it->second * card; | ||||
|     card *= cardinalities_.at(it->first); | ||||
|   } | ||||
|   return unique_rep; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DiscreteValues TableFactor::findAssignments(const uint64_t idx) const { | ||||
|   DiscreteValues assignment; | ||||
|   for (Key key : keys_) { | ||||
|     assignment[key] = keyValueForIndex(key, idx); | ||||
|   } | ||||
|   return assignment; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| TableFactor::shared_ptr TableFactor::combine(size_t nrFrontals, | ||||
|                                              Binary op) const { | ||||
|   if (nrFrontals > size()) { | ||||
|     throw invalid_argument( | ||||
|         "TableFactor::combine: invalid number of frontal " | ||||
|         "keys " + | ||||
|         to_string(nrFrontals) + ", nr.keys=" + std::to_string(size())); | ||||
|   } | ||||
|   // Find remaining keys.
 | ||||
|   DiscreteKeys remain_dkeys; | ||||
|   uint64_t card = 1; | ||||
|   for (auto i = nrFrontals; i < keys_.size(); i++) { | ||||
|     remain_dkeys.push_back(discreteKey(i)); | ||||
|     card *= cardinality(keys_[i]); | ||||
|   } | ||||
|   // Create combined table.
 | ||||
|   Eigen::SparseVector<double> combined_table(card); | ||||
|   combined_table.reserve(sparse_table_.nonZeros()); | ||||
|   // Populate combined table.
 | ||||
|   for (SparseIt it(sparse_table_); it; ++it) { | ||||
|     uint64_t idx = uniqueRep(remain_dkeys, it.index()); | ||||
|     double new_val = op(combined_table.coeff(idx), it.value()); | ||||
|     combined_table.coeffRef(idx) = new_val; | ||||
|   } | ||||
|   // Free unused memory.
 | ||||
|   combined_table.pruned(); | ||||
|   combined_table.data().squeeze(); | ||||
|   return std::make_shared<TableFactor>(remain_dkeys, combined_table); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| TableFactor::shared_ptr TableFactor::combine(const Ordering& frontalKeys, | ||||
|                                              Binary op) const { | ||||
|   if (frontalKeys.size() > size()) { | ||||
|     throw invalid_argument( | ||||
|         "TableFactor::combine: invalid number of frontal " | ||||
|         "keys " + | ||||
|         std::to_string(frontalKeys.size()) + | ||||
|         ", nr.keys=" + std::to_string(size())); | ||||
|   } | ||||
|   // Find remaining keys.
 | ||||
|   DiscreteKeys remain_dkeys; | ||||
|   uint64_t card = 1; | ||||
|   for (Key key : keys_) { | ||||
|     if (std::find(frontalKeys.begin(), frontalKeys.end(), key) == | ||||
|         frontalKeys.end()) { | ||||
|       remain_dkeys.emplace_back(key, cardinality(key)); | ||||
|       card *= cardinality(key); | ||||
|     } | ||||
|   } | ||||
|   // Create combined table.
 | ||||
|   Eigen::SparseVector<double> combined_table(card); | ||||
|   combined_table.reserve(sparse_table_.nonZeros()); | ||||
|   // Populate combined table.
 | ||||
|   for (SparseIt it(sparse_table_); it; ++it) { | ||||
|     uint64_t idx = uniqueRep(remain_dkeys, it.index()); | ||||
|     double new_val = op(combined_table.coeff(idx), it.value()); | ||||
|     combined_table.coeffRef(idx) = new_val; | ||||
|   } | ||||
|   // Free unused memory.
 | ||||
|   combined_table.pruned(); | ||||
|   combined_table.data().squeeze(); | ||||
|   return std::make_shared<TableFactor>(remain_dkeys, combined_table); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const { | ||||
|   // http://phrogz.net/lazy-cartesian-product
 | ||||
|   return (index / denominators_.at(target_key)) % cardinality(target_key); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const { | ||||
|   // Get all possible assignments
 | ||||
|   std::vector<std::pair<Key, size_t>> pairs = discreteKeys(); | ||||
|   // Reverse to make cartesian product output a more natural ordering.
 | ||||
|   std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend()); | ||||
|   const auto assignments = DiscreteValues::CartesianProduct(rpairs); | ||||
|   // Construct unordered_map with values
 | ||||
|   std::vector<std::pair<DiscreteValues, double>> result; | ||||
|   for (const auto& assignment : assignments) { | ||||
|     result.emplace_back(assignment, operator()(assignment)); | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| DiscreteKeys TableFactor::discreteKeys() const { | ||||
|   DiscreteKeys result; | ||||
|   for (auto&& key : keys()) { | ||||
|     DiscreteKey dkey(key, cardinality(key)); | ||||
|     if (std::find(result.begin(), result.end(), dkey) == result.end()) { | ||||
|       result.push_back(dkey); | ||||
|     } | ||||
|   } | ||||
|   return result; | ||||
| } | ||||
| 
 | ||||
| // Print out header.
 | ||||
| /* ************************************************************************ */ | ||||
| string TableFactor::markdown(const KeyFormatter& keyFormatter, | ||||
|                              const Names& names) const { | ||||
|   stringstream ss; | ||||
| 
 | ||||
|   // Print out header.
 | ||||
|   ss << "|"; | ||||
|   for (auto& key : keys()) { | ||||
|     ss << keyFormatter(key) << "|"; | ||||
|   } | ||||
|   ss << "value|\n"; | ||||
| 
 | ||||
|   // Print out separator with alignment hints.
 | ||||
|   ss << "|"; | ||||
|   for (size_t j = 0; j < size(); j++) ss << ":-:|"; | ||||
|   ss << ":-:|\n"; | ||||
| 
 | ||||
|   // Print out all rows.
 | ||||
|   for (SparseIt it(sparse_table_); it; ++it) { | ||||
|     DiscreteValues assignment = findAssignments(it.index()); | ||||
|     ss << "|"; | ||||
|     for (auto& key : keys()) { | ||||
|       size_t index = assignment.at(key); | ||||
|       ss << DiscreteValues::Translate(names, key, index) << "|"; | ||||
|     } | ||||
|     ss << it.value() << "|\n"; | ||||
|   } | ||||
|   return ss.str(); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| string TableFactor::html(const KeyFormatter& keyFormatter, | ||||
|                          const Names& names) const { | ||||
|   stringstream ss; | ||||
| 
 | ||||
|   // Print out preamble.
 | ||||
|   ss << "<div>\n<table class='TableFactor'>\n  <thead>\n"; | ||||
| 
 | ||||
|   // Print out header row.
 | ||||
|   ss << "    <tr>"; | ||||
|   for (auto& key : keys()) { | ||||
|     ss << "<th>" << keyFormatter(key) << "</th>"; | ||||
|   } | ||||
|   ss << "<th>value</th></tr>\n"; | ||||
| 
 | ||||
|   // Finish header and start body.
 | ||||
|   ss << "  </thead>\n  <tbody>\n"; | ||||
| 
 | ||||
|   // Print out all rows.
 | ||||
|   for (SparseIt it(sparse_table_); it; ++it) { | ||||
|     DiscreteValues assignment = findAssignments(it.index()); | ||||
|     ss << "    <tr>"; | ||||
|     for (auto& key : keys()) { | ||||
|       size_t index = assignment.at(key); | ||||
|       ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>"; | ||||
|     } | ||||
|     ss << "<td>" << it.value() << "</td>";  // value
 | ||||
|     ss << "</tr>\n"; | ||||
|   } | ||||
|   ss << "  </tbody>\n</table>\n</div>"; | ||||
|   return ss.str(); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| TableFactor TableFactor::prune(size_t maxNrAssignments) const { | ||||
|   const size_t N = maxNrAssignments; | ||||
| 
 | ||||
|   // Get the probabilities in the TableFactor so we can threshold.
 | ||||
|   vector<pair<Eigen::Index, double>> probabilities; | ||||
| 
 | ||||
|   // Store non-zero probabilities along with their indices in a vector.
 | ||||
|   for (SparseIt it(sparse_table_); it; ++it) { | ||||
|     probabilities.emplace_back(it.index(), it.value()); | ||||
|   } | ||||
| 
 | ||||
|   // The number of probabilities can be lower than max_leaves.
 | ||||
|   if (probabilities.size() <= N) return *this; | ||||
| 
 | ||||
|   // Sort the vector in descending order based on the element values.
 | ||||
|   sort(probabilities.begin(), probabilities.end(), | ||||
|        [](const std::pair<Eigen::Index, double>& a, | ||||
|           const std::pair<Eigen::Index, double>& b) { | ||||
|          return a.second > b.second; | ||||
|        }); | ||||
| 
 | ||||
|   // Keep the largest N probabilities in the vector.
 | ||||
|   if (probabilities.size() > N) probabilities.resize(N); | ||||
| 
 | ||||
|   // Create pruned sparse vector.
 | ||||
|   Eigen::SparseVector<double> pruned_vec(sparse_table_.size()); | ||||
|   pruned_vec.reserve(probabilities.size()); | ||||
| 
 | ||||
|   // Populate pruned sparse vector.
 | ||||
|   for (const auto& prob : probabilities) { | ||||
|     pruned_vec.insert(prob.first) = prob.second; | ||||
|   } | ||||
| 
 | ||||
|   // Create pruned decision tree factor and return.
 | ||||
|   return TableFactor(this->discreteKeys(), pruned_vec); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************ */ | ||||
| }  // namespace gtsam
 | ||||
|  | @ -0,0 +1,340 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /**
 | ||||
|  * @file TableFactor.h | ||||
|  * @date May 4, 2023 | ||||
|  * @author Yoonwoo Kim | ||||
|  */ | ||||
| 
 | ||||
| #pragma once | ||||
| 
 | ||||
| #include <gtsam/discrete/DiscreteFactor.h> | ||||
| #include <gtsam/discrete/DiscreteKey.h> | ||||
| #include <gtsam/inference/Ordering.h> | ||||
| 
 | ||||
| #include <Eigen/Sparse> | ||||
| #include <algorithm> | ||||
| #include <map> | ||||
| #include <memory> | ||||
| #include <stdexcept> | ||||
| #include <string> | ||||
| #include <utility> | ||||
| #include <vector> | ||||
| 
 | ||||
| namespace gtsam { | ||||
| 
 | ||||
| class HybridValues; | ||||
| 
 | ||||
| /**
 | ||||
|  * A discrete probabilistic factor optimized for sparsity. | ||||
|  * Uses sparse_table_ to store only the nonzero probabilities. | ||||
|  * Computes the assigned value for the key using the ordering which the | ||||
|  * nonzero probabilties are stored in. (lazy cartesian product) | ||||
|  * | ||||
|  * @ingroup discrete | ||||
|  */ | ||||
| class GTSAM_EXPORT TableFactor : public DiscreteFactor { | ||||
|  protected: | ||||
|   /// Map of Keys and their cardinalities.
 | ||||
|   std::map<Key, size_t> cardinalities_; | ||||
|   /// SparseVector of nonzero probabilities.
 | ||||
|   Eigen::SparseVector<double> sparse_table_; | ||||
| 
 | ||||
|  private: | ||||
|   /// Map of Keys and their denominators used in keyValueForIndex.
 | ||||
|   std::map<Key, size_t> denominators_; | ||||
|   /// Sorted DiscreteKeys to use internally.
 | ||||
|   DiscreteKeys sorted_dkeys_; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Uses lazy cartesian product to find nth entry in the cartesian | ||||
|    * product of arrays in O(1)  | ||||
|    * Example)  | ||||
|    *   v0 | v1 | val  | ||||
|    *    0 |  0 |  10  | ||||
|    *    0 |  1 |  21 | ||||
|    *    1 |  0 |  32 | ||||
|    *    1 |  1 |  43 | ||||
|    *   keyValueForIndex(v1, 2) = 0 | ||||
|    * @param target_key nth entry's key to find out its assigned value | ||||
|    * @param index nth entry in the sparse vector | ||||
|    * @return TableFactor | ||||
|    */ | ||||
|   size_t keyValueForIndex(Key target_key, uint64_t index) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Return ith key in keys_ as a DiscreteKey | ||||
|    * @param i ith key in keys_ | ||||
|    * @return DiscreteKey | ||||
|    * */  | ||||
|   DiscreteKey discreteKey(size_t i) const { | ||||
|     return DiscreteKey(keys_[i], cardinalities_.at(keys_[i])); | ||||
|   } | ||||
| 
 | ||||
|   /// Convert probability table given as doubles to SparseVector.
 | ||||
|   /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} 
 | ||||
|   static Eigen::SparseVector<double> Convert(const std::vector<double>& table); | ||||
| 
 | ||||
|   /// Convert probability table given as string to SparseVector.
 | ||||
|   static Eigen::SparseVector<double> Convert(const std::string& table); | ||||
| 
 | ||||
|  public: | ||||
|   // typedefs needed to play nice with gtsam
 | ||||
|   typedef TableFactor This; | ||||
|   typedef DiscreteFactor Base;  ///< Typedef to base class
 | ||||
|   typedef std::shared_ptr<TableFactor> shared_ptr; | ||||
|   typedef Eigen::SparseVector<double>::InnerIterator SparseIt; | ||||
|   typedef std::vector<std::pair<DiscreteValues, double>> AssignValList; | ||||
|   using Binary = std::function<double(const double, const double)>; | ||||
| 
 | ||||
|  public: | ||||
|   /** The Real ring with addition and multiplication */ | ||||
|   struct Ring { | ||||
|     static inline double zero() { return 0.0; } | ||||
|     static inline double one() { return 1.0; } | ||||
|     static inline double add(const double& a, const double& b) { return a + b; } | ||||
|     static inline double max(const double& a, const double& b) { | ||||
|       return std::max(a, b); | ||||
|     } | ||||
|     static inline double mul(const double& a, const double& b) { return a * b; } | ||||
|     static inline double div(const double& a, const double& b) { | ||||
|       return (a == 0 || b == 0) ? 0 : (a / b); | ||||
|     } | ||||
|     static inline double id(const double& x) { return x; } | ||||
|   }; | ||||
| 
 | ||||
|   /// @name Standard Constructors
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /** Default constructor for I/O */ | ||||
|   TableFactor(); | ||||
| 
 | ||||
|   /** Constructor from DiscreteKeys and TableFactor */ | ||||
|   TableFactor(const DiscreteKeys& keys, const TableFactor& potentials); | ||||
| 
 | ||||
|   /** Constructor from sparse_table */ | ||||
|   TableFactor(const DiscreteKeys& keys, | ||||
|               const Eigen::SparseVector<double>& table); | ||||
| 
 | ||||
|   /** Constructor from doubles */ | ||||
|   TableFactor(const DiscreteKeys& keys, const std::vector<double>& table) | ||||
|       : TableFactor(keys, Convert(table)) {} | ||||
| 
 | ||||
|   /** Constructor from string */ | ||||
|   TableFactor(const DiscreteKeys& keys, const std::string& table) | ||||
|       : TableFactor(keys, Convert(table)) {} | ||||
| 
 | ||||
|   /// Single-key specialization
 | ||||
|   template <class SOURCE> | ||||
|   TableFactor(const DiscreteKey& key, SOURCE table) | ||||
|       : TableFactor(DiscreteKeys{key}, table) {} | ||||
| 
 | ||||
|   /// Single-key specialization, with vector of doubles.
 | ||||
|   TableFactor(const DiscreteKey& key, const std::vector<double>& row) | ||||
|       : TableFactor(DiscreteKeys{key}, row) {} | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Testable
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /// equality
 | ||||
|   bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; | ||||
| 
 | ||||
|   // print
 | ||||
|   void print( | ||||
|       const std::string& s = "TableFactor:\n", | ||||
|       const KeyFormatter& formatter = DefaultKeyFormatter) const override; | ||||
| 
 | ||||
|   // /// @}
 | ||||
|   // /// @name Standard Interface
 | ||||
|   // /// @{
 | ||||
| 
 | ||||
|   /// Calculate probability for given values `x`,
 | ||||
|   /// is just look up in TableFactor.
 | ||||
|   double evaluate(const DiscreteValues& values) const { | ||||
|     return operator()(values); | ||||
|   } | ||||
| 
 | ||||
|   /// Evaluate probability distribution, sugar.
 | ||||
|   double operator()(const DiscreteValues& values) const override; | ||||
| 
 | ||||
|   /// Calculate error for DiscreteValues `x`, is -log(probability).
 | ||||
|   double error(const DiscreteValues& values) const; | ||||
| 
 | ||||
|   /// multiply two TableFactors
 | ||||
|   TableFactor operator*(const TableFactor& f) const { | ||||
|     return apply(f, Ring::mul); | ||||
|   }; | ||||
| 
 | ||||
|   /// multiple with DecisionTreeFactor
 | ||||
|   DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; | ||||
| 
 | ||||
|   static double safe_div(const double& a, const double& b); | ||||
| 
 | ||||
|   size_t cardinality(Key j) const { return cardinalities_.at(j); } | ||||
| 
 | ||||
|   /// divide by factor f (safely)
 | ||||
|   TableFactor operator/(const TableFactor& f) const { | ||||
|     return apply(f, safe_div); | ||||
|   } | ||||
| 
 | ||||
|   /// Convert into a decisiontree
 | ||||
|   DecisionTreeFactor toDecisionTreeFactor() const override; | ||||
| 
 | ||||
|   /// Create a TableFactor that is a subset of this TableFactor
 | ||||
|   TableFactor choose(const DiscreteValues assignments, | ||||
|                      DiscreteKeys parent_keys) const; | ||||
| 
 | ||||
|   /// Create new factor by summing all values with the same separator values
 | ||||
|   shared_ptr sum(size_t nrFrontals) const { | ||||
|     return combine(nrFrontals, Ring::add); | ||||
|   } | ||||
| 
 | ||||
|   /// Create new factor by summing all values with the same separator values
 | ||||
|   shared_ptr sum(const Ordering& keys) const { | ||||
|     return combine(keys, Ring::add); | ||||
|   } | ||||
| 
 | ||||
|   /// Create new factor by maximizing over all values with the same separator.
 | ||||
|   shared_ptr max(size_t nrFrontals) const { | ||||
|     return combine(nrFrontals, Ring::max); | ||||
|   } | ||||
| 
 | ||||
|   /// Create new factor by maximizing over all values with the same separator.
 | ||||
|   shared_ptr max(const Ordering& keys) const { | ||||
|     return combine(keys, Ring::max); | ||||
|   } | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Advanced Interface
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /**
 | ||||
|    * Apply binary operator (*this) "op" f | ||||
|    * @param f the second argument for op | ||||
|    * @param op a binary operator that operates on TableFactor | ||||
|    */ | ||||
|   TableFactor apply(const TableFactor& f, Binary op) const; | ||||
| 
 | ||||
|   /// Return keys in contract mode.
 | ||||
|   DiscreteKeys contractDkeys(const TableFactor& f) const; | ||||
| 
 | ||||
|   /// Return keys in free mode.
 | ||||
|   DiscreteKeys freeDkeys(const TableFactor& f) const; | ||||
| 
 | ||||
|   /// Return union of DiscreteKeys in two factors.
 | ||||
|   DiscreteKeys unionDkeys(const TableFactor& f) const; | ||||
| 
 | ||||
|   /// Create unique representation of union modes.
 | ||||
|   uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign, | ||||
|                     const uint64_t idx) const; | ||||
| 
 | ||||
|   /// Create a hash map of input factor with assignment of contract modes as
 | ||||
|   /// keys and vector of hashed assignment of free modes and value as values.
 | ||||
|   std::unordered_map<uint64_t, AssignValList> createMap( | ||||
|       const DiscreteKeys& contract, const DiscreteKeys& free) const; | ||||
| 
 | ||||
|   /// Create unique representation
 | ||||
|   uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const; | ||||
| 
 | ||||
|   /// Create unique representation with DiscreteValues
 | ||||
|   uint64_t uniqueRep(const DiscreteValues& assignments) const; | ||||
| 
 | ||||
|   /// Find DiscreteValues for corresponding index.
 | ||||
|   DiscreteValues findAssignments(const uint64_t idx) const; | ||||
| 
 | ||||
|   /// Find value for corresponding DiscreteValues.
 | ||||
|   double findValue(const DiscreteValues& values) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * Combine frontal variables using binary operator "op" | ||||
|    * @param nrFrontals nr. of frontal to combine variables in this factor | ||||
|    * @param op a binary operator that operates on TableFactor | ||||
|    * @return shared pointer to newly created TableFactor | ||||
|    */ | ||||
|   shared_ptr combine(size_t nrFrontals, Binary op) const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * Combine frontal variables in an Ordering using binary operator "op" | ||||
|    * @param nrFrontals nr. of frontal to combine variables in this factor | ||||
|    * @param op a binary operator that operates on TableFactor | ||||
|    * @return shared pointer to newly created TableFactor | ||||
|    */ | ||||
|   shared_ptr combine(const Ordering& keys, Binary op) const; | ||||
| 
 | ||||
|   /// Enumerate all values into a map from values to double.
 | ||||
|   std::vector<std::pair<DiscreteValues, double>> enumerate() const; | ||||
| 
 | ||||
|   /// Return all the discrete keys associated with this factor.
 | ||||
|   DiscreteKeys discreteKeys() const; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Prune the decision tree of discrete variables. | ||||
|    * | ||||
|    * Pruning will set the values to be "pruned" to 0 indicating a 0 | ||||
|    * probability. An assignment is pruned if it is not in the top | ||||
|    * `maxNrAssignments` values. | ||||
|    * | ||||
|    * A violation can occur if there are more | ||||
|    * duplicate values than `maxNrAssignments`. A violation here is the need to | ||||
|    * un-prune the decision tree (e.g. all assignment values are 1.0). We could | ||||
|    * have another case where some subset of duplicates exist (e.g. for a tree | ||||
|    * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is | ||||
|    * not a violation since the for `maxNrAssignments=5` the top values are (1, | ||||
|    * 0.8). | ||||
|    * | ||||
|    * @param maxNrAssignments The maximum number of assignments to keep. | ||||
|    * @return TableFactor | ||||
|    */ | ||||
|   TableFactor prune(size_t maxNrAssignments) const; | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name Wrapper support
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Render as markdown table | ||||
|    * | ||||
|    * @param keyFormatter GTSAM-style Key formatter. | ||||
|    * @param names optional, category names corresponding to choices. | ||||
|    * @return std::string a markdown string. | ||||
|    */ | ||||
|   std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter, | ||||
|                        const Names& names = {}) const override; | ||||
| 
 | ||||
|   /**
 | ||||
|    * @brief Render as html table | ||||
|    * | ||||
|    * @param keyFormatter GTSAM-style Key formatter. | ||||
|    * @param names optional, category names corresponding to choices. | ||||
|    * @return std::string a html string. | ||||
|    */ | ||||
|   std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter, | ||||
|                    const Names& names = {}) const override; | ||||
| 
 | ||||
|   /// @}
 | ||||
|   /// @name HybridValues methods.
 | ||||
|   /// @{
 | ||||
| 
 | ||||
|   /**
 | ||||
|    * Calculate error for HybridValues `x`, is -log(probability) | ||||
|    * Simply dispatches to DiscreteValues version. | ||||
|    */ | ||||
|   double error(const HybridValues& values) const override; | ||||
| 
 | ||||
|   /// @}
 | ||||
| }; | ||||
| 
 | ||||
| // traits
 | ||||
| template <> | ||||
| struct traits<TableFactor> : public Testable<TableFactor> {}; | ||||
| }  // namespace gtsam
 | ||||
|  | @ -0,0 +1,360 @@ | |||
| /* ----------------------------------------------------------------------------
 | ||||
| 
 | ||||
|  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||
|  * Atlanta, Georgia 30332-0415 | ||||
|  * All Rights Reserved | ||||
|  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||
| 
 | ||||
|  * See LICENSE for the license information | ||||
| 
 | ||||
|  * -------------------------------------------------------------------------- */ | ||||
| 
 | ||||
| /*
 | ||||
|  * testTableFactor.cpp | ||||
|  * | ||||
|  *  @date Feb 15, 2023 | ||||
|  *  @author Yoonwoo Kim | ||||
|  */ | ||||
| 
 | ||||
| #include <CppUnitLite/TestHarness.h> | ||||
| #include <gtsam/base/Testable.h> | ||||
| #include <gtsam/base/serializationTestHelpers.h> | ||||
| #include <gtsam/discrete/DiscreteDistribution.h> | ||||
| #include <gtsam/discrete/Signature.h> | ||||
| #include <gtsam/discrete/TableFactor.h> | ||||
| 
 | ||||
| #include <chrono> | ||||
| #include <random> | ||||
| 
 | ||||
| using namespace std; | ||||
| using namespace gtsam; | ||||
| 
 | ||||
| vector<double> genArr(double dropout, size_t size) { | ||||
|   random_device rd; | ||||
|   mt19937 g(rd()); | ||||
|   vector<double> dropoutmask(size);  // Chance of 0
 | ||||
| 
 | ||||
|   uniform_int_distribution<> dist(1, 9); | ||||
|   auto gen = [&dist, &g]() { return dist(g); }; | ||||
|   generate(dropoutmask.begin(), dropoutmask.end(), gen); | ||||
| 
 | ||||
|   fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0); | ||||
|   shuffle(dropoutmask.begin(), dropoutmask.end(), g); | ||||
| 
 | ||||
|   return dropoutmask; | ||||
| } | ||||
| 
 | ||||
| map<double, pair<chrono::microseconds, chrono::microseconds>> measureTime( | ||||
|     DiscreteKeys keys1, DiscreteKeys keys2, size_t size) { | ||||
|   vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9}; | ||||
|   map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times; | ||||
| 
 | ||||
|   for (auto dropout : dropouts) { | ||||
|     vector<double> arr1 = genArr(dropout, size); | ||||
|     vector<double> arr2 = genArr(dropout, size); | ||||
|     TableFactor f1(keys1, arr1); | ||||
|     TableFactor f2(keys2, arr2); | ||||
|     DecisionTreeFactor f1_dt(keys1, arr1); | ||||
|     DecisionTreeFactor f2_dt(keys2, arr2); | ||||
| 
 | ||||
|     // measure time TableFactor
 | ||||
|     auto tb_start = chrono::high_resolution_clock::now(); | ||||
|     TableFactor actual = f1 * f2; | ||||
|     auto tb_end = chrono::high_resolution_clock::now(); | ||||
|     auto tb_time_diff = | ||||
|         chrono::duration_cast<chrono::microseconds>(tb_end - tb_start); | ||||
| 
 | ||||
|     // measure time DT
 | ||||
|     auto dt_start = chrono::high_resolution_clock::now(); | ||||
|     DecisionTreeFactor actual_dt = f1_dt * f2_dt; | ||||
|     auto dt_end = chrono::high_resolution_clock::now(); | ||||
|     auto dt_time_diff = | ||||
|         chrono::duration_cast<chrono::microseconds>(dt_end - dt_start); | ||||
| 
 | ||||
|     bool flag = true; | ||||
|     for (auto assignmentVal : actual_dt.enumerate()) { | ||||
|       flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first); | ||||
|       if (flag) { | ||||
|         std::cout << "something is wrong: " << std::endl; | ||||
|         assignmentVal.first.print(); | ||||
|         std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl; | ||||
|         std::cout << "tb: " << actual(assignmentVal.first) << std::endl; | ||||
|         break; | ||||
|       } | ||||
|     } | ||||
|     if (flag) break; | ||||
|     measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff); | ||||
|   } | ||||
|   return measured_times; | ||||
| } | ||||
| 
 | ||||
| void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>> | ||||
|                    measured_time) { | ||||
|   for (auto&& kv : measured_time) { | ||||
|     cout << "dropout: " << kv.first | ||||
|          << " | TableFactor time: " << kv.second.first.count() | ||||
|          << " | DecisionTreeFactor time: " << kv.second.second.count() << endl; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check constructors for TableFactor.
 | ||||
| TEST(TableFactor, constructors) { | ||||
|   // Declare a bunch of keys
 | ||||
|   DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5); | ||||
| 
 | ||||
|   // Create factors
 | ||||
|   TableFactor f_zeros(A, {0, 0, 0, 0, 1}); | ||||
|   TableFactor f1(X, {2, 8}); | ||||
|   TableFactor f2(X & Y, "2 5 3 6 4 7"); | ||||
|   TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75"); | ||||
|   EXPECT_LONGS_EQUAL(1, f1.size()); | ||||
|   EXPECT_LONGS_EQUAL(2, f2.size()); | ||||
|   EXPECT_LONGS_EQUAL(3, f3.size()); | ||||
| 
 | ||||
|   DiscreteValues values; | ||||
|   values[0] = 1;  // x
 | ||||
|   values[1] = 2;  // y
 | ||||
|   values[2] = 1;  // z
 | ||||
|   values[3] = 4;  // a
 | ||||
|   EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9); | ||||
|   EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9); | ||||
|   EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9); | ||||
|   EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9); | ||||
| 
 | ||||
|   // Assert that error = -log(value)
 | ||||
|   EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check multiplication between two TableFactors.
 | ||||
| TEST(TableFactor, multiplication) { | ||||
|   DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); | ||||
| 
 | ||||
|   // Multiply with a DiscreteDistribution, i.e., Bayes Law!
 | ||||
|   DiscreteDistribution prior(v1 % "1/3"); | ||||
|   TableFactor f1(v0 & v1, "1 2 3 4"); | ||||
|   DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); | ||||
|   CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) * | ||||
|                                    f1.toDecisionTreeFactor())); | ||||
|   CHECK(assert_equal(expected, f1 * prior)); | ||||
| 
 | ||||
|   // Multiply two factors
 | ||||
|   TableFactor f2(v1 & v2, "5 6 7 8"); | ||||
|   TableFactor actual = f1 * f2; | ||||
|   TableFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); | ||||
|   CHECK(assert_equal(expected2, actual)); | ||||
| 
 | ||||
|   DiscreteKey A(0, 3), B(1, 2), C(2, 2); | ||||
|   TableFactor f_zeros1(A & C, "0 0 0 2 0 3"); | ||||
|   TableFactor f_zeros2(B & C, "4 0 0 5"); | ||||
|   TableFactor actual_zeros = f_zeros1 * f_zeros2; | ||||
|   TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15"); | ||||
|   CHECK(assert_equal(expected3, actual_zeros)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Benchmark which compares runtime of multiplication of two TableFactors
 | ||||
| // and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
 | ||||
| TEST(TableFactor, benchmark) { | ||||
|   DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3), | ||||
|       H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3); | ||||
| 
 | ||||
|   // 100
 | ||||
|   DiscreteKeys one_1 = {A, B, C, D}; | ||||
|   DiscreteKeys one_2 = {C, D, E, F}; | ||||
|   map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 = | ||||
|       measureTime(one_1, one_2, 100); | ||||
|   printTime(time_map_1); | ||||
|   // 200
 | ||||
|   DiscreteKeys two_1 = {A, B, C, D, F}; | ||||
|   DiscreteKeys two_2 = {B, C, D, E, F}; | ||||
|   map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 = | ||||
|       measureTime(two_1, two_2, 200); | ||||
|   printTime(time_map_2); | ||||
|   // 300
 | ||||
|   DiscreteKeys three_1 = {A, B, C, D, G}; | ||||
|   DiscreteKeys three_2 = {C, D, E, F, G}; | ||||
|   map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 = | ||||
|       measureTime(three_1, three_2, 300); | ||||
|   printTime(time_map_3); | ||||
|   // 400
 | ||||
|   DiscreteKeys four_1 = {A, B, C, D, F, H}; | ||||
|   DiscreteKeys four_2 = {B, C, D, E, F, H}; | ||||
|   map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 = | ||||
|       measureTime(four_1, four_2, 400); | ||||
|   printTime(time_map_4); | ||||
|   // 500
 | ||||
|   DiscreteKeys five_1 = {A, B, C, D, I}; | ||||
|   DiscreteKeys five_2 = {C, D, E, F, I}; | ||||
|   map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 = | ||||
|       measureTime(five_1, five_2, 500); | ||||
|   printTime(time_map_5); | ||||
|   // 600
 | ||||
|   DiscreteKeys six_1 = {A, B, C, D, F, G}; | ||||
|   DiscreteKeys six_2 = {B, C, D, E, F, G}; | ||||
|   map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 = | ||||
|       measureTime(six_1, six_2, 600); | ||||
|   printTime(time_map_6); | ||||
|   // 700
 | ||||
|   DiscreteKeys seven_1 = {A, B, C, D, J}; | ||||
|   DiscreteKeys seven_2 = {C, D, E, F, J}; | ||||
|   map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 = | ||||
|       measureTime(seven_1, seven_2, 700); | ||||
|   printTime(time_map_7); | ||||
|   // 800
 | ||||
|   DiscreteKeys eight_1 = {A, B, C, D, F, H, K}; | ||||
|   DiscreteKeys eight_2 = {B, C, D, E, F, H, K}; | ||||
|   map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 = | ||||
|       measureTime(eight_1, eight_2, 800); | ||||
|   printTime(time_map_8); | ||||
|   // 900
 | ||||
|   DiscreteKeys nine_1 = {A, B, C, D, G, L}; | ||||
|   DiscreteKeys nine_2 = {C, D, E, F, G, L}; | ||||
|   map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 = | ||||
|       measureTime(nine_1, nine_2, 900); | ||||
|   printTime(time_map_9); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check sum and max over frontals.
 | ||||
| TEST(TableFactor, sum_max) { | ||||
|   DiscreteKey v0(0, 3), v1(1, 2); | ||||
|   TableFactor f1(v0 & v1, "1 2  3 4  5 6"); | ||||
| 
 | ||||
|   TableFactor expected(v1, "9 12"); | ||||
|   TableFactor::shared_ptr actual = f1.sum(1); | ||||
|   CHECK(assert_equal(expected, *actual, 1e-5)); | ||||
| 
 | ||||
|   TableFactor expected2(v1, "5 6"); | ||||
|   TableFactor::shared_ptr actual2 = f1.max(1); | ||||
|   CHECK(assert_equal(expected2, *actual2)); | ||||
| 
 | ||||
|   TableFactor f2(v1 & v0, "1 2  3 4  5 6"); | ||||
|   TableFactor::shared_ptr actual22 = f2.sum(1); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check enumerate yields the correct list of assignment/value pairs.
 | ||||
| TEST(TableFactor, enumerate) { | ||||
|   DiscreteKey A(12, 3), B(5, 2); | ||||
|   TableFactor f(A & B, "1 2  3 4  5 6"); | ||||
|   auto actual = f.enumerate(); | ||||
|   std::vector<std::pair<DiscreteValues, double>> expected; | ||||
|   DiscreteValues values; | ||||
|   for (size_t a : {0, 1, 2}) { | ||||
|     for (size_t b : {0, 1}) { | ||||
|       values[12] = a; | ||||
|       values[5] = b; | ||||
|       expected.emplace_back(values, f(values)); | ||||
|     } | ||||
|   } | ||||
|   EXPECT(actual == expected); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check pruning of the decision tree works as expected.
 | ||||
| TEST(TableFactor, Prune) { | ||||
|   DiscreteKey A(1, 2), B(2, 2), C(3, 2); | ||||
|   TableFactor f(A & B & C, "1 5 3 7 2 6 4 8"); | ||||
| 
 | ||||
|   // Only keep the leaves with the top 5 values.
 | ||||
|   size_t maxNrAssignments = 5; | ||||
|   auto pruned5 = f.prune(maxNrAssignments); | ||||
| 
 | ||||
|   // Pruned leaves should be 0
 | ||||
|   TableFactor expected(A & B & C, "0 5 0 7 0 6 4 8"); | ||||
|   EXPECT(assert_equal(expected, pruned5)); | ||||
| 
 | ||||
|   // Check for more extreme pruning where we only keep the top 2 leaves
 | ||||
|   maxNrAssignments = 2; | ||||
|   auto pruned2 = f.prune(maxNrAssignments); | ||||
|   TableFactor expected2(A & B & C, "0 0 0 7 0 0 0 8"); | ||||
|   EXPECT(assert_equal(expected2, pruned2)); | ||||
| 
 | ||||
|   DiscreteKey D(4, 2); | ||||
|   TableFactor factor( | ||||
|       D & C & B & A, | ||||
|       "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 " | ||||
|       "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0"); | ||||
| 
 | ||||
|   TableFactor expected3(D & C & B & A, | ||||
|                         "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 " | ||||
|                         "0.999952870000 1.0 1.0 1.0 1.0"); | ||||
|   maxNrAssignments = 5; | ||||
|   auto pruned3 = factor.prune(maxNrAssignments); | ||||
|   EXPECT(assert_equal(expected3, pruned3)); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check markdown representation looks as expected.
 | ||||
| TEST(TableFactor, markdown) { | ||||
|   DiscreteKey A(12, 3), B(5, 2); | ||||
|   TableFactor f(A & B, "1 2  3 4  5 6"); | ||||
|   string expected = | ||||
|       "|A|B|value|\n" | ||||
|       "|:-:|:-:|:-:|\n" | ||||
|       "|0|0|1|\n" | ||||
|       "|0|1|2|\n" | ||||
|       "|1|0|3|\n" | ||||
|       "|1|1|4|\n" | ||||
|       "|2|0|5|\n" | ||||
|       "|2|1|6|\n"; | ||||
|   auto formatter = [](Key key) { return key == 12 ? "A" : "B"; }; | ||||
|   string actual = f.markdown(formatter); | ||||
|   EXPECT(actual == expected); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check markdown representation with a value formatter.
 | ||||
| TEST(TableFactor, markdownWithValueFormatter) { | ||||
|   DiscreteKey A(12, 3), B(5, 2); | ||||
|   TableFactor f(A & B, "1 2  3 4  5 6"); | ||||
|   string expected = | ||||
|       "|A|B|value|\n" | ||||
|       "|:-:|:-:|:-:|\n" | ||||
|       "|Zero|-|1|\n" | ||||
|       "|Zero|+|2|\n" | ||||
|       "|One|-|3|\n" | ||||
|       "|One|+|4|\n" | ||||
|       "|Two|-|5|\n" | ||||
|       "|Two|+|6|\n"; | ||||
|   auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; | ||||
|   TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; | ||||
|   string actual = f.markdown(keyFormatter, names); | ||||
|   EXPECT(actual == expected); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| // Check html representation with a value formatter.
 | ||||
| TEST(TableFactor, htmlWithValueFormatter) { | ||||
|   DiscreteKey A(12, 3), B(5, 2); | ||||
|   TableFactor f(A & B, "1 2  3 4  5 6"); | ||||
|   string expected = | ||||
|       "<div>\n" | ||||
|       "<table class='TableFactor'>\n" | ||||
|       "  <thead>\n" | ||||
|       "    <tr><th>A</th><th>B</th><th>value</th></tr>\n" | ||||
|       "  </thead>\n" | ||||
|       "  <tbody>\n" | ||||
|       "    <tr><th>Zero</th><th>-</th><td>1</td></tr>\n" | ||||
|       "    <tr><th>Zero</th><th>+</th><td>2</td></tr>\n" | ||||
|       "    <tr><th>One</th><th>-</th><td>3</td></tr>\n" | ||||
|       "    <tr><th>One</th><th>+</th><td>4</td></tr>\n" | ||||
|       "    <tr><th>Two</th><th>-</th><td>5</td></tr>\n" | ||||
|       "    <tr><th>Two</th><th>+</th><td>6</td></tr>\n" | ||||
|       "  </tbody>\n" | ||||
|       "</table>\n" | ||||
|       "</div>"; | ||||
|   auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; }; | ||||
|   TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}}; | ||||
|   string actual = f.html(keyFormatter, names); | ||||
|   EXPECT(actual == expected); | ||||
| } | ||||
| 
 | ||||
| /* ************************************************************************* */ | ||||
| int main() { | ||||
|   TestResult tr; | ||||
|   return TestRegistry::runAllTests(tr); | ||||
| } | ||||
| /* ************************************************************************* */ | ||||
|  | @ -105,7 +105,12 @@ function(wrap_library_internal interfaceHeader moduleName linkLibraries extraInc | |||
|       set(mexModuleExt mexglx) | ||||
|     endif() | ||||
|   elseif(APPLE) | ||||
|     set(mexModuleExt mexmaci64) | ||||
|     check_cxx_compiler_flag("-arch arm64" arm64Supported) | ||||
|     if (arm64Supported) | ||||
|       set(mexModuleExt mexmaca64) | ||||
|     else() | ||||
|       set(mexModuleExt mexmaci64) | ||||
|     endif() | ||||
|   elseif(MSVC) | ||||
|     if(CMAKE_CL_64) | ||||
|       set(mexModuleExt mexw64) | ||||
|  | @ -299,7 +304,12 @@ function(wrap_library_internal interfaceHeader moduleName linkLibraries extraInc | |||
|       APPEND | ||||
|       PROPERTY COMPILE_FLAGS "/bigobj") | ||||
|   elseif(APPLE) | ||||
|     set(mxLibPath "${MATLAB_ROOT}/bin/maci64") | ||||
|     check_cxx_compiler_flag("-arch arm64" arm64Supported) | ||||
|     if (arm64Supported) | ||||
|       set(mxLibPath "${MATLAB_ROOT}/bin/maca64") | ||||
|     else() | ||||
|       set(mxLibPath "${MATLAB_ROOT}/bin/maci64") | ||||
|     endif() | ||||
|     target_link_libraries( | ||||
|       ${moduleName}_matlab_wrapper "${mxLibPath}/libmex.dylib" | ||||
|       "${mxLibPath}/libmx.dylib" "${mxLibPath}/libmat.dylib") | ||||
|  | @ -367,7 +377,12 @@ function(check_conflicting_libraries_internal libraries) | |||
|   if(UNIX) | ||||
|     # Set path for matlab's built-in libraries | ||||
|     if(APPLE) | ||||
|       set(mxLibPath "${MATLAB_ROOT}/bin/maci64") | ||||
|       check_cxx_compiler_flag("-arch arm64" arm64Supported) | ||||
|       if (arm64Supported) | ||||
|         set(mxLibPath "${MATLAB_ROOT}/bin/maca64") | ||||
|       else() | ||||
|         set(mxLibPath "${MATLAB_ROOT}/bin/maci64") | ||||
|       endif() | ||||
|     else() | ||||
|       if(CMAKE_CL_64) | ||||
|         set(mxLibPath "${MATLAB_ROOT}/bin/glnxa64") | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue