Now using new WeightedSampler class

release/4.3a0
Frank Dellaert 2019-06-15 18:43:35 -04:00 committed by Frank Dellaert
parent f8af4a465d
commit 50e484a3c6
1 changed files with 13 additions and 78 deletions

View File

@ -17,6 +17,7 @@
#include <gtsam/base/DSFVector.h>
#include <gtsam/base/FastMap.h>
#include <gtsam/base/WeightedSampler.h>
#include <gtsam/inference/Ordering.h>
#include <gtsam/inference/VariableIndex.h>
#include <gtsam/linear/Errors.h>
@ -35,8 +36,11 @@
#include <iostream>
#include <numeric> // accumulate
#include <queue>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
using std::cout;
@ -68,81 +72,11 @@ static vector<size_t> sort_idx(const Container &src) {
return idx;
}
/*****************************************************************************/
static vector<size_t> iidSampler(const vector<double> &weight, const size_t n) {
/* compute the sum of the weights */
const double sum = std::accumulate(weight.begin(), weight.end(), 0.0);
if (sum == 0.0) {
throw std::runtime_error("Weight vector has no non-zero weights");
}
/* make a normalized and accumulated version of the weight vector */
const size_t m = weight.size();
vector<double> cdf;
cdf.reserve(m);
for (size_t i = 0; i < m; ++i) {
cdf.push_back(weight[i] / sum);
}
vector<double> acc(m);
std::partial_sum(cdf.begin(), cdf.end(), acc.begin());
/* iid sample n times */
vector<size_t> samples;
samples.reserve(n);
const double denominator = (double)RAND_MAX;
for (size_t i = 0; i < n; ++i) {
const double value = rand() / denominator;
/* binary search the interval containing "value" */
const auto it = std::lower_bound(acc.begin(), acc.end(), value);
const size_t index = it - acc.begin();
samples.push_back(index);
}
return samples;
}
/*****************************************************************************/
static vector<size_t> UniqueSampler(const vector<double> &weight,
const size_t n) {
const size_t m = weight.size();
if (n > m) throw std::invalid_argument("UniqueSampler: invalid input size");
vector<size_t> results;
size_t count = 0;
vector<bool> touched(m, false);
while (count < n) {
vector<size_t> localIndices;
localIndices.reserve(n - count);
vector<double> localWeights;
localWeights.reserve(n - count);
/* collect data */
for (size_t i = 0; i < m; ++i) {
if (!touched[i]) {
localIndices.push_back(i);
localWeights.push_back(weight[i]);
}
}
/* sampling and cache results */
vector<size_t> samples = iidSampler(localWeights, n - count);
for (const size_t &index : samples) {
if (touched[index] == false) {
touched[index] = true;
results.push_back(index);
if (++count >= n) break;
}
}
}
return results;
}
/****************************************************************************/
Subgraph::Subgraph(const vector<size_t> &indices) {
edges_.reserve(indices.size());
for (const size_t &index : indices) {
const Edge e {index,1.0};
const Edge e{index, 1.0};
edges_.push_back(e);
}
}
@ -423,12 +357,13 @@ vector<size_t> SubgraphBuilder::kruskal(const GaussianFactorGraph &gfg,
/****************************************************************/
vector<size_t> SubgraphBuilder::sample(const vector<double> &weights,
const size_t t) const {
return UniqueSampler(weights, t);
std::mt19937 rng(42); // TODO(frank): allow us to use a different seed
WeightedSampler<std::mt19937> sampler(&rng);
return sampler.sampleWithoutReplacement(t, weights);
}
/****************************************************************/
Subgraph SubgraphBuilder::operator()(
const GaussianFactorGraph &gfg) const {
Subgraph SubgraphBuilder::operator()(const GaussianFactorGraph &gfg) const {
const auto &p = parameters_;
const auto inverse_ordering = Ordering::Natural(gfg);
const FastMap<Key, size_t> forward_ordering = inverse_ordering.invert();
@ -518,15 +453,15 @@ GaussianFactorGraph::shared_ptr buildFactorSubgraph(
subgraphFactors->reserve(subgraph.size());
for (const auto &e : subgraph) {
const auto factor = gfg[e.index];
subgraphFactors->push_back(clone ? factor->clone(): factor);
subgraphFactors->push_back(clone ? factor->clone() : factor);
}
return subgraphFactors;
}
/**************************************************************************************************/
std::pair<GaussianFactorGraph::shared_ptr, GaussianFactorGraph::shared_ptr> //
splitFactorGraph(const GaussianFactorGraph &factorGraph, const Subgraph &subgraph) {
std::pair<GaussianFactorGraph::shared_ptr, GaussianFactorGraph::shared_ptr> //
splitFactorGraph(const GaussianFactorGraph &factorGraph,
const Subgraph &subgraph) {
// Get the subgraph by calling cheaper method
auto subgraphFactors = buildFactorSubgraph(factorGraph, subgraph, false);