sampling test
parent
ffd1802cea
commit
bdb7836d0e
|
|
@ -279,6 +279,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::error(
|
||||||
return error_tree;
|
return error_tree;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
|
AlgebraicDecisionTree<Key> HybridBayesNet::probPrime(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
AlgebraicDecisionTree<Key> error_tree = this->error(continuousValues);
|
||||||
|
|
|
||||||
|
|
@ -432,6 +432,74 @@ TEST(HybridEstimation, ProbabilityMultifrontal) {
|
||||||
EXPECT(assert_equal(discrete_seq, hybrid_values.discrete()));
|
EXPECT(assert_equal(discrete_seq, hybrid_values.discrete()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/****************************************************************************/
|
||||||
|
/**
|
||||||
|
* Test for correctness via sampling.
|
||||||
|
*
|
||||||
|
* Given the conditional P(x0, m0, x1| z0, z1)
|
||||||
|
* with meaasurements z0, z1, we:
|
||||||
|
* 1. Start with the corresponding Factor Graph `FG`.
|
||||||
|
* 2. Eliminate the factor graph into a Bayes Net `BN`.
|
||||||
|
* 3. Sample from the Bayes Net.
|
||||||
|
* 4. Check that the ratio `BN(x)/FG(x) = constant` for all samples `x`.
|
||||||
|
*/
|
||||||
|
TEST(HybridEstimation, CorrectnessViaSampling) {
|
||||||
|
HybridNonlinearFactorGraph nfg;
|
||||||
|
|
||||||
|
auto noise_model = noiseModel::Diagonal::Sigmas(Vector1(1.0));
|
||||||
|
auto zero_motion =
|
||||||
|
boost::make_shared<BetweenFactor<double>>(X(0), X(1), 0, noise_model);
|
||||||
|
auto one_motion =
|
||||||
|
boost::make_shared<BetweenFactor<double>>(X(0), X(1), 1, noise_model);
|
||||||
|
std::vector<NonlinearFactor::shared_ptr> factors = {zero_motion, one_motion};
|
||||||
|
nfg.emplace_nonlinear<PriorFactor<double>>(X(0), 0.0, noise_model);
|
||||||
|
nfg.emplace_hybrid<MixtureFactor>(
|
||||||
|
KeyVector{X(0), X(1)}, DiscreteKeys{DiscreteKey(M(0), 2)}, factors);
|
||||||
|
|
||||||
|
Values initial;
|
||||||
|
double z0 = 0.0, z1 = 1.0;
|
||||||
|
initial.insert<double>(X(0), z0);
|
||||||
|
initial.insert<double>(X(1), z1);
|
||||||
|
|
||||||
|
// 1. Create the factor graph from the nonlinear factor graph.
|
||||||
|
HybridGaussianFactorGraph::shared_ptr fg = nfg.linearize(initial);
|
||||||
|
// 2. Eliminate into BN
|
||||||
|
Ordering ordering = fg->getHybridOrdering();
|
||||||
|
HybridBayesNet::shared_ptr bn = fg->eliminateSequential(ordering);
|
||||||
|
|
||||||
|
// Set up sampling
|
||||||
|
std::random_device rd;
|
||||||
|
std::mt19937_64 gen(rd());
|
||||||
|
// Discrete distribution with 50/50 weightage on both discrete variables.
|
||||||
|
std::discrete_distribution<> ddist({50, 50});
|
||||||
|
|
||||||
|
// 3. Do sampling
|
||||||
|
std::vector<double> ratios;
|
||||||
|
int num_samples = 1000;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < num_samples; i++) {
|
||||||
|
// Sample a discrete value
|
||||||
|
DiscreteValues assignment;
|
||||||
|
assignment[M(0)] = ddist(gen);
|
||||||
|
|
||||||
|
// Using the discrete sample, get the corresponding bayes net.
|
||||||
|
GaussianBayesNet gbn = bn->choose(assignment);
|
||||||
|
// Sample from the bayes net
|
||||||
|
VectorValues sample = gbn.sample(&gen, noise_model);
|
||||||
|
// Compute the ratio in log form and canonical form
|
||||||
|
double log_ratio = bn->error(sample, assignment) - fg->error(sample, assignment);
|
||||||
|
double ratio = exp(-log_ratio);
|
||||||
|
|
||||||
|
// Store the ratio for post-processing
|
||||||
|
ratios.push_back(ratio);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Check that all samples == 1.0 (constant)
|
||||||
|
double ratio_sum = std::accumulate(ratios.begin(), ratios.end(),
|
||||||
|
decltype(ratios)::value_type(0));
|
||||||
|
EXPECT_DOUBLES_EQUAL(1.0, ratio_sum / num_samples, 1e-9);
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue