Resolved review comments.
parent
4340b46828
commit
28f440a623
|
|
@ -293,21 +293,21 @@ HybridValues HybridBayesNet::sample() const {
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
AlgebraicDecisionTree<Key> result(0.0);
|
||||||
|
|
||||||
// Iterate over each conditional.
|
// Iterate over each conditional.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
if (auto gm = conditional->asMixture()) {
|
if (auto gm = conditional->asMixture()) {
|
||||||
// If conditional is hybrid, select based on assignment and compute
|
// If conditional is hybrid, select based on assignment and compute
|
||||||
// logProbability.
|
// logProbability.
|
||||||
error_tree = error_tree + gm->logProbability(continuousValues);
|
result = result + gm->logProbability(continuousValues);
|
||||||
} else if (auto gc = conditional->asGaussian()) {
|
} else if (auto gc = conditional->asGaussian()) {
|
||||||
// If continuous, get the (double) logProbability and add it to the
|
// If continuous, get the (double) logProbability and add it to the
|
||||||
// error_tree
|
// result
|
||||||
double logProbability = gc->logProbability(continuousValues);
|
double logProbability = gc->logProbability(continuousValues);
|
||||||
// Add the computed logProbability to every leaf of the logProbability
|
// Add the computed logProbability to every leaf of the logProbability
|
||||||
// tree.
|
// tree.
|
||||||
error_tree = error_tree.apply([logProbability](double leaf_value) {
|
result = result.apply([logProbability](double leaf_value) {
|
||||||
return leaf_value + logProbability;
|
return leaf_value + logProbability;
|
||||||
});
|
});
|
||||||
} else if (auto dc = conditional->asDiscrete()) {
|
} else if (auto dc = conditional->asDiscrete()) {
|
||||||
|
|
@ -317,7 +317,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return error_tree;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
|
||||||
|
|
@ -54,9 +54,11 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructor that takes an initializer list of shared pointers.
|
* Constructor that takes an initializer list of shared pointers.
|
||||||
* BayesNet<SymbolicConditional> bn = {make_shared<SymbolicConditional>(), ...};
|
* BayesNet<SymbolicConditional> bn = {make_shared<SymbolicConditional>(),
|
||||||
|
* ...};
|
||||||
*/
|
*/
|
||||||
BayesNet(std::initializer_list<sharedConditional> conditionals): Base(conditionals) {}
|
BayesNet(std::initializer_list<sharedConditional> conditionals)
|
||||||
|
: Base(conditionals) {}
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
@ -91,7 +93,10 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
|
||||||
/// @name HybridValues methods
|
/// @name HybridValues methods
|
||||||
/// @{
|
/// @{
|
||||||
|
|
||||||
|
// Expose HybridValues version of logProbability.
|
||||||
double logProbability(const HybridValues& x) const;
|
double logProbability(const HybridValues& x) const;
|
||||||
|
|
||||||
|
// Expose HybridValues version of evaluate.
|
||||||
double evaluate(const HybridValues& c) const;
|
double evaluate(const HybridValues& c) const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,6 @@ Author: Frank Dellaert
|
||||||
"""
|
"""
|
||||||
# pylint: disable=invalid-name, no-name-in-module, no-member
|
# pylint: disable=invalid-name, no-name-in-module, no-member
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import math
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -55,9 +52,9 @@ class TestGaussianBayesNet(GtsamTestCase):
|
||||||
values.insert(_y_, np.array([5.0]))
|
values.insert(_y_, np.array([5.0]))
|
||||||
for i in [0, 1]:
|
for i in [0, 1]:
|
||||||
self.assertAlmostEqual(bayesNet.at(i).logProbability(values),
|
self.assertAlmostEqual(bayesNet.at(i).logProbability(values),
|
||||||
math.log(bayesNet.at(i).evaluate(values)))
|
np.log(bayesNet.at(i).evaluate(values)))
|
||||||
self.assertAlmostEqual(bayesNet.logProbability(values),
|
self.assertAlmostEqual(bayesNet.logProbability(values),
|
||||||
math.log(bayesNet.evaluate(values)))
|
np.log(bayesNet.evaluate(values)))
|
||||||
|
|
||||||
def test_sample(self):
|
def test_sample(self):
|
||||||
"""Test sample method"""
|
"""Test sample method"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue