use denominators to compute the correct index in ComputeSparseTable

release/4.3a0
Varun Agrawal 2025-01-02 15:18:22 -05:00
parent 3fca55acc3
commit ff93c8be29
1 changed files with 12 additions and 6 deletions

View File

@ -89,6 +89,14 @@ static Eigen::SparseVector<double> ComputeSparseTable(
KeySet allKeys(dt.keys().begin(), dt.keys().end()); KeySet allKeys(dt.keys().begin(), dt.keys().end());
// Compute denominators to be used in computing sparse table indices
std::map<Key, size_t> denominators;
double denom = sparseTable.size();
for (const DiscreteKey& dkey : dkeys) {
denom /= dkey.second;
denominators.insert(std::pair<Key, double>(dkey.first, denom));
}
/** /**
* @brief Functor which is called by the DecisionTree for each leaf. * @brief Functor which is called by the DecisionTree for each leaf.
* For each leaf value, we use the corresponding assignment to compute a * For each leaf value, we use the corresponding assignment to compute a
@ -127,12 +135,10 @@ static Eigen::SparseVector<double> ComputeSparseTable(
// Generate index and add to the sparse vector. // Generate index and add to the sparse vector.
Eigen::Index idx = 0; Eigen::Index idx = 0;
size_t previousCardinality = 1;
// We go in reverse since a DecisionTree has the highest label first // We go in reverse since a DecisionTree has the highest label first
for (auto&& it = updatedAssignment.rbegin(); for (auto&& it = updatedAssignment.rbegin();
it != updatedAssignment.rend(); it++) { it != updatedAssignment.rend(); it++) {
idx += previousCardinality * it->second; idx += it->second * denominators.at(it->first);
previousCardinality *= dt.cardinality(it->first);
} }
sparseTable.coeffRef(idx) = p; sparseTable.coeffRef(idx) = p;
} }
@ -252,9 +258,9 @@ DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const {
DecisionTreeFactor TableFactor::toDecisionTreeFactor() const { DecisionTreeFactor TableFactor::toDecisionTreeFactor() const {
DiscreteKeys dkeys = discreteKeys(); DiscreteKeys dkeys = discreteKeys();
std::vector<double> table; std::vector<double> table(sparse_table_.size(), 0.0);
for (auto i = 0; i < sparse_table_.size(); i++) { for (SparseIt it(sparse_table_); it; ++it) {
table.push_back(sparse_table_.coeff(i)); table[it.index()] = it.value();
} }
AlgebraicDecisionTree<Key> tree(dkeys, table); AlgebraicDecisionTree<Key> tree(dkeys, table);