Revert "Add optional model parameter to sample methods"
This reverts commit 4fc02a6aa2.
release/4.3a0
parent
e9978284c8
commit
789b5d2eb6
|
|
@ -59,30 +59,27 @@ namespace gtsam {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
VectorValues GaussianBayesNet::sample(std::mt19937_64* rng,
|
VectorValues GaussianBayesNet::sample(std::mt19937_64* rng) const {
|
||||||
const SharedDiagonal& model) const {
|
|
||||||
VectorValues result; // no missing variables -> create an empty vector
|
VectorValues result; // no missing variables -> create an empty vector
|
||||||
return sample(result, rng, model);
|
return sample(result, rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorValues GaussianBayesNet::sample(VectorValues result,
|
VectorValues GaussianBayesNet::sample(VectorValues result,
|
||||||
std::mt19937_64* rng,
|
std::mt19937_64* rng) const {
|
||||||
const SharedDiagonal& model) const {
|
|
||||||
// sample each node in reverse topological sort order (parents first)
|
// sample each node in reverse topological sort order (parents first)
|
||||||
for (auto cg : boost::adaptors::reverse(*this)) {
|
for (auto cg : boost::adaptors::reverse(*this)) {
|
||||||
const VectorValues sampled = cg->sample(result, rng, model);
|
const VectorValues sampled = cg->sample(result, rng);
|
||||||
result.insert(sampled);
|
result.insert(sampled);
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
VectorValues GaussianBayesNet::sample(const SharedDiagonal& model) const {
|
VectorValues GaussianBayesNet::sample() const {
|
||||||
return sample(&kRandomNumberGenerator);
|
return sample(&kRandomNumberGenerator);
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorValues GaussianBayesNet::sample(VectorValues given,
|
VectorValues GaussianBayesNet::sample(VectorValues given) const {
|
||||||
const SharedDiagonal& model) const {
|
|
||||||
return sample(given, &kRandomNumberGenerator);
|
return sample(given, &kRandomNumberGenerator);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -101,8 +101,7 @@ namespace gtsam {
|
||||||
* std::mt19937_64 rng(42);
|
* std::mt19937_64 rng(42);
|
||||||
* auto sample = gbn.sample(&rng);
|
* auto sample = gbn.sample(&rng);
|
||||||
*/
|
*/
|
||||||
VectorValues sample(std::mt19937_64* rng,
|
VectorValues sample(std::mt19937_64* rng) const;
|
||||||
const SharedDiagonal& model = nullptr) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sample from an incomplete BayesNet, given missing variables
|
* Sample from an incomplete BayesNet, given missing variables
|
||||||
|
|
@ -111,15 +110,13 @@ namespace gtsam {
|
||||||
* VectorValues given = ...;
|
* VectorValues given = ...;
|
||||||
* auto sample = gbn.sample(given, &rng);
|
* auto sample = gbn.sample(given, &rng);
|
||||||
*/
|
*/
|
||||||
VectorValues sample(VectorValues given, std::mt19937_64* rng,
|
VectorValues sample(VectorValues given, std::mt19937_64* rng) const;
|
||||||
const SharedDiagonal& model = nullptr) const;
|
|
||||||
|
|
||||||
/// Sample using ancestral sampling, use default rng
|
/// Sample using ancestral sampling, use default rng
|
||||||
VectorValues sample(const SharedDiagonal& model = nullptr) const;
|
VectorValues sample() const;
|
||||||
|
|
||||||
/// Sample from an incomplete BayesNet, use default rng
|
/// Sample from an incomplete BayesNet, use default rng
|
||||||
VectorValues sample(VectorValues given,
|
VectorValues sample(VectorValues given) const;
|
||||||
const SharedDiagonal& model = nullptr) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return ordering corresponding to a topological sort.
|
* Return ordering corresponding to a topological sort.
|
||||||
|
|
|
||||||
|
|
@ -293,48 +293,39 @@ double GaussianConditional::logDeterminant() const {
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
VectorValues GaussianConditional::sample(const VectorValues& parentsValues,
|
VectorValues GaussianConditional::sample(const VectorValues& parentsValues,
|
||||||
std::mt19937_64* rng,
|
std::mt19937_64* rng) const {
|
||||||
const SharedDiagonal& model) const {
|
|
||||||
if (nrFrontals() != 1) {
|
if (nrFrontals() != 1) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"GaussianConditional::sample can only be called on single variable "
|
"GaussianConditional::sample can only be called on single variable "
|
||||||
"conditionals");
|
"conditionals");
|
||||||
}
|
}
|
||||||
|
if (!model_) {
|
||||||
VectorValues solution = solve(parentsValues);
|
|
||||||
Key key = firstFrontalKey();
|
|
||||||
|
|
||||||
Vector sigmas;
|
|
||||||
if (model_) {
|
|
||||||
sigmas = model_->sigmas();
|
|
||||||
} else if (model) {
|
|
||||||
sigmas = model->sigmas();
|
|
||||||
} else {
|
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"GaussianConditional::sample can only be called if a diagonal noise "
|
"GaussianConditional::sample can only be called if a diagonal noise "
|
||||||
"model was specified at construction.");
|
"model was specified at construction.");
|
||||||
}
|
}
|
||||||
|
VectorValues solution = solve(parentsValues);
|
||||||
|
Key key = firstFrontalKey();
|
||||||
|
const Vector& sigmas = model_->sigmas();
|
||||||
solution[key] += Sampler::sampleDiagonal(sigmas, rng);
|
solution[key] += Sampler::sampleDiagonal(sigmas, rng);
|
||||||
return solution;
|
return solution;
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorValues GaussianConditional::sample(std::mt19937_64* rng,
|
VectorValues GaussianConditional::sample(std::mt19937_64* rng) const {
|
||||||
const SharedDiagonal& model) const {
|
|
||||||
if (nrParents() != 0)
|
if (nrParents() != 0)
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"sample() can only be invoked on no-parent prior");
|
"sample() can only be invoked on no-parent prior");
|
||||||
VectorValues values;
|
VectorValues values;
|
||||||
return sample(values, rng, model);
|
return sample(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
VectorValues GaussianConditional::sample(const SharedDiagonal& model) const {
|
VectorValues GaussianConditional::sample() const {
|
||||||
return sample(&kRandomNumberGenerator, model);
|
return sample(&kRandomNumberGenerator);
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorValues GaussianConditional::sample(const VectorValues& given,
|
VectorValues GaussianConditional::sample(const VectorValues& given) const {
|
||||||
const SharedDiagonal& model) const {
|
return sample(given, &kRandomNumberGenerator);
|
||||||
return sample(given, &kRandomNumberGenerator, model);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
|
||||||
|
|
@ -188,8 +188,7 @@ namespace gtsam {
|
||||||
* std::mt19937_64 rng(42);
|
* std::mt19937_64 rng(42);
|
||||||
* auto sample = gbn.sample(&rng);
|
* auto sample = gbn.sample(&rng);
|
||||||
*/
|
*/
|
||||||
VectorValues sample(std::mt19937_64* rng,
|
VectorValues sample(std::mt19937_64* rng) const;
|
||||||
const SharedDiagonal& model = nullptr) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sample from conditional, given missing variables
|
* Sample from conditional, given missing variables
|
||||||
|
|
@ -199,15 +198,13 @@ namespace gtsam {
|
||||||
* auto sample = gbn.sample(given, &rng);
|
* auto sample = gbn.sample(given, &rng);
|
||||||
*/
|
*/
|
||||||
VectorValues sample(const VectorValues& parentsValues,
|
VectorValues sample(const VectorValues& parentsValues,
|
||||||
std::mt19937_64* rng,
|
std::mt19937_64* rng) const;
|
||||||
const SharedDiagonal& model = nullptr) const;
|
|
||||||
|
|
||||||
/// Sample, use default rng
|
/// Sample, use default rng
|
||||||
VectorValues sample(const SharedDiagonal& model = nullptr) const;
|
VectorValues sample() const;
|
||||||
|
|
||||||
/// Sample with given values, use default rng
|
/// Sample with given values, use default rng
|
||||||
VectorValues sample(const VectorValues& parentsValues,
|
VectorValues sample(const VectorValues& parentsValues) const;
|
||||||
const SharedDiagonal& model = nullptr) const;
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue