Add test for dot
parent
efe922bc85
commit
13092f6218
|
@ -12,7 +12,9 @@ Author: Frank Dellaert
|
||||||
# pylint: disable=no-name-in-module, invalid-name
|
# pylint: disable=no-name-in-module, invalid-name
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
import gtsam
|
||||||
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph,
|
||||||
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
|
DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering)
|
||||||
from gtsam.utils.test_case import GtsamTestCase
|
from gtsam.utils.test_case import GtsamTestCase
|
||||||
|
@ -126,6 +128,39 @@ class TestDiscreteBayesNet(GtsamTestCase):
|
||||||
actual = fragment.sample(given)
|
actual = fragment.sample(given)
|
||||||
self.assertEqual(len(actual), 5)
|
self.assertEqual(len(actual), 5)
|
||||||
|
|
||||||
|
def test_dot(self):
|
||||||
|
"""Check that dot works with position hints."""
|
||||||
|
fragment = DiscreteBayesNet()
|
||||||
|
fragment.add(Either, [Tuberculosis, LungCancer], "F T T T")
|
||||||
|
MyAsia = gtsam.symbol('a', 0), 2 # use a symbol!
|
||||||
|
fragment.add(Tuberculosis, [MyAsia], "99/1 95/5")
|
||||||
|
fragment.add(LungCancer, [Smoking], "99/1 90/10")
|
||||||
|
|
||||||
|
# Make sure we can *update* position hints
|
||||||
|
writer = gtsam.DotWriter()
|
||||||
|
ph: dict = writer.positionHints
|
||||||
|
ph.update({'a': 2}) # hint at symbol position
|
||||||
|
writer.positionHints = ph
|
||||||
|
|
||||||
|
# Check the output of dot
|
||||||
|
actual = fragment.dot(writer=writer)
|
||||||
|
expected_result = """\
|
||||||
|
digraph {
|
||||||
|
size="5,5";
|
||||||
|
|
||||||
|
var3[label="3"];
|
||||||
|
var4[label="4"];
|
||||||
|
var5[label="5"];
|
||||||
|
var6[label="6"];
|
||||||
|
vara0[label="a0", pos="0,2!"];
|
||||||
|
|
||||||
|
var4->var6
|
||||||
|
vara0->var3
|
||||||
|
var3->var5
|
||||||
|
var6->var5
|
||||||
|
}"""
|
||||||
|
self.assertEqual(actual, textwrap.dedent(expected_result))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Reference in New Issue