diff --git a/gtsam/linear/SubgraphBuilder.cpp b/gtsam/linear/SubgraphBuilder.cpp index a999b3a71..c6b3ca15f 100644 --- a/gtsam/linear/SubgraphBuilder.cpp +++ b/gtsam/linear/SubgraphBuilder.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -35,8 +36,11 @@ #include #include // accumulate #include +#include #include #include +#include +#include #include using std::cout; @@ -68,81 +72,11 @@ static vector sort_idx(const Container &src) { return idx; } -/*****************************************************************************/ -static vector iidSampler(const vector &weight, const size_t n) { - /* compute the sum of the weights */ - const double sum = std::accumulate(weight.begin(), weight.end(), 0.0); - if (sum == 0.0) { - throw std::runtime_error("Weight vector has no non-zero weights"); - } - - /* make a normalized and accumulated version of the weight vector */ - const size_t m = weight.size(); - vector cdf; - cdf.reserve(m); - for (size_t i = 0; i < m; ++i) { - cdf.push_back(weight[i] / sum); - } - - vector acc(m); - std::partial_sum(cdf.begin(), cdf.end(), acc.begin()); - - /* iid sample n times */ - vector samples; - samples.reserve(n); - const double denominator = (double)RAND_MAX; - for (size_t i = 0; i < n; ++i) { - const double value = rand() / denominator; - /* binary search the interval containing "value" */ - const auto it = std::lower_bound(acc.begin(), acc.end(), value); - const size_t index = it - acc.begin(); - samples.push_back(index); - } - return samples; -} - -/*****************************************************************************/ -static vector UniqueSampler(const vector &weight, - const size_t n) { - const size_t m = weight.size(); - if (n > m) throw std::invalid_argument("UniqueSampler: invalid input size"); - - vector results; - - size_t count = 0; - vector touched(m, false); - while (count < n) { - vector localIndices; - localIndices.reserve(n - count); - vector localWeights; - localWeights.reserve(n - count); - - /* collect data */ - for (size_t i = 0; i < m; ++i) { - if (!touched[i]) { - localIndices.push_back(i); - localWeights.push_back(weight[i]); - } - } - - /* sampling and cache results */ - vector samples = iidSampler(localWeights, n - count); - for (const size_t &index : samples) { - if (touched[index] == false) { - touched[index] = true; - results.push_back(index); - if (++count >= n) break; - } - } - } - return results; -} - /****************************************************************************/ Subgraph::Subgraph(const vector &indices) { edges_.reserve(indices.size()); for (const size_t &index : indices) { - const Edge e {index,1.0}; + const Edge e{index, 1.0}; edges_.push_back(e); } } @@ -423,12 +357,13 @@ vector SubgraphBuilder::kruskal(const GaussianFactorGraph &gfg, /****************************************************************/ vector SubgraphBuilder::sample(const vector &weights, const size_t t) const { - return UniqueSampler(weights, t); + std::mt19937 rng(42); // TODO(frank): allow us to use a different seed + WeightedSampler sampler(&rng); + return sampler.sampleWithoutReplacement(t, weights); } /****************************************************************/ -Subgraph SubgraphBuilder::operator()( - const GaussianFactorGraph &gfg) const { +Subgraph SubgraphBuilder::operator()(const GaussianFactorGraph &gfg) const { const auto &p = parameters_; const auto inverse_ordering = Ordering::Natural(gfg); const FastMap forward_ordering = inverse_ordering.invert(); @@ -518,15 +453,15 @@ GaussianFactorGraph::shared_ptr buildFactorSubgraph( subgraphFactors->reserve(subgraph.size()); for (const auto &e : subgraph) { const auto factor = gfg[e.index]; - subgraphFactors->push_back(clone ? factor->clone(): factor); + subgraphFactors->push_back(clone ? factor->clone() : factor); } return subgraphFactors; } /**************************************************************************************************/ -std::pair // -splitFactorGraph(const GaussianFactorGraph &factorGraph, const Subgraph &subgraph) { - +std::pair // +splitFactorGraph(const GaussianFactorGraph &factorGraph, + const Subgraph &subgraph) { // Get the subgraph by calling cheaper method auto subgraphFactors = buildFactorSubgraph(factorGraph, subgraph, false);