diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index bc0d8e95e..be9cdba85 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -293,21 +293,21 @@ HybridValues HybridBayesNet::sample() const { /* ************************************************************************* */ AlgebraicDecisionTree HybridBayesNet::logProbability( const VectorValues &continuousValues) const { - AlgebraicDecisionTree error_tree(0.0); + AlgebraicDecisionTree result(0.0); // Iterate over each conditional. for (auto &&conditional : *this) { if (auto gm = conditional->asMixture()) { // If conditional is hybrid, select based on assignment and compute // logProbability. - error_tree = error_tree + gm->logProbability(continuousValues); + result = result + gm->logProbability(continuousValues); } else if (auto gc = conditional->asGaussian()) { // If continuous, get the (double) logProbability and add it to the - // error_tree + // result double logProbability = gc->logProbability(continuousValues); // Add the computed logProbability to every leaf of the logProbability // tree. - error_tree = error_tree.apply([logProbability](double leaf_value) { + result = result.apply([logProbability](double leaf_value) { return leaf_value + logProbability; }); } else if (auto dc = conditional->asDiscrete()) { @@ -317,7 +317,7 @@ AlgebraicDecisionTree HybridBayesNet::logProbability( } } - return error_tree; + return result; } /* ************************************************************************* */ diff --git a/gtsam/inference/BayesNet.h b/gtsam/inference/BayesNet.h index 4d266df46..3e6d55281 100644 --- a/gtsam/inference/BayesNet.h +++ b/gtsam/inference/BayesNet.h @@ -54,9 +54,11 @@ class BayesNet : public FactorGraph { /** * Constructor that takes an initializer list of shared pointers. - * BayesNet bn = {make_shared(), ...}; + * BayesNet bn = {make_shared(), + * ...}; */ - BayesNet(std::initializer_list conditionals): Base(conditionals) {} + BayesNet(std::initializer_list conditionals) + : Base(conditionals) {} /// @} @@ -91,7 +93,10 @@ class BayesNet : public FactorGraph { /// @name HybridValues methods /// @{ + // Expose HybridValues version of logProbability. double logProbability(const HybridValues& x) const; + + // Expose HybridValues version of evaluate. double evaluate(const HybridValues& c) const; /// @} diff --git a/python/gtsam/tests/test_GaussianBayesNet.py b/python/gtsam/tests/test_GaussianBayesNet.py index 9065c7bee..05522441b 100644 --- a/python/gtsam/tests/test_GaussianBayesNet.py +++ b/python/gtsam/tests/test_GaussianBayesNet.py @@ -10,9 +10,6 @@ Author: Frank Dellaert """ # pylint: disable=invalid-name, no-name-in-module, no-member -from __future__ import print_function - -import math import unittest import numpy as np @@ -55,9 +52,9 @@ class TestGaussianBayesNet(GtsamTestCase): values.insert(_y_, np.array([5.0])) for i in [0, 1]: self.assertAlmostEqual(bayesNet.at(i).logProbability(values), - math.log(bayesNet.at(i).evaluate(values))) + np.log(bayesNet.at(i).evaluate(values))) self.assertAlmostEqual(bayesNet.logProbability(values), - math.log(bayesNet.evaluate(values))) + np.log(bayesNet.evaluate(values))) def test_sample(self): """Test sample method"""