Wrap single-argument methods
							parent
							
								
									10628a0ddc
								
							
						
					
					
						commit
						a1b8f52da8
					
				| 
						 | 
				
			
			@ -79,6 +79,8 @@ virtual class DiscretePrior : gtsam::DiscreteConditional {
 | 
			
		|||
  void print(string s = "Discrete Prior\n",
 | 
			
		||||
             const gtsam::KeyFormatter& keyFormatter =
 | 
			
		||||
                 gtsam::DefaultKeyFormatter) const;
 | 
			
		||||
  double operator()(size_t value) const;
 | 
			
		||||
  std::vector<double> pmf() const;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DiscreteBayesNet.h>
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,16 +13,18 @@ Author: Varun Agrawal
 | 
			
		|||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
from gtsam import DiscretePrior, DecisionTreeFactor, DiscreteKeys
 | 
			
		||||
import numpy as np
 | 
			
		||||
from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior
 | 
			
		||||
from gtsam.utils.test_case import GtsamTestCase
 | 
			
		||||
 | 
			
		||||
X = 0, 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestDiscretePrior(GtsamTestCase):
 | 
			
		||||
    """Tests for Discrete Priors."""
 | 
			
		||||
 | 
			
		||||
    def test_constructor(self):
 | 
			
		||||
        """Test various constructors."""
 | 
			
		||||
        X = 0, 2
 | 
			
		||||
        actual = DiscretePrior(X, "2/3")
 | 
			
		||||
        keys = DiscreteKeys()
 | 
			
		||||
        keys.push_back(X)
 | 
			
		||||
| 
						 | 
				
			
			@ -30,10 +32,19 @@ class TestDiscretePrior(GtsamTestCase):
 | 
			
		|||
        expected = DiscretePrior(f)
 | 
			
		||||
        self.gtsamAssertEquals(actual, expected)
 | 
			
		||||
 | 
			
		||||
    def test_operator(self):
 | 
			
		||||
        prior = DiscretePrior(X, "2/3")
 | 
			
		||||
        self.assertAlmostEqual(prior(0), 0.4)
 | 
			
		||||
        self.assertAlmostEqual(prior(1), 0.6)
 | 
			
		||||
 | 
			
		||||
    def test_pmf(self):
 | 
			
		||||
        prior = DiscretePrior(X, "2/3")
 | 
			
		||||
        expected = np.array([0.4, 0.6])
 | 
			
		||||
        np.testing.assert_allclose(expected, prior.pmf())
 | 
			
		||||
 | 
			
		||||
    def test_markdown(self):
 | 
			
		||||
        """Test the _repr_markdown_ method."""
 | 
			
		||||
 | 
			
		||||
        X = 0, 2
 | 
			
		||||
        prior = DiscretePrior(X, "2/3")
 | 
			
		||||
        expected = " $P(0)$:\n" \
 | 
			
		||||
            "|0|value|\n" \
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue