From 789b5d2eb681e700867294c8e75006671f77c77a Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 24 Dec 2022 06:58:49 +0530 Subject: [PATCH] Revert "Add optional model parameter to sample methods" This reverts commit 4fc02a6aa296b3214b850f76cb07e5a977faf023. --- gtsam/linear/GaussianBayesNet.cpp | 15 ++++++-------- gtsam/linear/GaussianBayesNet.h | 11 ++++------ gtsam/linear/GaussianConditional.cpp | 31 ++++++++++------------------ gtsam/linear/GaussianConditional.h | 11 ++++------ 4 files changed, 25 insertions(+), 43 deletions(-) diff --git a/gtsam/linear/GaussianBayesNet.cpp b/gtsam/linear/GaussianBayesNet.cpp index d42fbe772..41a734b34 100644 --- a/gtsam/linear/GaussianBayesNet.cpp +++ b/gtsam/linear/GaussianBayesNet.cpp @@ -59,30 +59,27 @@ namespace gtsam { } /* ************************************************************************ */ - VectorValues GaussianBayesNet::sample(std::mt19937_64* rng, - const SharedDiagonal& model) const { + VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const { VectorValues result; // no missing variables -> create an empty vector - return sample(result, rng, model); + return sample(result, rng); } VectorValues GaussianBayesNet::sample(VectorValues result, - std::mt19937_64* rng, - const SharedDiagonal& model) const { + std::mt19937_64* rng) const { // sample each node in reverse topological sort order (parents first) for (auto cg : boost::adaptors::reverse(*this)) { - const VectorValues sampled = cg->sample(result, rng, model); + const VectorValues sampled = cg->sample(result, rng); result.insert(sampled); } return result; } /* ************************************************************************ */ - VectorValues GaussianBayesNet::sample(const SharedDiagonal& model) const { + VectorValues GaussianBayesNet::sample() const { return sample(&kRandomNumberGenerator); } - VectorValues GaussianBayesNet::sample(VectorValues given, - const SharedDiagonal& model) const { + VectorValues GaussianBayesNet::sample(VectorValues given) const { return sample(given, &kRandomNumberGenerator); } diff --git a/gtsam/linear/GaussianBayesNet.h b/gtsam/linear/GaussianBayesNet.h index 570bfef58..83328576f 100644 --- a/gtsam/linear/GaussianBayesNet.h +++ b/gtsam/linear/GaussianBayesNet.h @@ -101,8 +101,7 @@ namespace gtsam { * std::mt19937_64 rng(42); * auto sample = gbn.sample(&rng); */ - VectorValues sample(std::mt19937_64* rng, - const SharedDiagonal& model = nullptr) const; + VectorValues sample(std::mt19937_64* rng) const; /** * Sample from an incomplete BayesNet, given missing variables @@ -111,15 +110,13 @@ namespace gtsam { * VectorValues given = ...; * auto sample = gbn.sample(given, &rng); */ - VectorValues sample(VectorValues given, std::mt19937_64* rng, - const SharedDiagonal& model = nullptr) const; + VectorValues sample(VectorValues given, std::mt19937_64* rng) const; /// Sample using ancestral sampling, use default rng - VectorValues sample(const SharedDiagonal& model = nullptr) const; + VectorValues sample() const; /// Sample from an incomplete BayesNet, use default rng - VectorValues sample(VectorValues given, - const SharedDiagonal& model = nullptr) const; + VectorValues sample(VectorValues given) const; /** * Return ordering corresponding to a topological sort. diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 363d25d11..60ddb1b7d 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -293,48 +293,39 @@ double GaussianConditional::logDeterminant() const { /* ************************************************************************ */ VectorValues GaussianConditional::sample(const VectorValues& parentsValues, - std::mt19937_64* rng, - const SharedDiagonal& model) const { + std::mt19937_64* rng) const { if (nrFrontals() != 1) { throw std::invalid_argument( "GaussianConditional::sample can only be called on single variable " "conditionals"); } - - VectorValues solution = solve(parentsValues); - Key key = firstFrontalKey(); - - Vector sigmas; - if (model_) { - sigmas = model_->sigmas(); - } else if (model) { - sigmas = model->sigmas(); - } else { + if (!model_) { throw std::invalid_argument( "GaussianConditional::sample can only be called if a diagonal noise " "model was specified at construction."); } + VectorValues solution = solve(parentsValues); + Key key = firstFrontalKey(); + const Vector& sigmas = model_->sigmas(); solution[key] += Sampler::sampleDiagonal(sigmas, rng); return solution; } - VectorValues GaussianConditional::sample(std::mt19937_64* rng, - const SharedDiagonal& model) const { + VectorValues GaussianConditional::sample(std::mt19937_64* rng) const { if (nrParents() != 0) throw std::invalid_argument( "sample() can only be invoked on no-parent prior"); VectorValues values; - return sample(values, rng, model); + return sample(values); } /* ************************************************************************ */ - VectorValues GaussianConditional::sample(const SharedDiagonal& model) const { - return sample(&kRandomNumberGenerator, model); + VectorValues GaussianConditional::sample() const { + return sample(&kRandomNumberGenerator); } - VectorValues GaussianConditional::sample(const VectorValues& given, - const SharedDiagonal& model) const { - return sample(given, &kRandomNumberGenerator, model); + VectorValues GaussianConditional::sample(const VectorValues& given) const { + return sample(given, &kRandomNumberGenerator); } /* ************************************************************************ */ diff --git a/gtsam/linear/GaussianConditional.h b/gtsam/linear/GaussianConditional.h index 1ca9b7d53..8af7f6602 100644 --- a/gtsam/linear/GaussianConditional.h +++ b/gtsam/linear/GaussianConditional.h @@ -188,8 +188,7 @@ namespace gtsam { * std::mt19937_64 rng(42); * auto sample = gbn.sample(&rng); */ - VectorValues sample(std::mt19937_64* rng, - const SharedDiagonal& model = nullptr) const; + VectorValues sample(std::mt19937_64* rng) const; /** * Sample from conditional, given missing variables @@ -199,15 +198,13 @@ namespace gtsam { * auto sample = gbn.sample(given, &rng); */ VectorValues sample(const VectorValues& parentsValues, - std::mt19937_64* rng, - const SharedDiagonal& model = nullptr) const; + std::mt19937_64* rng) const; /// Sample, use default rng - VectorValues sample(const SharedDiagonal& model = nullptr) const; + VectorValues sample() const; /// Sample with given values, use default rng - VectorValues sample(const VectorValues& parentsValues, - const SharedDiagonal& model = nullptr) const; + VectorValues sample(const VectorValues& parentsValues) const; /// @}