Merge branch 'develop' into decisiontree-improvements
						commit
						d74e41a1c3
					
				| 
						 | 
					@ -19,7 +19,8 @@ option(GTSAM_FORCE_STATIC_LIB               "Force gtsam to be a static library,
 | 
				
			||||||
option(GTSAM_USE_QUATERNIONS                "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF)
 | 
					option(GTSAM_USE_QUATERNIONS                "Enable/Disable using an internal Quaternion representation for rotations instead of rotation matrices. If enable, Rot3::EXPMAP is enforced by default." OFF)
 | 
				
			||||||
option(GTSAM_POSE3_EXPMAP                   "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
 | 
					option(GTSAM_POSE3_EXPMAP                   "Enable/Disable using Pose3::EXPMAP as the default mode. If disabled, Pose3::FIRST_ORDER will be used." ON)
 | 
				
			||||||
option(GTSAM_ROT3_EXPMAP                    "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
 | 
					option(GTSAM_ROT3_EXPMAP                    "Ignore if GTSAM_USE_QUATERNIONS is OFF (Rot3::EXPMAP by default). Otherwise, enable Rot3::EXPMAP, or if disabled, use Rot3::CAYLEY." ON)
 | 
				
			||||||
option(GTSAM_ENABLE_CONSISTENCY_CHECKS      "Enable/Disable expensive consistency checks"       OFF)
 | 
					option(GTSAM_ENABLE_CONSISTENCY_CHECKS      "Enable/Disable expensive consistency checks" OFF)
 | 
				
			||||||
 | 
					option(GTSAM_ENABLE_MEMORY_SANITIZER        "Enable/Disable memory sanitizer" OFF)
 | 
				
			||||||
option(GTSAM_WITH_TBB                       "Use Intel Threaded Building Blocks (TBB) if available" ON)
 | 
					option(GTSAM_WITH_TBB                       "Use Intel Threaded Building Blocks (TBB) if available" ON)
 | 
				
			||||||
option(GTSAM_WITH_EIGEN_MKL                 "Eigen will use Intel MKL if available" OFF)
 | 
					option(GTSAM_WITH_EIGEN_MKL                 "Eigen will use Intel MKL if available" OFF)
 | 
				
			||||||
option(GTSAM_WITH_EIGEN_MKL_OPENMP          "Eigen, when using Intel MKL, will also use OpenMP for multithreading if available" OFF)
 | 
					option(GTSAM_WITH_EIGEN_MKL_OPENMP          "Eigen, when using Intel MKL, will also use OpenMP for multithreading if available" OFF)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -50,3 +50,10 @@ if(GTSAM_ENABLE_CONSISTENCY_CHECKS)
 | 
				
			||||||
  # This should be made PUBLIC if GTSAM_EXTRA_CONSISTENCY_CHECKS is someday used in a public .h
 | 
					  # This should be made PUBLIC if GTSAM_EXTRA_CONSISTENCY_CHECKS is someday used in a public .h
 | 
				
			||||||
  list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE GTSAM_EXTRA_CONSISTENCY_CHECKS)
 | 
					  list_append_cache(GTSAM_COMPILE_DEFINITIONS_PRIVATE GTSAM_EXTRA_CONSISTENCY_CHECKS)
 | 
				
			||||||
endif()
 | 
					endif()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if(GTSAM_ENABLE_MEMORY_SANITIZER)
 | 
				
			||||||
 | 
					  set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address  -fsanitize=leak -g")
 | 
				
			||||||
 | 
					  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address  -fsanitize=leak -g")
 | 
				
			||||||
 | 
					  set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fsanitize=address  -fsanitize=leak")
 | 
				
			||||||
 | 
					  set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -fsanitize=address  -fsanitize=leak")
 | 
				
			||||||
 | 
					endif()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -87,6 +87,7 @@ print_config("CPack Generator" "${CPACK_GENERATOR}")
 | 
				
			||||||
message(STATUS "GTSAM flags                                               ")
 | 
					message(STATUS "GTSAM flags                                               ")
 | 
				
			||||||
print_enabled_config(${GTSAM_USE_QUATERNIONS}             "Quaternions as default Rot3     ")
 | 
					print_enabled_config(${GTSAM_USE_QUATERNIONS}             "Quaternions as default Rot3     ")
 | 
				
			||||||
print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS}   "Runtime consistency checking    ")
 | 
					print_enabled_config(${GTSAM_ENABLE_CONSISTENCY_CHECKS}   "Runtime consistency checking    ")
 | 
				
			||||||
 | 
					print_enabled_config(${GTSAM_ENABLE_MEMORY_SANITIZER}     "Build with Memory Sanitizer     ")
 | 
				
			||||||
print_enabled_config(${GTSAM_ROT3_EXPMAP}                 "Rot3 retract is full ExpMap     ")
 | 
					print_enabled_config(${GTSAM_ROT3_EXPMAP}                 "Rot3 retract is full ExpMap     ")
 | 
				
			||||||
print_enabled_config(${GTSAM_POSE3_EXPMAP}                "Pose3 retract is full ExpMap    ")
 | 
					print_enabled_config(${GTSAM_POSE3_EXPMAP}                "Pose3 retract is full ExpMap    ")
 | 
				
			||||||
print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43}  "Allow features deprecated in GTSAM 4.3")
 | 
					print_enabled_config(${GTSAM_ALLOW_DEPRECATED_SINCE_V43}  "Allow features deprecated in GTSAM 4.3")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -149,6 +149,9 @@ TEST(StdOptionalSerialization, SerializTestOptionalStructPointerPointer) {
 | 
				
			||||||
  // Check that it worked
 | 
					  // Check that it worked
 | 
				
			||||||
  EXPECT(opt2.has_value());
 | 
					  EXPECT(opt2.has_value());
 | 
				
			||||||
  EXPECT(**opt2 == TestOptionalStruct(42));
 | 
					  EXPECT(**opt2 == TestOptionalStruct(42));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  delete (*opt);
 | 
				
			||||||
 | 
					  delete (*opt2);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
int main() {
 | 
					int main() {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,554 @@
 | 
				
			||||||
 | 
					/* ----------------------------------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
 | 
				
			||||||
 | 
					 * Atlanta, Georgia 30332-0415
 | 
				
			||||||
 | 
					 * All Rights Reserved
 | 
				
			||||||
 | 
					 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 * See LICENSE for the license information
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 * -------------------------------------------------------------------------- */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * @file TableFactor.cpp
 | 
				
			||||||
 | 
					 * @brief discrete factor
 | 
				
			||||||
 | 
					 * @date May 4, 2023
 | 
				
			||||||
 | 
					 * @author Yoonwoo Kim
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <gtsam/base/FastSet.h>
 | 
				
			||||||
 | 
					#include <gtsam/discrete/DecisionTreeFactor.h>
 | 
				
			||||||
 | 
					#include <gtsam/discrete/TableFactor.h>
 | 
				
			||||||
 | 
					#include <gtsam/hybrid/HybridValues.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <boost/format.hpp>
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace std;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace gtsam {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					TableFactor::TableFactor() {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					TableFactor::TableFactor(const DiscreteKeys& dkeys,
 | 
				
			||||||
 | 
					                         const TableFactor& potentials)
 | 
				
			||||||
 | 
					    : DiscreteFactor(dkeys.indices()),
 | 
				
			||||||
 | 
					      cardinalities_(potentials.cardinalities_) {
 | 
				
			||||||
 | 
					  sparse_table_ = potentials.sparse_table_;
 | 
				
			||||||
 | 
					  denominators_ = potentials.denominators_;
 | 
				
			||||||
 | 
					  sorted_dkeys_ = discreteKeys();
 | 
				
			||||||
 | 
					  sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					TableFactor::TableFactor(const DiscreteKeys& dkeys,
 | 
				
			||||||
 | 
					                         const Eigen::SparseVector<double>& table)
 | 
				
			||||||
 | 
					    : DiscreteFactor(dkeys.indices()), sparse_table_(table.size()) {
 | 
				
			||||||
 | 
					  sparse_table_ = table;
 | 
				
			||||||
 | 
					  double denom = table.size();
 | 
				
			||||||
 | 
					  for (const DiscreteKey& dkey : dkeys) {
 | 
				
			||||||
 | 
					    cardinalities_.insert(dkey);
 | 
				
			||||||
 | 
					    denom /= dkey.second;
 | 
				
			||||||
 | 
					    denominators_.insert(std::pair<Key, double>(dkey.first, denom));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  sorted_dkeys_ = discreteKeys();
 | 
				
			||||||
 | 
					  sort(sorted_dkeys_.begin(), sorted_dkeys_.end());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					Eigen::SparseVector<double> TableFactor::Convert(
 | 
				
			||||||
 | 
					    const std::vector<double>& table) {
 | 
				
			||||||
 | 
					  Eigen::SparseVector<double> sparse_table(table.size());
 | 
				
			||||||
 | 
					  // Count number of nonzero elements in table and reserving the space.
 | 
				
			||||||
 | 
					  const uint64_t nnz = std::count_if(table.begin(), table.end(),
 | 
				
			||||||
 | 
					                                     [](uint64_t i) { return i != 0; });
 | 
				
			||||||
 | 
					  sparse_table.reserve(nnz);
 | 
				
			||||||
 | 
					  for (uint64_t i = 0; i < table.size(); i++) {
 | 
				
			||||||
 | 
					    if (table[i] != 0) sparse_table.insert(i) = table[i];
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  sparse_table.pruned();
 | 
				
			||||||
 | 
					  sparse_table.data().squeeze();
 | 
				
			||||||
 | 
					  return sparse_table;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					Eigen::SparseVector<double> TableFactor::Convert(const std::string& table) {
 | 
				
			||||||
 | 
					  // Convert string to doubles.
 | 
				
			||||||
 | 
					  std::vector<double> ys;
 | 
				
			||||||
 | 
					  std::istringstream iss(table);
 | 
				
			||||||
 | 
					  std::copy(std::istream_iterator<double>(iss), std::istream_iterator<double>(),
 | 
				
			||||||
 | 
					            std::back_inserter(ys));
 | 
				
			||||||
 | 
					  return Convert(ys);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					bool TableFactor::equals(const DiscreteFactor& other, double tol) const {
 | 
				
			||||||
 | 
					  if (!dynamic_cast<const TableFactor*>(&other)) {
 | 
				
			||||||
 | 
					    return false;
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    const auto& f(static_cast<const TableFactor&>(other));
 | 
				
			||||||
 | 
					    return sparse_table_.isApprox(f.sparse_table_, tol);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					double TableFactor::operator()(const DiscreteValues& values) const {
 | 
				
			||||||
 | 
					  // a b c d => D * (C * (B * (a) + b) + c) + d
 | 
				
			||||||
 | 
					  uint64_t idx = 0, card = 1;
 | 
				
			||||||
 | 
					  for (auto it = sorted_dkeys_.rbegin(); it != sorted_dkeys_.rend(); ++it) {
 | 
				
			||||||
 | 
					    if (values.find(it->first) != values.end()) {
 | 
				
			||||||
 | 
					      idx += card * values.at(it->first);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    card *= it->second;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return sparse_table_.coeff(idx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					double TableFactor::findValue(const DiscreteValues& values) const {
 | 
				
			||||||
 | 
					  // a b c d => D * (C * (B * (a) + b) + c) + d
 | 
				
			||||||
 | 
					  uint64_t idx = 0, card = 1;
 | 
				
			||||||
 | 
					  for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
 | 
				
			||||||
 | 
					    if (values.find(*it) != values.end()) {
 | 
				
			||||||
 | 
					      idx += card * values.at(*it);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    card *= cardinality(*it);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return sparse_table_.coeff(idx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					double TableFactor::error(const DiscreteValues& values) const {
 | 
				
			||||||
 | 
					  return -log(evaluate(values));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					double TableFactor::error(const HybridValues& values) const {
 | 
				
			||||||
 | 
					  return error(values.discrete());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
 | 
				
			||||||
 | 
					  return toDecisionTreeFactor() * f;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
 | 
				
			||||||
 | 
					  DiscreteKeys dkeys = discreteKeys();
 | 
				
			||||||
 | 
					  std::vector<double> table;
 | 
				
			||||||
 | 
					  for (auto i = 0; i < sparse_table_.size(); i++) {
 | 
				
			||||||
 | 
					    table.push_back(sparse_table_.coeff(i));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  DecisionTreeFactor f(dkeys, table);
 | 
				
			||||||
 | 
					  return f;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					TableFactor TableFactor::choose(const DiscreteValues parent_assign,
 | 
				
			||||||
 | 
					                                DiscreteKeys parent_keys) const {
 | 
				
			||||||
 | 
					  if (parent_keys.empty()) return *this;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Unique representation of parent values.
 | 
				
			||||||
 | 
					  uint64_t unique = 0;
 | 
				
			||||||
 | 
					  uint64_t card = 1;
 | 
				
			||||||
 | 
					  for (auto it = keys_.rbegin(); it != keys_.rend(); ++it) {
 | 
				
			||||||
 | 
					    if (parent_assign.find(*it) != parent_assign.end()) {
 | 
				
			||||||
 | 
					      unique += parent_assign.at(*it) * card;
 | 
				
			||||||
 | 
					      card *= cardinality(*it);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Find child DiscreteKeys
 | 
				
			||||||
 | 
					  DiscreteKeys child_dkeys;
 | 
				
			||||||
 | 
					  std::sort(parent_keys.begin(), parent_keys.end());
 | 
				
			||||||
 | 
					  std::set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
 | 
				
			||||||
 | 
					                      parent_keys.begin(), parent_keys.end(),
 | 
				
			||||||
 | 
					                      std::back_inserter(child_dkeys));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Create child sparse table to populate.
 | 
				
			||||||
 | 
					  uint64_t child_card = 1;
 | 
				
			||||||
 | 
					  for (const DiscreteKey& child_dkey : child_dkeys)
 | 
				
			||||||
 | 
					    child_card *= child_dkey.second;
 | 
				
			||||||
 | 
					  Eigen::SparseVector<double> child_sparse_table_(child_card);
 | 
				
			||||||
 | 
					  child_sparse_table_.reserve(child_card);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Populate child sparse table.
 | 
				
			||||||
 | 
					  for (SparseIt it(sparse_table_); it; ++it) {
 | 
				
			||||||
 | 
					    // Create unique representation of parent keys
 | 
				
			||||||
 | 
					    uint64_t parent_unique = uniqueRep(parent_keys, it.index());
 | 
				
			||||||
 | 
					    // Populate the table
 | 
				
			||||||
 | 
					    if (parent_unique == unique) {
 | 
				
			||||||
 | 
					      uint64_t idx = uniqueRep(child_dkeys, it.index());
 | 
				
			||||||
 | 
					      child_sparse_table_.insert(idx) = it.value();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  child_sparse_table_.pruned();
 | 
				
			||||||
 | 
					  child_sparse_table_.data().squeeze();
 | 
				
			||||||
 | 
					  return TableFactor(child_dkeys, child_sparse_table_);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					double TableFactor::safe_div(const double& a, const double& b) {
 | 
				
			||||||
 | 
					  // The use for safe_div is when we divide the product factor by the sum
 | 
				
			||||||
 | 
					  // factor. If the product or sum is zero, we accord zero probability to the
 | 
				
			||||||
 | 
					  // event.
 | 
				
			||||||
 | 
					  return (a == 0 || b == 0) ? 0 : (a / b);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
 | 
				
			||||||
 | 
					  cout << s;
 | 
				
			||||||
 | 
					  cout << " f[";
 | 
				
			||||||
 | 
					  for (auto&& key : keys())
 | 
				
			||||||
 | 
					    cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
 | 
				
			||||||
 | 
					  cout << " ]" << endl;
 | 
				
			||||||
 | 
					  for (SparseIt it(sparse_table_); it; ++it) {
 | 
				
			||||||
 | 
					    DiscreteValues assignment = findAssignments(it.index());
 | 
				
			||||||
 | 
					    for (auto&& kv : assignment) {
 | 
				
			||||||
 | 
					      cout << "(" << formatter(kv.first) << ", " << kv.second << ")";
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    cout << " | " << it.value() << " | " << it.index() << endl;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
 | 
				
			||||||
 | 
					  if (keys_.empty() && sparse_table_.nonZeros() == 0)
 | 
				
			||||||
 | 
					    return f;
 | 
				
			||||||
 | 
					  else if (f.keys_.empty() && f.sparse_table_.nonZeros() == 0)
 | 
				
			||||||
 | 
					    return *this;
 | 
				
			||||||
 | 
					  // 1. Identify keys for contract and free modes.
 | 
				
			||||||
 | 
					  DiscreteKeys contract_dkeys = contractDkeys(f);
 | 
				
			||||||
 | 
					  DiscreteKeys f_free_dkeys = f.freeDkeys(*this);
 | 
				
			||||||
 | 
					  DiscreteKeys union_dkeys = unionDkeys(f);
 | 
				
			||||||
 | 
					  // 2. Create hash table for input factor f
 | 
				
			||||||
 | 
					  unordered_map<uint64_t, AssignValList> map_f =
 | 
				
			||||||
 | 
					      f.createMap(contract_dkeys, f_free_dkeys);
 | 
				
			||||||
 | 
					  // 3. Initialize multiplied factor.
 | 
				
			||||||
 | 
					  uint64_t card = 1;
 | 
				
			||||||
 | 
					  for (auto u_dkey : union_dkeys) card *= u_dkey.second;
 | 
				
			||||||
 | 
					  Eigen::SparseVector<double> mult_sparse_table(card);
 | 
				
			||||||
 | 
					  mult_sparse_table.reserve(card);
 | 
				
			||||||
 | 
					  // 3. Multiply.
 | 
				
			||||||
 | 
					  for (SparseIt it(sparse_table_); it; ++it) {
 | 
				
			||||||
 | 
					    uint64_t contract_unique = uniqueRep(contract_dkeys, it.index());
 | 
				
			||||||
 | 
					    if (map_f.find(contract_unique) == map_f.end()) continue;
 | 
				
			||||||
 | 
					    for (auto assignVal : map_f[contract_unique]) {
 | 
				
			||||||
 | 
					      uint64_t union_idx = unionRep(union_dkeys, assignVal.first, it.index());
 | 
				
			||||||
 | 
					      mult_sparse_table.insert(union_idx) = op(it.value(), assignVal.second);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // 4. Free unused memory.
 | 
				
			||||||
 | 
					  mult_sparse_table.pruned();
 | 
				
			||||||
 | 
					  mult_sparse_table.data().squeeze();
 | 
				
			||||||
 | 
					  // 5. Create union keys and return.
 | 
				
			||||||
 | 
					  return TableFactor(union_dkeys, mult_sparse_table);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					DiscreteKeys TableFactor::contractDkeys(const TableFactor& f) const {
 | 
				
			||||||
 | 
					  // Find contract modes.
 | 
				
			||||||
 | 
					  DiscreteKeys contract;
 | 
				
			||||||
 | 
					  set_intersection(sorted_dkeys_.begin(), sorted_dkeys_.end(),
 | 
				
			||||||
 | 
					                   f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
 | 
				
			||||||
 | 
					                   back_inserter(contract));
 | 
				
			||||||
 | 
					  return contract;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					DiscreteKeys TableFactor::freeDkeys(const TableFactor& f) const {
 | 
				
			||||||
 | 
					  // Find free modes.
 | 
				
			||||||
 | 
					  DiscreteKeys free;
 | 
				
			||||||
 | 
					  set_difference(sorted_dkeys_.begin(), sorted_dkeys_.end(),
 | 
				
			||||||
 | 
					                 f.sorted_dkeys_.begin(), f.sorted_dkeys_.end(),
 | 
				
			||||||
 | 
					                 back_inserter(free));
 | 
				
			||||||
 | 
					  return free;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					DiscreteKeys TableFactor::unionDkeys(const TableFactor& f) const {
 | 
				
			||||||
 | 
					  // Find union modes.
 | 
				
			||||||
 | 
					  DiscreteKeys union_dkeys;
 | 
				
			||||||
 | 
					  set_union(sorted_dkeys_.begin(), sorted_dkeys_.end(), f.sorted_dkeys_.begin(),
 | 
				
			||||||
 | 
					            f.sorted_dkeys_.end(), back_inserter(union_dkeys));
 | 
				
			||||||
 | 
					  return union_dkeys;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					uint64_t TableFactor::unionRep(const DiscreteKeys& union_keys,
 | 
				
			||||||
 | 
					                               const DiscreteValues& f_free,
 | 
				
			||||||
 | 
					                               const uint64_t idx) const {
 | 
				
			||||||
 | 
					  uint64_t union_idx = 0, card = 1;
 | 
				
			||||||
 | 
					  for (auto it = union_keys.rbegin(); it != union_keys.rend(); it++) {
 | 
				
			||||||
 | 
					    if (f_free.find(it->first) == f_free.end()) {
 | 
				
			||||||
 | 
					      union_idx += keyValueForIndex(it->first, idx) * card;
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      union_idx += f_free.at(it->first) * card;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    card *= it->second;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return union_idx;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					unordered_map<uint64_t, TableFactor::AssignValList> TableFactor::createMap(
 | 
				
			||||||
 | 
					    const DiscreteKeys& contract, const DiscreteKeys& free) const {
 | 
				
			||||||
 | 
					  // 1. Initialize map.
 | 
				
			||||||
 | 
					  unordered_map<uint64_t, AssignValList> map_f;
 | 
				
			||||||
 | 
					  // 2. Iterate over nonzero elements.
 | 
				
			||||||
 | 
					  for (SparseIt it(sparse_table_); it; ++it) {
 | 
				
			||||||
 | 
					    // 3. Create unique representation of contract modes.
 | 
				
			||||||
 | 
					    uint64_t unique_rep = uniqueRep(contract, it.index());
 | 
				
			||||||
 | 
					    // 4. Create assignment for free modes.
 | 
				
			||||||
 | 
					    DiscreteValues free_assignments;
 | 
				
			||||||
 | 
					    for (auto& key : free)
 | 
				
			||||||
 | 
					      free_assignments[key.first] = keyValueForIndex(key.first, it.index());
 | 
				
			||||||
 | 
					    // 5. Populate map.
 | 
				
			||||||
 | 
					    if (map_f.find(unique_rep) == map_f.end()) {
 | 
				
			||||||
 | 
					      map_f[unique_rep] = {make_pair(free_assignments, it.value())};
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      map_f[unique_rep].push_back(make_pair(free_assignments, it.value()));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return map_f;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					uint64_t TableFactor::uniqueRep(const DiscreteKeys& dkeys,
 | 
				
			||||||
 | 
					                                const uint64_t idx) const {
 | 
				
			||||||
 | 
					  if (dkeys.empty()) return 0;
 | 
				
			||||||
 | 
					  uint64_t unique_rep = 0, card = 1;
 | 
				
			||||||
 | 
					  for (auto it = dkeys.rbegin(); it != dkeys.rend(); it++) {
 | 
				
			||||||
 | 
					    unique_rep += keyValueForIndex(it->first, idx) * card;
 | 
				
			||||||
 | 
					    card *= it->second;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return unique_rep;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					uint64_t TableFactor::uniqueRep(const DiscreteValues& assignments) const {
 | 
				
			||||||
 | 
					  if (assignments.empty()) return 0;
 | 
				
			||||||
 | 
					  uint64_t unique_rep = 0, card = 1;
 | 
				
			||||||
 | 
					  for (auto it = assignments.rbegin(); it != assignments.rend(); it++) {
 | 
				
			||||||
 | 
					    unique_rep += it->second * card;
 | 
				
			||||||
 | 
					    card *= cardinalities_.at(it->first);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return unique_rep;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					DiscreteValues TableFactor::findAssignments(const uint64_t idx) const {
 | 
				
			||||||
 | 
					  DiscreteValues assignment;
 | 
				
			||||||
 | 
					  for (Key key : keys_) {
 | 
				
			||||||
 | 
					    assignment[key] = keyValueForIndex(key, idx);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return assignment;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					TableFactor::shared_ptr TableFactor::combine(size_t nrFrontals,
 | 
				
			||||||
 | 
					                                             Binary op) const {
 | 
				
			||||||
 | 
					  if (nrFrontals > size()) {
 | 
				
			||||||
 | 
					    throw invalid_argument(
 | 
				
			||||||
 | 
					        "TableFactor::combine: invalid number of frontal "
 | 
				
			||||||
 | 
					        "keys " +
 | 
				
			||||||
 | 
					        to_string(nrFrontals) + ", nr.keys=" + std::to_string(size()));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // Find remaining keys.
 | 
				
			||||||
 | 
					  DiscreteKeys remain_dkeys;
 | 
				
			||||||
 | 
					  uint64_t card = 1;
 | 
				
			||||||
 | 
					  for (auto i = nrFrontals; i < keys_.size(); i++) {
 | 
				
			||||||
 | 
					    remain_dkeys.push_back(discreteKey(i));
 | 
				
			||||||
 | 
					    card *= cardinality(keys_[i]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // Create combined table.
 | 
				
			||||||
 | 
					  Eigen::SparseVector<double> combined_table(card);
 | 
				
			||||||
 | 
					  combined_table.reserve(sparse_table_.nonZeros());
 | 
				
			||||||
 | 
					  // Populate combined table.
 | 
				
			||||||
 | 
					  for (SparseIt it(sparse_table_); it; ++it) {
 | 
				
			||||||
 | 
					    uint64_t idx = uniqueRep(remain_dkeys, it.index());
 | 
				
			||||||
 | 
					    double new_val = op(combined_table.coeff(idx), it.value());
 | 
				
			||||||
 | 
					    combined_table.coeffRef(idx) = new_val;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // Free unused memory.
 | 
				
			||||||
 | 
					  combined_table.pruned();
 | 
				
			||||||
 | 
					  combined_table.data().squeeze();
 | 
				
			||||||
 | 
					  return std::make_shared<TableFactor>(remain_dkeys, combined_table);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					TableFactor::shared_ptr TableFactor::combine(const Ordering& frontalKeys,
 | 
				
			||||||
 | 
					                                             Binary op) const {
 | 
				
			||||||
 | 
					  if (frontalKeys.size() > size()) {
 | 
				
			||||||
 | 
					    throw invalid_argument(
 | 
				
			||||||
 | 
					        "TableFactor::combine: invalid number of frontal "
 | 
				
			||||||
 | 
					        "keys " +
 | 
				
			||||||
 | 
					        std::to_string(frontalKeys.size()) +
 | 
				
			||||||
 | 
					        ", nr.keys=" + std::to_string(size()));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // Find remaining keys.
 | 
				
			||||||
 | 
					  DiscreteKeys remain_dkeys;
 | 
				
			||||||
 | 
					  uint64_t card = 1;
 | 
				
			||||||
 | 
					  for (Key key : keys_) {
 | 
				
			||||||
 | 
					    if (std::find(frontalKeys.begin(), frontalKeys.end(), key) ==
 | 
				
			||||||
 | 
					        frontalKeys.end()) {
 | 
				
			||||||
 | 
					      remain_dkeys.emplace_back(key, cardinality(key));
 | 
				
			||||||
 | 
					      card *= cardinality(key);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // Create combined table.
 | 
				
			||||||
 | 
					  Eigen::SparseVector<double> combined_table(card);
 | 
				
			||||||
 | 
					  combined_table.reserve(sparse_table_.nonZeros());
 | 
				
			||||||
 | 
					  // Populate combined table.
 | 
				
			||||||
 | 
					  for (SparseIt it(sparse_table_); it; ++it) {
 | 
				
			||||||
 | 
					    uint64_t idx = uniqueRep(remain_dkeys, it.index());
 | 
				
			||||||
 | 
					    double new_val = op(combined_table.coeff(idx), it.value());
 | 
				
			||||||
 | 
					    combined_table.coeffRef(idx) = new_val;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // Free unused memory.
 | 
				
			||||||
 | 
					  combined_table.pruned();
 | 
				
			||||||
 | 
					  combined_table.data().squeeze();
 | 
				
			||||||
 | 
					  return std::make_shared<TableFactor>(remain_dkeys, combined_table);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					size_t TableFactor::keyValueForIndex(Key target_key, uint64_t index) const {
 | 
				
			||||||
 | 
					  // http://phrogz.net/lazy-cartesian-product
 | 
				
			||||||
 | 
					  return (index / denominators_.at(target_key)) % cardinality(target_key);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					std::vector<std::pair<DiscreteValues, double>> TableFactor::enumerate() const {
 | 
				
			||||||
 | 
					  // Get all possible assignments
 | 
				
			||||||
 | 
					  std::vector<std::pair<Key, size_t>> pairs = discreteKeys();
 | 
				
			||||||
 | 
					  // Reverse to make cartesian product output a more natural ordering.
 | 
				
			||||||
 | 
					  std::vector<std::pair<Key, size_t>> rpairs(pairs.rbegin(), pairs.rend());
 | 
				
			||||||
 | 
					  const auto assignments = DiscreteValues::CartesianProduct(rpairs);
 | 
				
			||||||
 | 
					  // Construct unordered_map with values
 | 
				
			||||||
 | 
					  std::vector<std::pair<DiscreteValues, double>> result;
 | 
				
			||||||
 | 
					  for (const auto& assignment : assignments) {
 | 
				
			||||||
 | 
					    result.emplace_back(assignment, operator()(assignment));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					DiscreteKeys TableFactor::discreteKeys() const {
 | 
				
			||||||
 | 
					  DiscreteKeys result;
 | 
				
			||||||
 | 
					  for (auto&& key : keys()) {
 | 
				
			||||||
 | 
					    DiscreteKey dkey(key, cardinality(key));
 | 
				
			||||||
 | 
					    if (std::find(result.begin(), result.end(), dkey) == result.end()) {
 | 
				
			||||||
 | 
					      result.push_back(dkey);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Print out header.
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					string TableFactor::markdown(const KeyFormatter& keyFormatter,
 | 
				
			||||||
 | 
					                             const Names& names) const {
 | 
				
			||||||
 | 
					  stringstream ss;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Print out header.
 | 
				
			||||||
 | 
					  ss << "|";
 | 
				
			||||||
 | 
					  for (auto& key : keys()) {
 | 
				
			||||||
 | 
					    ss << keyFormatter(key) << "|";
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  ss << "value|\n";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Print out separator with alignment hints.
 | 
				
			||||||
 | 
					  ss << "|";
 | 
				
			||||||
 | 
					  for (size_t j = 0; j < size(); j++) ss << ":-:|";
 | 
				
			||||||
 | 
					  ss << ":-:|\n";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Print out all rows.
 | 
				
			||||||
 | 
					  for (SparseIt it(sparse_table_); it; ++it) {
 | 
				
			||||||
 | 
					    DiscreteValues assignment = findAssignments(it.index());
 | 
				
			||||||
 | 
					    ss << "|";
 | 
				
			||||||
 | 
					    for (auto& key : keys()) {
 | 
				
			||||||
 | 
					      size_t index = assignment.at(key);
 | 
				
			||||||
 | 
					      ss << DiscreteValues::Translate(names, key, index) << "|";
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    ss << it.value() << "|\n";
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return ss.str();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					string TableFactor::html(const KeyFormatter& keyFormatter,
 | 
				
			||||||
 | 
					                         const Names& names) const {
 | 
				
			||||||
 | 
					  stringstream ss;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Print out preamble.
 | 
				
			||||||
 | 
					  ss << "<div>\n<table class='TableFactor'>\n  <thead>\n";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Print out header row.
 | 
				
			||||||
 | 
					  ss << "    <tr>";
 | 
				
			||||||
 | 
					  for (auto& key : keys()) {
 | 
				
			||||||
 | 
					    ss << "<th>" << keyFormatter(key) << "</th>";
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  ss << "<th>value</th></tr>\n";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Finish header and start body.
 | 
				
			||||||
 | 
					  ss << "  </thead>\n  <tbody>\n";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Print out all rows.
 | 
				
			||||||
 | 
					  for (SparseIt it(sparse_table_); it; ++it) {
 | 
				
			||||||
 | 
					    DiscreteValues assignment = findAssignments(it.index());
 | 
				
			||||||
 | 
					    ss << "    <tr>";
 | 
				
			||||||
 | 
					    for (auto& key : keys()) {
 | 
				
			||||||
 | 
					      size_t index = assignment.at(key);
 | 
				
			||||||
 | 
					      ss << "<th>" << DiscreteValues::Translate(names, key, index) << "</th>";
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    ss << "<td>" << it.value() << "</td>";  // value
 | 
				
			||||||
 | 
					    ss << "</tr>\n";
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  ss << "  </tbody>\n</table>\n</div>";
 | 
				
			||||||
 | 
					  return ss.str();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					TableFactor TableFactor::prune(size_t maxNrAssignments) const {
 | 
				
			||||||
 | 
					  const size_t N = maxNrAssignments;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Get the probabilities in the TableFactor so we can threshold.
 | 
				
			||||||
 | 
					  vector<pair<Eigen::Index, double>> probabilities;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Store non-zero probabilities along with their indices in a vector.
 | 
				
			||||||
 | 
					  for (SparseIt it(sparse_table_); it; ++it) {
 | 
				
			||||||
 | 
					    probabilities.emplace_back(it.index(), it.value());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // The number of probabilities can be lower than max_leaves.
 | 
				
			||||||
 | 
					  if (probabilities.size() <= N) return *this;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Sort the vector in descending order based on the element values.
 | 
				
			||||||
 | 
					  sort(probabilities.begin(), probabilities.end(),
 | 
				
			||||||
 | 
					       [](const std::pair<Eigen::Index, double>& a,
 | 
				
			||||||
 | 
					          const std::pair<Eigen::Index, double>& b) {
 | 
				
			||||||
 | 
					         return a.second > b.second;
 | 
				
			||||||
 | 
					       });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Keep the largest N probabilities in the vector.
 | 
				
			||||||
 | 
					  if (probabilities.size() > N) probabilities.resize(N);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Create pruned sparse vector.
 | 
				
			||||||
 | 
					  Eigen::SparseVector<double> pruned_vec(sparse_table_.size());
 | 
				
			||||||
 | 
					  pruned_vec.reserve(probabilities.size());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Populate pruned sparse vector.
 | 
				
			||||||
 | 
					  for (const auto& prob : probabilities) {
 | 
				
			||||||
 | 
					    pruned_vec.insert(prob.first) = prob.second;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Create pruned decision tree factor and return.
 | 
				
			||||||
 | 
					  return TableFactor(this->discreteKeys(), pruned_vec);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************ */
 | 
				
			||||||
 | 
					}  // namespace gtsam
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,340 @@
 | 
				
			||||||
 | 
					/* ----------------------------------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
 | 
				
			||||||
 | 
					 * Atlanta, Georgia 30332-0415
 | 
				
			||||||
 | 
					 * All Rights Reserved
 | 
				
			||||||
 | 
					 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 * See LICENSE for the license information
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 * -------------------------------------------------------------------------- */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * @file TableFactor.h
 | 
				
			||||||
 | 
					 * @date May 4, 2023
 | 
				
			||||||
 | 
					 * @author Yoonwoo Kim
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <gtsam/discrete/DiscreteFactor.h>
 | 
				
			||||||
 | 
					#include <gtsam/discrete/DiscreteKey.h>
 | 
				
			||||||
 | 
					#include <gtsam/inference/Ordering.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <Eigen/Sparse>
 | 
				
			||||||
 | 
					#include <algorithm>
 | 
				
			||||||
 | 
					#include <map>
 | 
				
			||||||
 | 
					#include <memory>
 | 
				
			||||||
 | 
					#include <stdexcept>
 | 
				
			||||||
 | 
					#include <string>
 | 
				
			||||||
 | 
					#include <utility>
 | 
				
			||||||
 | 
					#include <vector>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace gtsam {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HybridValues;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/**
 | 
				
			||||||
 | 
					 * A discrete probabilistic factor optimized for sparsity.
 | 
				
			||||||
 | 
					 * Uses sparse_table_ to store only the nonzero probabilities.
 | 
				
			||||||
 | 
					 * Computes the assigned value for the key using the ordering which the
 | 
				
			||||||
 | 
					 * nonzero probabilties are stored in. (lazy cartesian product)
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 * @ingroup discrete
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					class GTSAM_EXPORT TableFactor : public DiscreteFactor {
 | 
				
			||||||
 | 
					 protected:
 | 
				
			||||||
 | 
					  /// Map of Keys and their cardinalities.
 | 
				
			||||||
 | 
					  std::map<Key, size_t> cardinalities_;
 | 
				
			||||||
 | 
					  /// SparseVector of nonzero probabilities.
 | 
				
			||||||
 | 
					  Eigen::SparseVector<double> sparse_table_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 private:
 | 
				
			||||||
 | 
					  /// Map of Keys and their denominators used in keyValueForIndex.
 | 
				
			||||||
 | 
					  std::map<Key, size_t> denominators_;
 | 
				
			||||||
 | 
					  /// Sorted DiscreteKeys to use internally.
 | 
				
			||||||
 | 
					  DiscreteKeys sorted_dkeys_;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * @brief Uses lazy cartesian product to find nth entry in the cartesian
 | 
				
			||||||
 | 
					   * product of arrays in O(1) 
 | 
				
			||||||
 | 
					   * Example) 
 | 
				
			||||||
 | 
					   *   v0 | v1 | val 
 | 
				
			||||||
 | 
					   *    0 |  0 |  10 
 | 
				
			||||||
 | 
					   *    0 |  1 |  21
 | 
				
			||||||
 | 
					   *    1 |  0 |  32
 | 
				
			||||||
 | 
					   *    1 |  1 |  43
 | 
				
			||||||
 | 
					   *   keyValueForIndex(v1, 2) = 0
 | 
				
			||||||
 | 
					   * @param target_key nth entry's key to find out its assigned value
 | 
				
			||||||
 | 
					   * @param index nth entry in the sparse vector
 | 
				
			||||||
 | 
					   * @return TableFactor
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  size_t keyValueForIndex(Key target_key, uint64_t index) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * @brief Return ith key in keys_ as a DiscreteKey
 | 
				
			||||||
 | 
					   * @param i ith key in keys_
 | 
				
			||||||
 | 
					   * @return DiscreteKey
 | 
				
			||||||
 | 
					   * */ 
 | 
				
			||||||
 | 
					  DiscreteKey discreteKey(size_t i) const {
 | 
				
			||||||
 | 
					    return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Convert probability table given as doubles to SparseVector.
 | 
				
			||||||
 | 
					  /// Example) {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5} 
 | 
				
			||||||
 | 
					  static Eigen::SparseVector<double> Convert(const std::vector<double>& table);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Convert probability table given as string to SparseVector.
 | 
				
			||||||
 | 
					  static Eigen::SparseVector<double> Convert(const std::string& table);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  // typedefs needed to play nice with gtsam
 | 
				
			||||||
 | 
					  typedef TableFactor This;
 | 
				
			||||||
 | 
					  typedef DiscreteFactor Base;  ///< Typedef to base class
 | 
				
			||||||
 | 
					  typedef std::shared_ptr<TableFactor> shared_ptr;
 | 
				
			||||||
 | 
					  typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
 | 
				
			||||||
 | 
					  typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
 | 
				
			||||||
 | 
					  using Binary = std::function<double(const double, const double)>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 public:
 | 
				
			||||||
 | 
					  /** The Real ring with addition and multiplication */
 | 
				
			||||||
 | 
					  struct Ring {
 | 
				
			||||||
 | 
					    static inline double zero() { return 0.0; }
 | 
				
			||||||
 | 
					    static inline double one() { return 1.0; }
 | 
				
			||||||
 | 
					    static inline double add(const double& a, const double& b) { return a + b; }
 | 
				
			||||||
 | 
					    static inline double max(const double& a, const double& b) {
 | 
				
			||||||
 | 
					      return std::max(a, b);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    static inline double mul(const double& a, const double& b) { return a * b; }
 | 
				
			||||||
 | 
					    static inline double div(const double& a, const double& b) {
 | 
				
			||||||
 | 
					      return (a == 0 || b == 0) ? 0 : (a / b);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    static inline double id(const double& x) { return x; }
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// @name Standard Constructors
 | 
				
			||||||
 | 
					  /// @{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Default constructor for I/O */
 | 
				
			||||||
 | 
					  TableFactor();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Constructor from DiscreteKeys and TableFactor */
 | 
				
			||||||
 | 
					  TableFactor(const DiscreteKeys& keys, const TableFactor& potentials);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Constructor from sparse_table */
 | 
				
			||||||
 | 
					  TableFactor(const DiscreteKeys& keys,
 | 
				
			||||||
 | 
					              const Eigen::SparseVector<double>& table);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Constructor from doubles */
 | 
				
			||||||
 | 
					  TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
 | 
				
			||||||
 | 
					      : TableFactor(keys, Convert(table)) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /** Constructor from string */
 | 
				
			||||||
 | 
					  TableFactor(const DiscreteKeys& keys, const std::string& table)
 | 
				
			||||||
 | 
					      : TableFactor(keys, Convert(table)) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Single-key specialization
 | 
				
			||||||
 | 
					  template <class SOURCE>
 | 
				
			||||||
 | 
					  TableFactor(const DiscreteKey& key, SOURCE table)
 | 
				
			||||||
 | 
					      : TableFactor(DiscreteKeys{key}, table) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Single-key specialization, with vector of doubles.
 | 
				
			||||||
 | 
					  TableFactor(const DiscreteKey& key, const std::vector<double>& row)
 | 
				
			||||||
 | 
					      : TableFactor(DiscreteKeys{key}, row) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// @}
 | 
				
			||||||
 | 
					  /// @name Testable
 | 
				
			||||||
 | 
					  /// @{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// equality
 | 
				
			||||||
 | 
					  bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // print
 | 
				
			||||||
 | 
					  void print(
 | 
				
			||||||
 | 
					      const std::string& s = "TableFactor:\n",
 | 
				
			||||||
 | 
					      const KeyFormatter& formatter = DefaultKeyFormatter) const override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // /// @}
 | 
				
			||||||
 | 
					  // /// @name Standard Interface
 | 
				
			||||||
 | 
					  // /// @{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Calculate probability for given values `x`,
 | 
				
			||||||
 | 
					  /// is just look up in TableFactor.
 | 
				
			||||||
 | 
					  double evaluate(const DiscreteValues& values) const {
 | 
				
			||||||
 | 
					    return operator()(values);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Evaluate probability distribution, sugar.
 | 
				
			||||||
 | 
					  double operator()(const DiscreteValues& values) const override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Calculate error for DiscreteValues `x`, is -log(probability).
 | 
				
			||||||
 | 
					  double error(const DiscreteValues& values) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// multiply two TableFactors
 | 
				
			||||||
 | 
					  TableFactor operator*(const TableFactor& f) const {
 | 
				
			||||||
 | 
					    return apply(f, Ring::mul);
 | 
				
			||||||
 | 
					  };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// multiple with DecisionTreeFactor
 | 
				
			||||||
 | 
					  DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static double safe_div(const double& a, const double& b);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  size_t cardinality(Key j) const { return cardinalities_.at(j); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// divide by factor f (safely)
 | 
				
			||||||
 | 
					  TableFactor operator/(const TableFactor& f) const {
 | 
				
			||||||
 | 
					    return apply(f, safe_div);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Convert into a decisiontree
 | 
				
			||||||
 | 
					  DecisionTreeFactor toDecisionTreeFactor() const override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Create a TableFactor that is a subset of this TableFactor
 | 
				
			||||||
 | 
					  TableFactor choose(const DiscreteValues assignments,
 | 
				
			||||||
 | 
					                     DiscreteKeys parent_keys) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Create new factor by summing all values with the same separator values
 | 
				
			||||||
 | 
					  shared_ptr sum(size_t nrFrontals) const {
 | 
				
			||||||
 | 
					    return combine(nrFrontals, Ring::add);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Create new factor by summing all values with the same separator values
 | 
				
			||||||
 | 
					  shared_ptr sum(const Ordering& keys) const {
 | 
				
			||||||
 | 
					    return combine(keys, Ring::add);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Create new factor by maximizing over all values with the same separator.
 | 
				
			||||||
 | 
					  shared_ptr max(size_t nrFrontals) const {
 | 
				
			||||||
 | 
					    return combine(nrFrontals, Ring::max);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Create new factor by maximizing over all values with the same separator.
 | 
				
			||||||
 | 
					  shared_ptr max(const Ordering& keys) const {
 | 
				
			||||||
 | 
					    return combine(keys, Ring::max);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// @}
 | 
				
			||||||
 | 
					  /// @name Advanced Interface
 | 
				
			||||||
 | 
					  /// @{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Apply binary operator (*this) "op" f
 | 
				
			||||||
 | 
					   * @param f the second argument for op
 | 
				
			||||||
 | 
					   * @param op a binary operator that operates on TableFactor
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  TableFactor apply(const TableFactor& f, Binary op) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Return keys in contract mode.
 | 
				
			||||||
 | 
					  DiscreteKeys contractDkeys(const TableFactor& f) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Return keys in free mode.
 | 
				
			||||||
 | 
					  DiscreteKeys freeDkeys(const TableFactor& f) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Return union of DiscreteKeys in two factors.
 | 
				
			||||||
 | 
					  DiscreteKeys unionDkeys(const TableFactor& f) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Create unique representation of union modes.
 | 
				
			||||||
 | 
					  uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign,
 | 
				
			||||||
 | 
					                    const uint64_t idx) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Create a hash map of input factor with assignment of contract modes as
 | 
				
			||||||
 | 
					  /// keys and vector of hashed assignment of free modes and value as values.
 | 
				
			||||||
 | 
					  std::unordered_map<uint64_t, AssignValList> createMap(
 | 
				
			||||||
 | 
					      const DiscreteKeys& contract, const DiscreteKeys& free) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Create unique representation
 | 
				
			||||||
 | 
					  uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Create unique representation with DiscreteValues
 | 
				
			||||||
 | 
					  uint64_t uniqueRep(const DiscreteValues& assignments) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Find DiscreteValues for corresponding index.
 | 
				
			||||||
 | 
					  DiscreteValues findAssignments(const uint64_t idx) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Find value for corresponding DiscreteValues.
 | 
				
			||||||
 | 
					  double findValue(const DiscreteValues& values) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Combine frontal variables using binary operator "op"
 | 
				
			||||||
 | 
					   * @param nrFrontals nr. of frontal to combine variables in this factor
 | 
				
			||||||
 | 
					   * @param op a binary operator that operates on TableFactor
 | 
				
			||||||
 | 
					   * @return shared pointer to newly created TableFactor
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  shared_ptr combine(size_t nrFrontals, Binary op) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Combine frontal variables in an Ordering using binary operator "op"
 | 
				
			||||||
 | 
					   * @param nrFrontals nr. of frontal to combine variables in this factor
 | 
				
			||||||
 | 
					   * @param op a binary operator that operates on TableFactor
 | 
				
			||||||
 | 
					   * @return shared pointer to newly created TableFactor
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  shared_ptr combine(const Ordering& keys, Binary op) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Enumerate all values into a map from values to double.
 | 
				
			||||||
 | 
					  std::vector<std::pair<DiscreteValues, double>> enumerate() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Return all the discrete keys associated with this factor.
 | 
				
			||||||
 | 
					  DiscreteKeys discreteKeys() const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * @brief Prune the decision tree of discrete variables.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * Pruning will set the values to be "pruned" to 0 indicating a 0
 | 
				
			||||||
 | 
					   * probability. An assignment is pruned if it is not in the top
 | 
				
			||||||
 | 
					   * `maxNrAssignments` values.
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * A violation can occur if there are more
 | 
				
			||||||
 | 
					   * duplicate values than `maxNrAssignments`. A violation here is the need to
 | 
				
			||||||
 | 
					   * un-prune the decision tree (e.g. all assignment values are 1.0). We could
 | 
				
			||||||
 | 
					   * have another case where some subset of duplicates exist (e.g. for a tree
 | 
				
			||||||
 | 
					   * with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is
 | 
				
			||||||
 | 
					   * not a violation since the for `maxNrAssignments=5` the top values are (1,
 | 
				
			||||||
 | 
					   * 0.8).
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param maxNrAssignments The maximum number of assignments to keep.
 | 
				
			||||||
 | 
					   * @return TableFactor
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  TableFactor prune(size_t maxNrAssignments) const;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// @}
 | 
				
			||||||
 | 
					  /// @name Wrapper support
 | 
				
			||||||
 | 
					  /// @{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * @brief Render as markdown table
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param keyFormatter GTSAM-style Key formatter.
 | 
				
			||||||
 | 
					   * @param names optional, category names corresponding to choices.
 | 
				
			||||||
 | 
					   * @return std::string a markdown string.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
 | 
				
			||||||
 | 
					                       const Names& names = {}) const override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * @brief Render as html table
 | 
				
			||||||
 | 
					   *
 | 
				
			||||||
 | 
					   * @param keyFormatter GTSAM-style Key formatter.
 | 
				
			||||||
 | 
					   * @param names optional, category names corresponding to choices.
 | 
				
			||||||
 | 
					   * @return std::string a html string.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
 | 
				
			||||||
 | 
					                   const Names& names = {}) const override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// @}
 | 
				
			||||||
 | 
					  /// @name HybridValues methods.
 | 
				
			||||||
 | 
					  /// @{
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /**
 | 
				
			||||||
 | 
					   * Calculate error for HybridValues `x`, is -log(probability)
 | 
				
			||||||
 | 
					   * Simply dispatches to DiscreteValues version.
 | 
				
			||||||
 | 
					   */
 | 
				
			||||||
 | 
					  double error(const HybridValues& values) const override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// @}
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// traits
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct traits<TableFactor> : public Testable<TableFactor> {};
 | 
				
			||||||
 | 
					}  // namespace gtsam
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,360 @@
 | 
				
			||||||
 | 
					/* ----------------------------------------------------------------------------
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
 | 
				
			||||||
 | 
					 * Atlanta, Georgia 30332-0415
 | 
				
			||||||
 | 
					 * All Rights Reserved
 | 
				
			||||||
 | 
					 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 * See LICENSE for the license information
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 * -------------------------------------------------------------------------- */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/*
 | 
				
			||||||
 | 
					 * testTableFactor.cpp
 | 
				
			||||||
 | 
					 *
 | 
				
			||||||
 | 
					 *  @date Feb 15, 2023
 | 
				
			||||||
 | 
					 *  @author Yoonwoo Kim
 | 
				
			||||||
 | 
					 */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <CppUnitLite/TestHarness.h>
 | 
				
			||||||
 | 
					#include <gtsam/base/Testable.h>
 | 
				
			||||||
 | 
					#include <gtsam/base/serializationTestHelpers.h>
 | 
				
			||||||
 | 
					#include <gtsam/discrete/DiscreteDistribution.h>
 | 
				
			||||||
 | 
					#include <gtsam/discrete/Signature.h>
 | 
				
			||||||
 | 
					#include <gtsam/discrete/TableFactor.h>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <chrono>
 | 
				
			||||||
 | 
					#include <random>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace std;
 | 
				
			||||||
 | 
					using namespace gtsam;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					vector<double> genArr(double dropout, size_t size) {
 | 
				
			||||||
 | 
					  random_device rd;
 | 
				
			||||||
 | 
					  mt19937 g(rd());
 | 
				
			||||||
 | 
					  vector<double> dropoutmask(size);  // Chance of 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  uniform_int_distribution<> dist(1, 9);
 | 
				
			||||||
 | 
					  auto gen = [&dist, &g]() { return dist(g); };
 | 
				
			||||||
 | 
					  generate(dropoutmask.begin(), dropoutmask.end(), gen);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  fill_n(dropoutmask.begin(), dropoutmask.size() * (dropout), 0);
 | 
				
			||||||
 | 
					  shuffle(dropoutmask.begin(), dropoutmask.end(), g);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return dropoutmask;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					map<double, pair<chrono::microseconds, chrono::microseconds>> measureTime(
 | 
				
			||||||
 | 
					    DiscreteKeys keys1, DiscreteKeys keys2, size_t size) {
 | 
				
			||||||
 | 
					  vector<double> dropouts = {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9};
 | 
				
			||||||
 | 
					  map<double, pair<chrono::microseconds, chrono::microseconds>> measured_times;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (auto dropout : dropouts) {
 | 
				
			||||||
 | 
					    vector<double> arr1 = genArr(dropout, size);
 | 
				
			||||||
 | 
					    vector<double> arr2 = genArr(dropout, size);
 | 
				
			||||||
 | 
					    TableFactor f1(keys1, arr1);
 | 
				
			||||||
 | 
					    TableFactor f2(keys2, arr2);
 | 
				
			||||||
 | 
					    DecisionTreeFactor f1_dt(keys1, arr1);
 | 
				
			||||||
 | 
					    DecisionTreeFactor f2_dt(keys2, arr2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // measure time TableFactor
 | 
				
			||||||
 | 
					    auto tb_start = chrono::high_resolution_clock::now();
 | 
				
			||||||
 | 
					    TableFactor actual = f1 * f2;
 | 
				
			||||||
 | 
					    auto tb_end = chrono::high_resolution_clock::now();
 | 
				
			||||||
 | 
					    auto tb_time_diff =
 | 
				
			||||||
 | 
					        chrono::duration_cast<chrono::microseconds>(tb_end - tb_start);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // measure time DT
 | 
				
			||||||
 | 
					    auto dt_start = chrono::high_resolution_clock::now();
 | 
				
			||||||
 | 
					    DecisionTreeFactor actual_dt = f1_dt * f2_dt;
 | 
				
			||||||
 | 
					    auto dt_end = chrono::high_resolution_clock::now();
 | 
				
			||||||
 | 
					    auto dt_time_diff =
 | 
				
			||||||
 | 
					        chrono::duration_cast<chrono::microseconds>(dt_end - dt_start);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bool flag = true;
 | 
				
			||||||
 | 
					    for (auto assignmentVal : actual_dt.enumerate()) {
 | 
				
			||||||
 | 
					      flag = actual_dt(assignmentVal.first) != actual(assignmentVal.first);
 | 
				
			||||||
 | 
					      if (flag) {
 | 
				
			||||||
 | 
					        std::cout << "something is wrong: " << std::endl;
 | 
				
			||||||
 | 
					        assignmentVal.first.print();
 | 
				
			||||||
 | 
					        std::cout << "dt: " << actual_dt(assignmentVal.first) << std::endl;
 | 
				
			||||||
 | 
					        std::cout << "tb: " << actual(assignmentVal.first) << std::endl;
 | 
				
			||||||
 | 
					        break;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (flag) break;
 | 
				
			||||||
 | 
					    measured_times[dropout] = make_pair(tb_time_diff, dt_time_diff);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return measured_times;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
 | 
				
			||||||
 | 
					                   measured_time) {
 | 
				
			||||||
 | 
					  for (auto&& kv : measured_time) {
 | 
				
			||||||
 | 
					    cout << "dropout: " << kv.first
 | 
				
			||||||
 | 
					         << " | TableFactor time: " << kv.second.first.count()
 | 
				
			||||||
 | 
					         << " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 | 
					// Check constructors for TableFactor.
 | 
				
			||||||
 | 
					TEST(TableFactor, constructors) {
 | 
				
			||||||
 | 
					  // Declare a bunch of keys
 | 
				
			||||||
 | 
					  DiscreteKey X(0, 2), Y(1, 3), Z(2, 2), A(3, 5);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Create factors
 | 
				
			||||||
 | 
					  TableFactor f_zeros(A, {0, 0, 0, 0, 1});
 | 
				
			||||||
 | 
					  TableFactor f1(X, {2, 8});
 | 
				
			||||||
 | 
					  TableFactor f2(X & Y, "2 5 3 6 4 7");
 | 
				
			||||||
 | 
					  TableFactor f3(X & Y & Z, "2 5 3 6 4 7 25 55 35 65 45 75");
 | 
				
			||||||
 | 
					  EXPECT_LONGS_EQUAL(1, f1.size());
 | 
				
			||||||
 | 
					  EXPECT_LONGS_EQUAL(2, f2.size());
 | 
				
			||||||
 | 
					  EXPECT_LONGS_EQUAL(3, f3.size());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  DiscreteValues values;
 | 
				
			||||||
 | 
					  values[0] = 1;  // x
 | 
				
			||||||
 | 
					  values[1] = 2;  // y
 | 
				
			||||||
 | 
					  values[2] = 1;  // z
 | 
				
			||||||
 | 
					  values[3] = 4;  // a
 | 
				
			||||||
 | 
					  EXPECT_DOUBLES_EQUAL(1, f_zeros(values), 1e-9);
 | 
				
			||||||
 | 
					  EXPECT_DOUBLES_EQUAL(8, f1(values), 1e-9);
 | 
				
			||||||
 | 
					  EXPECT_DOUBLES_EQUAL(7, f2(values), 1e-9);
 | 
				
			||||||
 | 
					  EXPECT_DOUBLES_EQUAL(75, f3(values), 1e-9);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Assert that error = -log(value)
 | 
				
			||||||
 | 
					  EXPECT_DOUBLES_EQUAL(-log(f1(values)), f1.error(values), 1e-9);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 | 
					// Check multiplication between two TableFactors.
 | 
				
			||||||
 | 
					TEST(TableFactor, multiplication) {
 | 
				
			||||||
 | 
					  DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Multiply with a DiscreteDistribution, i.e., Bayes Law!
 | 
				
			||||||
 | 
					  DiscreteDistribution prior(v1 % "1/3");
 | 
				
			||||||
 | 
					  TableFactor f1(v0 & v1, "1 2 3 4");
 | 
				
			||||||
 | 
					  DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3");
 | 
				
			||||||
 | 
					  CHECK(assert_equal(expected, static_cast<DecisionTreeFactor>(prior) *
 | 
				
			||||||
 | 
					                                   f1.toDecisionTreeFactor()));
 | 
				
			||||||
 | 
					  CHECK(assert_equal(expected, f1 * prior));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Multiply two factors
 | 
				
			||||||
 | 
					  TableFactor f2(v1 & v2, "5 6 7 8");
 | 
				
			||||||
 | 
					  TableFactor actual = f1 * f2;
 | 
				
			||||||
 | 
					  TableFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32");
 | 
				
			||||||
 | 
					  CHECK(assert_equal(expected2, actual));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  DiscreteKey A(0, 3), B(1, 2), C(2, 2);
 | 
				
			||||||
 | 
					  TableFactor f_zeros1(A & C, "0 0 0 2 0 3");
 | 
				
			||||||
 | 
					  TableFactor f_zeros2(B & C, "4 0 0 5");
 | 
				
			||||||
 | 
					  TableFactor actual_zeros = f_zeros1 * f_zeros2;
 | 
				
			||||||
 | 
					  TableFactor expected3(A & B & C, "0 0 0 0 0 0 0 10 0 0 0 15");
 | 
				
			||||||
 | 
					  CHECK(assert_equal(expected3, actual_zeros));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 | 
					// Benchmark which compares runtime of multiplication of two TableFactors
 | 
				
			||||||
 | 
					// and two DecisionTreeFactors given sparsity from dense to 90% sparsity.
 | 
				
			||||||
 | 
					TEST(TableFactor, benchmark) {
 | 
				
			||||||
 | 
					  DiscreteKey A(0, 5), B(1, 2), C(2, 5), D(3, 2), E(4, 5), F(5, 2), G(6, 3),
 | 
				
			||||||
 | 
					      H(7, 2), I(8, 5), J(9, 7), K(10, 2), L(11, 3);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // 100
 | 
				
			||||||
 | 
					  DiscreteKeys one_1 = {A, B, C, D};
 | 
				
			||||||
 | 
					  DiscreteKeys one_2 = {C, D, E, F};
 | 
				
			||||||
 | 
					  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_1 =
 | 
				
			||||||
 | 
					      measureTime(one_1, one_2, 100);
 | 
				
			||||||
 | 
					  printTime(time_map_1);
 | 
				
			||||||
 | 
					  // 200
 | 
				
			||||||
 | 
					  DiscreteKeys two_1 = {A, B, C, D, F};
 | 
				
			||||||
 | 
					  DiscreteKeys two_2 = {B, C, D, E, F};
 | 
				
			||||||
 | 
					  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_2 =
 | 
				
			||||||
 | 
					      measureTime(two_1, two_2, 200);
 | 
				
			||||||
 | 
					  printTime(time_map_2);
 | 
				
			||||||
 | 
					  // 300
 | 
				
			||||||
 | 
					  DiscreteKeys three_1 = {A, B, C, D, G};
 | 
				
			||||||
 | 
					  DiscreteKeys three_2 = {C, D, E, F, G};
 | 
				
			||||||
 | 
					  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_3 =
 | 
				
			||||||
 | 
					      measureTime(three_1, three_2, 300);
 | 
				
			||||||
 | 
					  printTime(time_map_3);
 | 
				
			||||||
 | 
					  // 400
 | 
				
			||||||
 | 
					  DiscreteKeys four_1 = {A, B, C, D, F, H};
 | 
				
			||||||
 | 
					  DiscreteKeys four_2 = {B, C, D, E, F, H};
 | 
				
			||||||
 | 
					  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_4 =
 | 
				
			||||||
 | 
					      measureTime(four_1, four_2, 400);
 | 
				
			||||||
 | 
					  printTime(time_map_4);
 | 
				
			||||||
 | 
					  // 500
 | 
				
			||||||
 | 
					  DiscreteKeys five_1 = {A, B, C, D, I};
 | 
				
			||||||
 | 
					  DiscreteKeys five_2 = {C, D, E, F, I};
 | 
				
			||||||
 | 
					  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_5 =
 | 
				
			||||||
 | 
					      measureTime(five_1, five_2, 500);
 | 
				
			||||||
 | 
					  printTime(time_map_5);
 | 
				
			||||||
 | 
					  // 600
 | 
				
			||||||
 | 
					  DiscreteKeys six_1 = {A, B, C, D, F, G};
 | 
				
			||||||
 | 
					  DiscreteKeys six_2 = {B, C, D, E, F, G};
 | 
				
			||||||
 | 
					  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_6 =
 | 
				
			||||||
 | 
					      measureTime(six_1, six_2, 600);
 | 
				
			||||||
 | 
					  printTime(time_map_6);
 | 
				
			||||||
 | 
					  // 700
 | 
				
			||||||
 | 
					  DiscreteKeys seven_1 = {A, B, C, D, J};
 | 
				
			||||||
 | 
					  DiscreteKeys seven_2 = {C, D, E, F, J};
 | 
				
			||||||
 | 
					  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_7 =
 | 
				
			||||||
 | 
					      measureTime(seven_1, seven_2, 700);
 | 
				
			||||||
 | 
					  printTime(time_map_7);
 | 
				
			||||||
 | 
					  // 800
 | 
				
			||||||
 | 
					  DiscreteKeys eight_1 = {A, B, C, D, F, H, K};
 | 
				
			||||||
 | 
					  DiscreteKeys eight_2 = {B, C, D, E, F, H, K};
 | 
				
			||||||
 | 
					  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_8 =
 | 
				
			||||||
 | 
					      measureTime(eight_1, eight_2, 800);
 | 
				
			||||||
 | 
					  printTime(time_map_8);
 | 
				
			||||||
 | 
					  // 900
 | 
				
			||||||
 | 
					  DiscreteKeys nine_1 = {A, B, C, D, G, L};
 | 
				
			||||||
 | 
					  DiscreteKeys nine_2 = {C, D, E, F, G, L};
 | 
				
			||||||
 | 
					  map<double, pair<chrono::microseconds, chrono::microseconds>> time_map_9 =
 | 
				
			||||||
 | 
					      measureTime(nine_1, nine_2, 900);
 | 
				
			||||||
 | 
					  printTime(time_map_9);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 | 
					// Check sum and max over frontals.
 | 
				
			||||||
 | 
					TEST(TableFactor, sum_max) {
 | 
				
			||||||
 | 
					  DiscreteKey v0(0, 3), v1(1, 2);
 | 
				
			||||||
 | 
					  TableFactor f1(v0 & v1, "1 2  3 4  5 6");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  TableFactor expected(v1, "9 12");
 | 
				
			||||||
 | 
					  TableFactor::shared_ptr actual = f1.sum(1);
 | 
				
			||||||
 | 
					  CHECK(assert_equal(expected, *actual, 1e-5));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  TableFactor expected2(v1, "5 6");
 | 
				
			||||||
 | 
					  TableFactor::shared_ptr actual2 = f1.max(1);
 | 
				
			||||||
 | 
					  CHECK(assert_equal(expected2, *actual2));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  TableFactor f2(v1 & v0, "1 2  3 4  5 6");
 | 
				
			||||||
 | 
					  TableFactor::shared_ptr actual22 = f2.sum(1);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 | 
					// Check enumerate yields the correct list of assignment/value pairs.
 | 
				
			||||||
 | 
					TEST(TableFactor, enumerate) {
 | 
				
			||||||
 | 
					  DiscreteKey A(12, 3), B(5, 2);
 | 
				
			||||||
 | 
					  TableFactor f(A & B, "1 2  3 4  5 6");
 | 
				
			||||||
 | 
					  auto actual = f.enumerate();
 | 
				
			||||||
 | 
					  std::vector<std::pair<DiscreteValues, double>> expected;
 | 
				
			||||||
 | 
					  DiscreteValues values;
 | 
				
			||||||
 | 
					  for (size_t a : {0, 1, 2}) {
 | 
				
			||||||
 | 
					    for (size_t b : {0, 1}) {
 | 
				
			||||||
 | 
					      values[12] = a;
 | 
				
			||||||
 | 
					      values[5] = b;
 | 
				
			||||||
 | 
					      expected.emplace_back(values, f(values));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  EXPECT(actual == expected);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 | 
					// Check pruning of the decision tree works as expected.
 | 
				
			||||||
 | 
					TEST(TableFactor, Prune) {
 | 
				
			||||||
 | 
					  DiscreteKey A(1, 2), B(2, 2), C(3, 2);
 | 
				
			||||||
 | 
					  TableFactor f(A & B & C, "1 5 3 7 2 6 4 8");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Only keep the leaves with the top 5 values.
 | 
				
			||||||
 | 
					  size_t maxNrAssignments = 5;
 | 
				
			||||||
 | 
					  auto pruned5 = f.prune(maxNrAssignments);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Pruned leaves should be 0
 | 
				
			||||||
 | 
					  TableFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
 | 
				
			||||||
 | 
					  EXPECT(assert_equal(expected, pruned5));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Check for more extreme pruning where we only keep the top 2 leaves
 | 
				
			||||||
 | 
					  maxNrAssignments = 2;
 | 
				
			||||||
 | 
					  auto pruned2 = f.prune(maxNrAssignments);
 | 
				
			||||||
 | 
					  TableFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
 | 
				
			||||||
 | 
					  EXPECT(assert_equal(expected2, pruned2));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  DiscreteKey D(4, 2);
 | 
				
			||||||
 | 
					  TableFactor factor(
 | 
				
			||||||
 | 
					      D & C & B & A,
 | 
				
			||||||
 | 
					      "0.0 0.0 0.0 0.60658897 0.61241912 0.61241969 0.61247685 0.61247742 0.0 "
 | 
				
			||||||
 | 
					      "0.0 0.0 0.99995287 1.0 1.0 1.0 1.0");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  TableFactor expected3(D & C & B & A,
 | 
				
			||||||
 | 
					                        "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 "
 | 
				
			||||||
 | 
					                        "0.999952870000 1.0 1.0 1.0 1.0");
 | 
				
			||||||
 | 
					  maxNrAssignments = 5;
 | 
				
			||||||
 | 
					  auto pruned3 = factor.prune(maxNrAssignments);
 | 
				
			||||||
 | 
					  EXPECT(assert_equal(expected3, pruned3));
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 | 
					// Check markdown representation looks as expected.
 | 
				
			||||||
 | 
					TEST(TableFactor, markdown) {
 | 
				
			||||||
 | 
					  DiscreteKey A(12, 3), B(5, 2);
 | 
				
			||||||
 | 
					  TableFactor f(A & B, "1 2  3 4  5 6");
 | 
				
			||||||
 | 
					  string expected =
 | 
				
			||||||
 | 
					      "|A|B|value|\n"
 | 
				
			||||||
 | 
					      "|:-:|:-:|:-:|\n"
 | 
				
			||||||
 | 
					      "|0|0|1|\n"
 | 
				
			||||||
 | 
					      "|0|1|2|\n"
 | 
				
			||||||
 | 
					      "|1|0|3|\n"
 | 
				
			||||||
 | 
					      "|1|1|4|\n"
 | 
				
			||||||
 | 
					      "|2|0|5|\n"
 | 
				
			||||||
 | 
					      "|2|1|6|\n";
 | 
				
			||||||
 | 
					  auto formatter = [](Key key) { return key == 12 ? "A" : "B"; };
 | 
				
			||||||
 | 
					  string actual = f.markdown(formatter);
 | 
				
			||||||
 | 
					  EXPECT(actual == expected);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 | 
					// Check markdown representation with a value formatter.
 | 
				
			||||||
 | 
					TEST(TableFactor, markdownWithValueFormatter) {
 | 
				
			||||||
 | 
					  DiscreteKey A(12, 3), B(5, 2);
 | 
				
			||||||
 | 
					  TableFactor f(A & B, "1 2  3 4  5 6");
 | 
				
			||||||
 | 
					  string expected =
 | 
				
			||||||
 | 
					      "|A|B|value|\n"
 | 
				
			||||||
 | 
					      "|:-:|:-:|:-:|\n"
 | 
				
			||||||
 | 
					      "|Zero|-|1|\n"
 | 
				
			||||||
 | 
					      "|Zero|+|2|\n"
 | 
				
			||||||
 | 
					      "|One|-|3|\n"
 | 
				
			||||||
 | 
					      "|One|+|4|\n"
 | 
				
			||||||
 | 
					      "|Two|-|5|\n"
 | 
				
			||||||
 | 
					      "|Two|+|6|\n";
 | 
				
			||||||
 | 
					  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
 | 
				
			||||||
 | 
					  TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
 | 
				
			||||||
 | 
					  string actual = f.markdown(keyFormatter, names);
 | 
				
			||||||
 | 
					  EXPECT(actual == expected);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 | 
					// Check html representation with a value formatter.
 | 
				
			||||||
 | 
					TEST(TableFactor, htmlWithValueFormatter) {
 | 
				
			||||||
 | 
					  DiscreteKey A(12, 3), B(5, 2);
 | 
				
			||||||
 | 
					  TableFactor f(A & B, "1 2  3 4  5 6");
 | 
				
			||||||
 | 
					  string expected =
 | 
				
			||||||
 | 
					      "<div>\n"
 | 
				
			||||||
 | 
					      "<table class='TableFactor'>\n"
 | 
				
			||||||
 | 
					      "  <thead>\n"
 | 
				
			||||||
 | 
					      "    <tr><th>A</th><th>B</th><th>value</th></tr>\n"
 | 
				
			||||||
 | 
					      "  </thead>\n"
 | 
				
			||||||
 | 
					      "  <tbody>\n"
 | 
				
			||||||
 | 
					      "    <tr><th>Zero</th><th>-</th><td>1</td></tr>\n"
 | 
				
			||||||
 | 
					      "    <tr><th>Zero</th><th>+</th><td>2</td></tr>\n"
 | 
				
			||||||
 | 
					      "    <tr><th>One</th><th>-</th><td>3</td></tr>\n"
 | 
				
			||||||
 | 
					      "    <tr><th>One</th><th>+</th><td>4</td></tr>\n"
 | 
				
			||||||
 | 
					      "    <tr><th>Two</th><th>-</th><td>5</td></tr>\n"
 | 
				
			||||||
 | 
					      "    <tr><th>Two</th><th>+</th><td>6</td></tr>\n"
 | 
				
			||||||
 | 
					      "  </tbody>\n"
 | 
				
			||||||
 | 
					      "</table>\n"
 | 
				
			||||||
 | 
					      "</div>";
 | 
				
			||||||
 | 
					  auto keyFormatter = [](Key key) { return key == 12 ? "A" : "B"; };
 | 
				
			||||||
 | 
					  TableFactor::Names names{{12, {"Zero", "One", "Two"}}, {5, {"-", "+"}}};
 | 
				
			||||||
 | 
					  string actual = f.html(keyFormatter, names);
 | 
				
			||||||
 | 
					  EXPECT(actual == expected);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
 | 
					int main() {
 | 
				
			||||||
 | 
					  TestResult tr;
 | 
				
			||||||
 | 
					  return TestRegistry::runAllTests(tr);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					/* ************************************************************************* */
 | 
				
			||||||
| 
						 | 
					@ -105,7 +105,12 @@ function(wrap_library_internal interfaceHeader moduleName linkLibraries extraInc
 | 
				
			||||||
      set(mexModuleExt mexglx)
 | 
					      set(mexModuleExt mexglx)
 | 
				
			||||||
    endif()
 | 
					    endif()
 | 
				
			||||||
  elseif(APPLE)
 | 
					  elseif(APPLE)
 | 
				
			||||||
    set(mexModuleExt mexmaci64)
 | 
					    check_cxx_compiler_flag("-arch arm64" arm64Supported)
 | 
				
			||||||
 | 
					    if (arm64Supported)
 | 
				
			||||||
 | 
					      set(mexModuleExt mexmaca64)
 | 
				
			||||||
 | 
					    else()
 | 
				
			||||||
 | 
					      set(mexModuleExt mexmaci64)
 | 
				
			||||||
 | 
					    endif()
 | 
				
			||||||
  elseif(MSVC)
 | 
					  elseif(MSVC)
 | 
				
			||||||
    if(CMAKE_CL_64)
 | 
					    if(CMAKE_CL_64)
 | 
				
			||||||
      set(mexModuleExt mexw64)
 | 
					      set(mexModuleExt mexw64)
 | 
				
			||||||
| 
						 | 
					@ -299,7 +304,12 @@ function(wrap_library_internal interfaceHeader moduleName linkLibraries extraInc
 | 
				
			||||||
      APPEND
 | 
					      APPEND
 | 
				
			||||||
      PROPERTY COMPILE_FLAGS "/bigobj")
 | 
					      PROPERTY COMPILE_FLAGS "/bigobj")
 | 
				
			||||||
  elseif(APPLE)
 | 
					  elseif(APPLE)
 | 
				
			||||||
    set(mxLibPath "${MATLAB_ROOT}/bin/maci64")
 | 
					    check_cxx_compiler_flag("-arch arm64" arm64Supported)
 | 
				
			||||||
 | 
					    if (arm64Supported)
 | 
				
			||||||
 | 
					      set(mxLibPath "${MATLAB_ROOT}/bin/maca64")
 | 
				
			||||||
 | 
					    else()
 | 
				
			||||||
 | 
					      set(mxLibPath "${MATLAB_ROOT}/bin/maci64")
 | 
				
			||||||
 | 
					    endif()
 | 
				
			||||||
    target_link_libraries(
 | 
					    target_link_libraries(
 | 
				
			||||||
      ${moduleName}_matlab_wrapper "${mxLibPath}/libmex.dylib"
 | 
					      ${moduleName}_matlab_wrapper "${mxLibPath}/libmex.dylib"
 | 
				
			||||||
      "${mxLibPath}/libmx.dylib" "${mxLibPath}/libmat.dylib")
 | 
					      "${mxLibPath}/libmx.dylib" "${mxLibPath}/libmat.dylib")
 | 
				
			||||||
| 
						 | 
					@ -367,7 +377,12 @@ function(check_conflicting_libraries_internal libraries)
 | 
				
			||||||
  if(UNIX)
 | 
					  if(UNIX)
 | 
				
			||||||
    # Set path for matlab's built-in libraries
 | 
					    # Set path for matlab's built-in libraries
 | 
				
			||||||
    if(APPLE)
 | 
					    if(APPLE)
 | 
				
			||||||
      set(mxLibPath "${MATLAB_ROOT}/bin/maci64")
 | 
					      check_cxx_compiler_flag("-arch arm64" arm64Supported)
 | 
				
			||||||
 | 
					      if (arm64Supported)
 | 
				
			||||||
 | 
					        set(mxLibPath "${MATLAB_ROOT}/bin/maca64")
 | 
				
			||||||
 | 
					      else()
 | 
				
			||||||
 | 
					        set(mxLibPath "${MATLAB_ROOT}/bin/maci64")
 | 
				
			||||||
 | 
					      endif()
 | 
				
			||||||
    else()
 | 
					    else()
 | 
				
			||||||
      if(CMAKE_CL_64)
 | 
					      if(CMAKE_CL_64)
 | 
				
			||||||
        set(mxLibPath "${MATLAB_ROOT}/bin/glnxa64")
 | 
					        set(mxLibPath "${MATLAB_ROOT}/bin/glnxa64")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue