Resolved review comments.

release/4.3a0
Frank Dellaert 2023-01-11 18:11:28 -08:00
parent 4340b46828
commit 28f440a623
3 changed files with 14 additions and 12 deletions

View File

@ -293,21 +293,21 @@ HybridValues HybridBayesNet::sample() const {
/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
const VectorValues &continuousValues) const {
AlgebraicDecisionTree<Key> error_tree(0.0);
AlgebraicDecisionTree<Key> result(0.0);
// Iterate over each conditional.
for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) {
// If conditional is hybrid, select based on assignment and compute
// logProbability.
error_tree = error_tree + gm->logProbability(continuousValues);
result = result + gm->logProbability(continuousValues);
} else if (auto gc = conditional->asGaussian()) {
// If continuous, get the (double) logProbability and add it to the
// error_tree
// result
double logProbability = gc->logProbability(continuousValues);
// Add the computed logProbability to every leaf of the logProbability
// tree.
error_tree = error_tree.apply([logProbability](double leaf_value) {
result = result.apply([logProbability](double leaf_value) {
return leaf_value + logProbability;
});
} else if (auto dc = conditional->asDiscrete()) {
@ -317,7 +317,7 @@ AlgebraicDecisionTree<Key> HybridBayesNet::logProbability(
}
}
return error_tree;
return result;
}
/* ************************************************************************* */

View File

@ -54,9 +54,11 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
/**
* Constructor that takes an initializer list of shared pointers.
* BayesNet<SymbolicConditional> bn = {make_shared<SymbolicConditional>(), ...};
* BayesNet<SymbolicConditional> bn = {make_shared<SymbolicConditional>(),
* ...};
*/
BayesNet(std::initializer_list<sharedConditional> conditionals): Base(conditionals) {}
BayesNet(std::initializer_list<sharedConditional> conditionals)
: Base(conditionals) {}
/// @}
@ -91,7 +93,10 @@ class BayesNet : public FactorGraph<CONDITIONAL> {
/// @name HybridValues methods
/// @{
// Expose HybridValues version of logProbability.
double logProbability(const HybridValues& x) const;
// Expose HybridValues version of evaluate.
double evaluate(const HybridValues& c) const;
/// @}

View File

@ -10,9 +10,6 @@ Author: Frank Dellaert
"""
# pylint: disable=invalid-name, no-name-in-module, no-member
from __future__ import print_function
import math
import unittest
import numpy as np
@ -55,9 +52,9 @@ class TestGaussianBayesNet(GtsamTestCase):
values.insert(_y_, np.array([5.0]))
for i in [0, 1]:
self.assertAlmostEqual(bayesNet.at(i).logProbability(values),
math.log(bayesNet.at(i).evaluate(values)))
np.log(bayesNet.at(i).evaluate(values)))
self.assertAlmostEqual(bayesNet.logProbability(values),
math.log(bayesNet.evaluate(values)))
np.log(bayesNet.evaluate(values)))
def test_sample(self):
"""Test sample method"""