use a TableFactor as the underlying data representation for DiscreteTableConditional since it provides a clean abstraction
parent
d18f23c47b
commit
4ff70141f8
|
|
@ -41,23 +41,21 @@ namespace gtsam {
|
||||||
DiscreteTableConditional::DiscreteTableConditional(const size_t nrFrontals,
|
DiscreteTableConditional::DiscreteTableConditional(const size_t nrFrontals,
|
||||||
const TableFactor& f)
|
const TableFactor& f)
|
||||||
: BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())),
|
: BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())),
|
||||||
sparse_table_((f / (*f.sum(nrFrontals))).sparseTable()) {
|
table_(f / (*f.sum(nrFrontals))) {}
|
||||||
// sparse_table_ = sparse_table_.prune();
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteTableConditional::DiscreteTableConditional(
|
DiscreteTableConditional::DiscreteTableConditional(
|
||||||
size_t nrFrontals, const DiscreteKeys& keys,
|
size_t nrFrontals, const DiscreteKeys& keys,
|
||||||
const Eigen::SparseVector<double>& potentials)
|
const Eigen::SparseVector<double>& potentials)
|
||||||
: BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())),
|
: BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())),
|
||||||
sparse_table_(potentials) {}
|
table_(TableFactor(keys, potentials)) {}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
|
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
|
||||||
const TableFactor& marginal)
|
const TableFactor& marginal)
|
||||||
: BaseConditional(joint.size() - marginal.size(),
|
: BaseConditional(joint.size() - marginal.size(),
|
||||||
joint.discreteKeys() & marginal.discreteKeys(), ADT()),
|
joint.discreteKeys() & marginal.discreteKeys(), ADT()),
|
||||||
sparse_table_((joint / marginal).sparseTable()) {}
|
table_(joint / marginal) {}
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
|
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
|
||||||
|
|
@ -71,8 +69,7 @@ DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteTableConditional::DiscreteTableConditional(const Signature& signature)
|
DiscreteTableConditional::DiscreteTableConditional(const Signature& signature)
|
||||||
: BaseConditional(1, DecisionTreeFactor()),
|
: BaseConditional(1, DecisionTreeFactor()),
|
||||||
sparse_table_(TableFactor(signature.discreteKeys(), signature.cpt())
|
table_(TableFactor(signature.discreteKeys(), signature.cpt())) {}
|
||||||
.sparseTable()) {}
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteTableConditional DiscreteTableConditional::operator*(
|
DiscreteTableConditional DiscreteTableConditional::operator*(
|
||||||
|
|
@ -108,9 +105,7 @@ DiscreteTableConditional DiscreteTableConditional::operator*(
|
||||||
// Finally, add parents to keys, in order
|
// Finally, add parents to keys, in order
|
||||||
for (auto&& dk : parents) discreteKeys.push_back(dk);
|
for (auto&& dk : parents) discreteKeys.push_back(dk);
|
||||||
|
|
||||||
TableFactor a(this->discreteKeys(), this->sparse_table_),
|
TableFactor product = this->table_ * other.table();
|
||||||
b(other.discreteKeys(), other.sparse_table_);
|
|
||||||
TableFactor product = a * other;
|
|
||||||
return DiscreteTableConditional(newFrontals.size(), product);
|
return DiscreteTableConditional(newFrontals.size(), product);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -128,7 +123,7 @@ void DiscreteTableConditional::print(const string& s,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cout << "):\n";
|
cout << "):\n";
|
||||||
// BaseFactor::print("", formatter);
|
table_.print("", formatter);
|
||||||
cout << endl;
|
cout << endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,13 +29,16 @@
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Discrete Conditional Density which uses a SparseTable as the internal
|
* Discrete Conditional Density which uses a SparseVector as the internal
|
||||||
* representation, similar to the TableFactor.
|
* representation, similar to the TableFactor.
|
||||||
*
|
*
|
||||||
* @ingroup discrete
|
* @ingroup discrete
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
||||||
Eigen::SparseVector<double> sparse_table_;
|
private:
|
||||||
|
TableFactor table_;
|
||||||
|
|
||||||
|
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// typedefs needed to play nice with gtsam
|
// typedefs needed to play nice with gtsam
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue