Now using new WeightedSampler class
parent
f8af4a465d
commit
50e484a3c6
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include <gtsam/base/DSFVector.h>
|
||||
#include <gtsam/base/FastMap.h>
|
||||
#include <gtsam/base/WeightedSampler.h>
|
||||
#include <gtsam/inference/Ordering.h>
|
||||
#include <gtsam/inference/VariableIndex.h>
|
||||
#include <gtsam/linear/Errors.h>
|
||||
|
@ -35,8 +36,11 @@
|
|||
#include <iostream>
|
||||
#include <numeric> // accumulate
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
using std::cout;
|
||||
|
@ -68,81 +72,11 @@ static vector<size_t> sort_idx(const Container &src) {
|
|||
return idx;
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
static vector<size_t> iidSampler(const vector<double> &weight, const size_t n) {
|
||||
/* compute the sum of the weights */
|
||||
const double sum = std::accumulate(weight.begin(), weight.end(), 0.0);
|
||||
if (sum == 0.0) {
|
||||
throw std::runtime_error("Weight vector has no non-zero weights");
|
||||
}
|
||||
|
||||
/* make a normalized and accumulated version of the weight vector */
|
||||
const size_t m = weight.size();
|
||||
vector<double> cdf;
|
||||
cdf.reserve(m);
|
||||
for (size_t i = 0; i < m; ++i) {
|
||||
cdf.push_back(weight[i] / sum);
|
||||
}
|
||||
|
||||
vector<double> acc(m);
|
||||
std::partial_sum(cdf.begin(), cdf.end(), acc.begin());
|
||||
|
||||
/* iid sample n times */
|
||||
vector<size_t> samples;
|
||||
samples.reserve(n);
|
||||
const double denominator = (double)RAND_MAX;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
const double value = rand() / denominator;
|
||||
/* binary search the interval containing "value" */
|
||||
const auto it = std::lower_bound(acc.begin(), acc.end(), value);
|
||||
const size_t index = it - acc.begin();
|
||||
samples.push_back(index);
|
||||
}
|
||||
return samples;
|
||||
}
|
||||
|
||||
/*****************************************************************************/
|
||||
static vector<size_t> UniqueSampler(const vector<double> &weight,
|
||||
const size_t n) {
|
||||
const size_t m = weight.size();
|
||||
if (n > m) throw std::invalid_argument("UniqueSampler: invalid input size");
|
||||
|
||||
vector<size_t> results;
|
||||
|
||||
size_t count = 0;
|
||||
vector<bool> touched(m, false);
|
||||
while (count < n) {
|
||||
vector<size_t> localIndices;
|
||||
localIndices.reserve(n - count);
|
||||
vector<double> localWeights;
|
||||
localWeights.reserve(n - count);
|
||||
|
||||
/* collect data */
|
||||
for (size_t i = 0; i < m; ++i) {
|
||||
if (!touched[i]) {
|
||||
localIndices.push_back(i);
|
||||
localWeights.push_back(weight[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/* sampling and cache results */
|
||||
vector<size_t> samples = iidSampler(localWeights, n - count);
|
||||
for (const size_t &index : samples) {
|
||||
if (touched[index] == false) {
|
||||
touched[index] = true;
|
||||
results.push_back(index);
|
||||
if (++count >= n) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
/****************************************************************************/
|
||||
Subgraph::Subgraph(const vector<size_t> &indices) {
|
||||
edges_.reserve(indices.size());
|
||||
for (const size_t &index : indices) {
|
||||
const Edge e {index,1.0};
|
||||
const Edge e{index, 1.0};
|
||||
edges_.push_back(e);
|
||||
}
|
||||
}
|
||||
|
@ -423,12 +357,13 @@ vector<size_t> SubgraphBuilder::kruskal(const GaussianFactorGraph &gfg,
|
|||
/****************************************************************/
|
||||
vector<size_t> SubgraphBuilder::sample(const vector<double> &weights,
|
||||
const size_t t) const {
|
||||
return UniqueSampler(weights, t);
|
||||
std::mt19937 rng(42); // TODO(frank): allow us to use a different seed
|
||||
WeightedSampler<std::mt19937> sampler(&rng);
|
||||
return sampler.sampleWithoutReplacement(t, weights);
|
||||
}
|
||||
|
||||
/****************************************************************/
|
||||
Subgraph SubgraphBuilder::operator()(
|
||||
const GaussianFactorGraph &gfg) const {
|
||||
Subgraph SubgraphBuilder::operator()(const GaussianFactorGraph &gfg) const {
|
||||
const auto &p = parameters_;
|
||||
const auto inverse_ordering = Ordering::Natural(gfg);
|
||||
const FastMap<Key, size_t> forward_ordering = inverse_ordering.invert();
|
||||
|
@ -518,15 +453,15 @@ GaussianFactorGraph::shared_ptr buildFactorSubgraph(
|
|||
subgraphFactors->reserve(subgraph.size());
|
||||
for (const auto &e : subgraph) {
|
||||
const auto factor = gfg[e.index];
|
||||
subgraphFactors->push_back(clone ? factor->clone(): factor);
|
||||
subgraphFactors->push_back(clone ? factor->clone() : factor);
|
||||
}
|
||||
return subgraphFactors;
|
||||
}
|
||||
|
||||
/**************************************************************************************************/
|
||||
std::pair<GaussianFactorGraph::shared_ptr, GaussianFactorGraph::shared_ptr> //
|
||||
splitFactorGraph(const GaussianFactorGraph &factorGraph, const Subgraph &subgraph) {
|
||||
|
||||
std::pair<GaussianFactorGraph::shared_ptr, GaussianFactorGraph::shared_ptr> //
|
||||
splitFactorGraph(const GaussianFactorGraph &factorGraph,
|
||||
const Subgraph &subgraph) {
|
||||
// Get the subgraph by calling cheaper method
|
||||
auto subgraphFactors = buildFactorSubgraph(factorGraph, subgraph, false);
|
||||
|
||||
|
|
Loading…
Reference in New Issue