use a TableFactor as the underlying data representation for DiscreteTableConditional since it provides a clean abstraction

release/4.3a0
Varun Agrawal 2024-12-31 00:10:37 -05:00
parent d18f23c47b
commit 4ff70141f8
2 changed files with 11 additions and 13 deletions

View File

@ -41,23 +41,21 @@ namespace gtsam {
DiscreteTableConditional::DiscreteTableConditional(const size_t nrFrontals, DiscreteTableConditional::DiscreteTableConditional(const size_t nrFrontals,
const TableFactor& f) const TableFactor& f)
: BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())), : BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())),
sparse_table_((f / (*f.sum(nrFrontals))).sparseTable()) { table_(f / (*f.sum(nrFrontals))) {}
// sparse_table_ = sparse_table_.prune();
}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional( DiscreteTableConditional::DiscreteTableConditional(
size_t nrFrontals, const DiscreteKeys& keys, size_t nrFrontals, const DiscreteKeys& keys,
const Eigen::SparseVector<double>& potentials) const Eigen::SparseVector<double>& potentials)
: BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())), : BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())),
sparse_table_(potentials) {} table_(TableFactor(keys, potentials)) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint, DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
const TableFactor& marginal) const TableFactor& marginal)
: BaseConditional(joint.size() - marginal.size(), : BaseConditional(joint.size() - marginal.size(),
joint.discreteKeys() & marginal.discreteKeys(), ADT()), joint.discreteKeys() & marginal.discreteKeys(), ADT()),
sparse_table_((joint / marginal).sparseTable()) {} table_(joint / marginal) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint, DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
@ -71,8 +69,7 @@ DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteTableConditional::DiscreteTableConditional(const Signature& signature) DiscreteTableConditional::DiscreteTableConditional(const Signature& signature)
: BaseConditional(1, DecisionTreeFactor()), : BaseConditional(1, DecisionTreeFactor()),
sparse_table_(TableFactor(signature.discreteKeys(), signature.cpt()) table_(TableFactor(signature.discreteKeys(), signature.cpt())) {}
.sparseTable()) {}
/* ************************************************************************** */ /* ************************************************************************** */
DiscreteTableConditional DiscreteTableConditional::operator*( DiscreteTableConditional DiscreteTableConditional::operator*(
@ -108,9 +105,7 @@ DiscreteTableConditional DiscreteTableConditional::operator*(
// Finally, add parents to keys, in order // Finally, add parents to keys, in order
for (auto&& dk : parents) discreteKeys.push_back(dk); for (auto&& dk : parents) discreteKeys.push_back(dk);
TableFactor a(this->discreteKeys(), this->sparse_table_), TableFactor product = this->table_ * other.table();
b(other.discreteKeys(), other.sparse_table_);
TableFactor product = a * other;
return DiscreteTableConditional(newFrontals.size(), product); return DiscreteTableConditional(newFrontals.size(), product);
} }
@ -128,7 +123,7 @@ void DiscreteTableConditional::print(const string& s,
} }
} }
cout << "):\n"; cout << "):\n";
// BaseFactor::print("", formatter); table_.print("", formatter);
cout << endl; cout << endl;
} }

View File

@ -29,13 +29,16 @@
namespace gtsam { namespace gtsam {
/** /**
* Discrete Conditional Density which uses a SparseTable as the internal * Discrete Conditional Density which uses a SparseVector as the internal
* representation, similar to the TableFactor. * representation, similar to the TableFactor.
* *
* @ingroup discrete * @ingroup discrete
*/ */
class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional { class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
Eigen::SparseVector<double> sparse_table_; private:
TableFactor table_;
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
public: public:
// typedefs needed to play nice with gtsam // typedefs needed to play nice with gtsam