2012-05-15 17:22:55 +08:00
|
|
|
/* ----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
|
|
|
* Atlanta, Georgia 30332-0415
|
|
|
|
|
* All Rights Reserved
|
|
|
|
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
|
|
|
|
|
|
|
|
* See LICENSE for the license information
|
|
|
|
|
|
|
|
|
|
* -------------------------------------------------------------------------- */
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @file DiscreteFactor.h
|
2012-04-16 06:35:28 +08:00
|
|
|
* @date Feb 14, 2011
|
|
|
|
|
* @author Duy-Nguyen Ta
|
2012-05-15 17:22:55 +08:00
|
|
|
* @author Frank Dellaert
|
2012-04-16 06:35:28 +08:00
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
2021-12-14 02:46:53 +08:00
|
|
|
#include <gtsam/discrete/DiscreteValues.h>
|
2013-10-11 05:59:49 +08:00
|
|
|
#include <gtsam/inference/Factor.h>
|
2014-12-22 05:02:06 +08:00
|
|
|
#include <gtsam/base/Testable.h>
|
2012-04-16 06:35:28 +08:00
|
|
|
|
2022-01-10 03:46:23 +08:00
|
|
|
#include <string>
|
2012-04-16 06:35:28 +08:00
|
|
|
namespace gtsam {
|
|
|
|
|
|
2013-10-11 05:59:49 +08:00
|
|
|
class DecisionTreeFactor;
|
|
|
|
|
class DiscreteConditional;
|
2023-01-09 08:37:39 +08:00
|
|
|
class HybridValues;
|
|
|
|
|
|
2013-10-11 05:59:49 +08:00
|
|
|
/**
|
|
|
|
|
* Base class for discrete probabilistic factors
|
|
|
|
|
* The most general one is the derived DecisionTreeFactor
|
2022-07-27 04:55:28 +08:00
|
|
|
*
|
|
|
|
|
* @ingroup discrete
|
2013-10-11 05:59:49 +08:00
|
|
|
*/
|
|
|
|
|
class GTSAM_EXPORT DiscreteFactor: public Factor {
|
2023-06-29 06:24:12 +08:00
|
|
|
public:
|
2013-10-11 05:59:49 +08:00
|
|
|
// typedefs needed to play nice with gtsam
|
2023-06-29 06:24:12 +08:00
|
|
|
typedef DiscreteFactor This; ///< This class
|
|
|
|
|
typedef std::shared_ptr<DiscreteFactor>
|
|
|
|
|
shared_ptr; ///< shared_ptr to this class
|
|
|
|
|
typedef Factor Base; ///< Our base class
|
2012-10-02 22:40:07 +08:00
|
|
|
|
2023-06-29 06:24:12 +08:00
|
|
|
using Values = DiscreteValues; ///< backwards compatibility
|
2012-10-02 22:40:07 +08:00
|
|
|
|
2023-07-08 23:33:01 +08:00
|
|
|
protected:
|
|
|
|
|
/// Map of Keys and their cardinalities.
|
|
|
|
|
std::map<Key, size_t> cardinalities_;
|
|
|
|
|
|
2023-06-29 06:24:12 +08:00
|
|
|
public:
|
2013-10-11 05:59:49 +08:00
|
|
|
/// @name Standard Constructors
|
|
|
|
|
/// @{
|
2012-10-02 22:40:07 +08:00
|
|
|
|
2013-10-11 05:59:49 +08:00
|
|
|
/** Default constructor creates empty factor */
|
|
|
|
|
DiscreteFactor() {}
|
2012-10-02 22:40:07 +08:00
|
|
|
|
2023-07-08 23:33:01 +08:00
|
|
|
/**
|
|
|
|
|
* Construct from container of keys and map of cardinalities.
|
|
|
|
|
* This constructor is used internally from derived factor constructors,
|
|
|
|
|
* either from a container of keys or from a boost::assign::list_of.
|
|
|
|
|
*/
|
|
|
|
|
template <typename CONTAINER>
|
|
|
|
|
DiscreteFactor(const CONTAINER& keys,
|
|
|
|
|
const std::map<Key, size_t> cardinalities = {})
|
|
|
|
|
: Base(keys), cardinalities_(cardinalities) {}
|
2012-10-02 22:40:07 +08:00
|
|
|
|
2013-10-11 05:59:49 +08:00
|
|
|
/// @}
|
|
|
|
|
/// @name Testable
|
|
|
|
|
/// @{
|
2012-10-02 22:40:07 +08:00
|
|
|
|
2021-04-30 07:43:27 +08:00
|
|
|
/// equals
|
2013-10-11 05:59:49 +08:00
|
|
|
virtual bool equals(const DiscreteFactor& lf, double tol = 1e-9) const = 0;
|
2012-10-02 22:40:07 +08:00
|
|
|
|
2021-04-30 07:43:27 +08:00
|
|
|
/// print
|
2021-05-01 00:58:52 +08:00
|
|
|
void print(
|
2021-04-30 07:43:27 +08:00
|
|
|
const std::string& s = "DiscreteFactor\n",
|
|
|
|
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
|
|
|
|
|
Base::print(s, formatter);
|
2013-10-11 05:59:49 +08:00
|
|
|
}
|
2012-10-02 22:40:07 +08:00
|
|
|
|
2013-10-11 05:59:49 +08:00
|
|
|
/// @}
|
|
|
|
|
/// @name Standard Interface
|
|
|
|
|
/// @{
|
2012-10-02 22:40:07 +08:00
|
|
|
|
2023-07-08 23:33:01 +08:00
|
|
|
/// Return all the discrete keys associated with this factor.
|
|
|
|
|
DiscreteKeys discreteKeys() const;
|
|
|
|
|
|
|
|
|
|
std::map<Key, size_t> cardinalities() const { return cardinalities_; }
|
|
|
|
|
|
|
|
|
|
size_t cardinality(Key j) const { return cardinalities_.at(j); }
|
|
|
|
|
|
2013-10-11 05:59:49 +08:00
|
|
|
/// Find value for given assignment of values to variables
|
2021-12-14 02:46:53 +08:00
|
|
|
virtual double operator()(const DiscreteValues&) const = 0;
|
2012-10-02 22:40:07 +08:00
|
|
|
|
2023-01-09 08:37:39 +08:00
|
|
|
/// Error is just -log(value)
|
|
|
|
|
double error(const DiscreteValues& values) const;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* The Factor::error simply extracts the \class DiscreteValues from the
|
|
|
|
|
* \class HybridValues and calculates the error.
|
|
|
|
|
*/
|
|
|
|
|
double error(const HybridValues& c) const override;
|
|
|
|
|
|
2013-10-11 05:59:49 +08:00
|
|
|
/// Multiply in a DecisionTreeFactor and return the result as DecisionTreeFactor
|
|
|
|
|
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0;
|
2013-08-06 06:30:50 +08:00
|
|
|
|
2013-10-11 05:59:49 +08:00
|
|
|
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;
|
2013-08-06 06:30:50 +08:00
|
|
|
|
2021-12-25 02:27:02 +08:00
|
|
|
/// @}
|
|
|
|
|
/// @name Wrapper support
|
|
|
|
|
/// @{
|
|
|
|
|
|
2022-01-04 00:13:32 +08:00
|
|
|
/// Translation table from values to strings.
|
2022-01-10 03:46:23 +08:00
|
|
|
using Names = DiscreteValues::Names;
|
2022-01-03 10:34:22 +08:00
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Render as markdown table
|
2022-01-10 03:46:23 +08:00
|
|
|
*
|
2022-01-03 10:34:22 +08:00
|
|
|
* @param keyFormatter GTSAM-style Key formatter.
|
|
|
|
|
* @param names optional, category names corresponding to choices.
|
|
|
|
|
* @return std::string a markdown string.
|
|
|
|
|
*/
|
2021-12-25 23:46:49 +08:00
|
|
|
virtual std::string markdown(
|
2022-01-03 10:34:22 +08:00
|
|
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
|
|
|
|
const Names& names = {}) const = 0;
|
2021-12-25 02:27:02 +08:00
|
|
|
|
2022-01-09 21:19:44 +08:00
|
|
|
/**
|
|
|
|
|
* @brief Render as html table
|
|
|
|
|
*
|
|
|
|
|
* @param keyFormatter GTSAM-style Key formatter.
|
|
|
|
|
* @param names optional, category names corresponding to choices.
|
|
|
|
|
* @return std::string a html string.
|
|
|
|
|
*/
|
|
|
|
|
virtual std::string html(
|
|
|
|
|
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
|
|
|
|
const Names& names = {}) const = 0;
|
|
|
|
|
|
2013-10-11 05:59:49 +08:00
|
|
|
/// @}
|
2023-06-29 06:24:12 +08:00
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
|
|
|
|
/** Serialization function */
|
|
|
|
|
friend class boost::serialization::access;
|
|
|
|
|
template <class ARCHIVE>
|
|
|
|
|
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
|
|
|
|
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
|
2023-07-08 23:33:01 +08:00
|
|
|
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
|
2023-06-29 06:24:12 +08:00
|
|
|
}
|
|
|
|
|
#endif
|
2013-10-11 05:59:49 +08:00
|
|
|
};
|
2012-04-16 06:35:28 +08:00
|
|
|
// DiscreteFactor
|
|
|
|
|
|
2014-12-22 05:02:06 +08:00
|
|
|
// traits
|
2014-12-26 23:47:51 +08:00
|
|
|
template<> struct traits<DiscreteFactor> : public Testable<DiscreteFactor> {};
|
2014-12-22 05:02:06 +08:00
|
|
|
|
2022-01-23 00:06:33 +08:00
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* @brief Normalize a set of log probabilities.
|
|
|
|
|
*
|
|
|
|
|
* Normalizing a set of log probabilities in a numerically stable way is
|
|
|
|
|
* tricky. To avoid overflow/underflow issues, we compute the largest
|
|
|
|
|
* (finite) log probability and subtract it from each log probability before
|
|
|
|
|
* normalizing. This comes from the observation that if:
|
|
|
|
|
* p_i = exp(L_i) / ( sum_j exp(L_j) ),
|
|
|
|
|
* Then,
|
|
|
|
|
* p_i = exp(Z) exp(L_i - Z) / (exp(Z) sum_j exp(L_j - Z)),
|
|
|
|
|
* = exp(L_i - Z) / ( sum_j exp(L_j - Z) )
|
|
|
|
|
*
|
|
|
|
|
* Setting Z = max_j L_j, we can avoid numerical issues that arise when all
|
|
|
|
|
* of the (unnormalized) log probabilities are either very large or very
|
|
|
|
|
* small.
|
|
|
|
|
*/
|
|
|
|
|
std::vector<double> expNormalize(const std::vector<double> &logProbs);
|
|
|
|
|
|
|
|
|
|
|
2012-04-16 06:35:28 +08:00
|
|
|
}// namespace gtsam
|