use a TableFactor as the underlying data representation for DiscreteTableConditional since it provides a clean abstraction

release/4.3a0
Varun Agrawal 2024-12-31 00:10:37 -05:00
parent d18f23c47b
commit 4ff70141f8
2 changed files with 11 additions and 13 deletions

View File

@ -41,23 +41,21 @@ namespace gtsam {
DiscreteTableConditional::DiscreteTableConditional(const size_t nrFrontals,
const TableFactor& f)
: BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())),
sparse_table_((f / (*f.sum(nrFrontals))).sparseTable()) {
// sparse_table_ = sparse_table_.prune();
}
table_(f / (*f.sum(nrFrontals))) {}
/* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(
size_t nrFrontals, const DiscreteKeys& keys,
const Eigen::SparseVector<double>& potentials)
: BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())),
sparse_table_(potentials) {}
table_(TableFactor(keys, potentials)) {}
/* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
const TableFactor& marginal)
: BaseConditional(joint.size() - marginal.size(),
joint.discreteKeys() & marginal.discreteKeys(), ADT()),
sparse_table_((joint / marginal).sparseTable()) {}
table_(joint / marginal) {}
/* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
@ -71,8 +69,7 @@ DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
/* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const Signature& signature)
: BaseConditional(1, DecisionTreeFactor()),
sparse_table_(TableFactor(signature.discreteKeys(), signature.cpt())
.sparseTable()) {}
table_(TableFactor(signature.discreteKeys(), signature.cpt())) {}
/* ************************************************************************** */
DiscreteTableConditional DiscreteTableConditional::operator*(
@ -108,9 +105,7 @@ DiscreteTableConditional DiscreteTableConditional::operator*(
// Finally, add parents to keys, in order
for (auto&& dk : parents) discreteKeys.push_back(dk);
TableFactor a(this->discreteKeys(), this->sparse_table_),
b(other.discreteKeys(), other.sparse_table_);
TableFactor product = a * other;
TableFactor product = this->table_ * other.table();
return DiscreteTableConditional(newFrontals.size(), product);
}
@ -128,7 +123,7 @@ void DiscreteTableConditional::print(const string& s,
}
}
cout << "):\n";
// BaseFactor::print("", formatter);
table_.print("", formatter);
cout << endl;
}

View File

@ -29,13 +29,16 @@
namespace gtsam {
/**
* Discrete Conditional Density which uses a SparseTable as the internal
* Discrete Conditional Density which uses a SparseVector as the internal
* representation, similar to the TableFactor.
*
* @ingroup discrete
*/
class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
Eigen::SparseVector<double> sparse_table_;
private:
TableFactor table_;
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
public:
// typedefs needed to play nice with gtsam