Resolved review comments.
parent
4340b46828
commit
28f440a623
|
@ -293,21 +293,21 @@ HybridValues HybridBayesNet::sample() const {
|
|||
/* ************************************************************************* */
|
||||
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
||||
const VectorValues &continuousValues) const {
|
||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
AlgebraicDecisionTree<Key> result(0.0);
|
||||
|
||||
// Iterate over each conditional.
|
||||
for (auto &&conditional : *this) {
|
||||
if (auto gm = conditional->asMixture()) {
|
||||
// If conditional is hybrid, select based on assignment and compute
|
||||
// logProbability.
|
||||
error_tree = error_tree + gm->logProbability(continuousValues);
|
||||
result = result + gm->logProbability(continuousValues);
|
||||
} else if (auto gc = conditional->asGaussian()) {
|
||||
// If continuous, get the (double) logProbability and add it to the
|
||||
// error_tree
|
||||
// result
|
||||
double logProbability = gc->logProbability(continuousValues);
|
||||
// Add the computed logProbability to every leaf of the logProbability
|
||||
// tree.
|
||||
error_tree = error_tree.apply([logProbability](double leaf_value) {
|
||||
result = result.apply([logProbability](double leaf_value) {
|
||||
return leaf_value + logProbability;
|
||||
});
|
||||
} else if (auto dc = conditional->asDiscrete()) {
|
||||
|
@ -317,7 +317,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
|
|||
}
|
||||
}
|
||||
|
||||
return error_tree;
|
||||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
@ -54,9 +54,11 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
|
|||
|
||||
/**
|
||||
* Constructor that takes an initializer list of shared pointers.
|
||||
* BayesNet<SymbolicConditional> bn = {make_shared<SymbolicConditional>(), ...};
|
||||
* BayesNet<SymbolicConditional> bn = {make_shared<SymbolicConditional>(),
|
||||
* ...};
|
||||
*/
|
||||
BayesNet(std::initializer_list<sharedConditional> conditionals): Base(conditionals) {}
|
||||
BayesNet(std::initializer_list<sharedConditional> conditionals)
|
||||
: Base(conditionals) {}
|
||||
|
||||
/// @}
|
||||
|
||||
|
@ -91,7 +93,10 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
|
|||
/// @name HybridValues methods
|
||||
/// @{
|
||||
|
||||
// Expose HybridValues version of logProbability.
|
||||
double logProbability(const HybridValues& x) const;
|
||||
|
||||
// Expose HybridValues version of evaluate.
|
||||
double evaluate(const HybridValues& c) const;
|
||||
|
||||
/// @}
|
||||
|
|
|
@ -10,9 +10,6 @@ Author: Frank Dellaert
|
|||
"""
|
||||
# pylint: disable=invalid-name, no-name-in-module, no-member
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
@ -55,9 +52,9 @@ class TestGaussianBayesNet(GtsamTestCase):
|
|||
values.insert(_y_, np.array([5.0]))
|
||||
for i in [0, 1]:
|
||||
self.assertAlmostEqual(bayesNet.at(i).logProbability(values),
|
||||
math.log(bayesNet.at(i).evaluate(values)))
|
||||
np.log(bayesNet.at(i).evaluate(values)))
|
||||
self.assertAlmostEqual(bayesNet.logProbability(values),
|
||||
math.log(bayesNet.evaluate(values)))
|
||||
np.log(bayesNet.evaluate(values)))
|
||||
|
||||
def test_sample(self):
|
||||
"""Test sample method"""
|
||||
|
|
Loading…
Reference in New Issue