Fixed tests

release/4.3a0
Frank Dellaert 2023-01-06 23:02:49 -08:00
parent 876e2e822e
commit 88f27a210a
2 changed files with 27 additions and 28 deletions

View File

@ -56,6 +56,10 @@ namespace gtsam {
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
using boost::dynamic_pointer_cast;
/* ************************************************************************ */
// Throw a runtime exception for method specified in string s, and factor f:
static void throwRuntimeError(const std::string &s,
@ -88,8 +92,6 @@ static GaussianFactorGraphTree addGaussian(
// TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
using boost::dynamic_pointer_cast;
gttic(assembleGraphTree);
GaussianFactorGraphTree result;
@ -113,14 +115,9 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
// Don't do anything for discrete-only factors
// since we want to eliminate continuous values only.
continue;
} else if (auto orphan = dynamic_pointer_cast<
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f)) {
// We need to handle the case where the object is actually an
// BayesTreeOrphanWrapper!
throw std::invalid_argument(
"gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented "
"yet.");
} else {
// TODO(dellaert): there was an unattributed comment here: We need to
// handle the case where the object is actually an BayesTreeOrphanWrapper!
throwRuntimeError("gtsam::assembleGraphTree", f);
}
}
@ -134,17 +131,19 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
continuousElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
using boost::dynamic_pointer_cast;
GaussianFactorGraph gfg;
for (auto &fp : factors) {
if (auto hgf = dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
for (auto &f : factors) {
if (auto hgf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
gfg.push_back(hgf->inner());
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(fp)) {
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique.
// TODO(dellaert): is this correct? If so explain here.
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
auto gc = hc->asGaussian();
assert(gc);
if (!gc) throwRuntimeError("continuousElimination", gc);
gfg.push_back(gc);
} else {
// It is an orphan wrapped conditional
throwRuntimeError("continuousElimination", f);
}
}
@ -159,14 +158,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
const Ordering &frontalKeys) {
DiscreteFactorGraph dfg;
for (auto &factor : factors) {
if (auto dtf = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
for (auto &f : factors) {
if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
dfg.push_back(dtf);
} else if (auto hc =
boost::static_pointer_cast<HybridConditional>(factor)) {
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
// Ignore orphaned clique.
// TODO(dellaert): is this correct? If so explain here.
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
auto dc = hc->asDiscrete();
if (!dc) throwRuntimeError("continuousElimination", dc);
dfg.push_back(hc->asDiscrete());
} else {
// It is an orphan wrapper
throwRuntimeError("continuousElimination", f);
}
}
@ -456,8 +459,6 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
const VectorValues &continuousValues) const {
using boost::dynamic_pointer_cast;
AlgebraicDecisionTree<Key> error_tree(0.0);
// Iterate over each factor.
@ -496,10 +497,10 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
double error = 0.0;
for (auto &f : factors_) {
if (auto hf = boost::dynamic_pointer_cast<HybridFactor>(f)) {
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
// TODO(dellaert): needs to change when we discard other wrappers.
error += hf->error(values);
} else if (auto dtf = boost::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
} else if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
error -= log((*dtf)(values.discrete()));
} else {
throwRuntimeError("HybridGaussianFactorGraph::error", f);

View File

@ -563,12 +563,10 @@ factor 4: Continuous [x2]
]
b = [ -10 ]
No noise model
factor 5: Discrete [m0]
P( m0 ):
factor 5: P( m0 ):
Leaf 0.5
factor 6: Discrete [m1 m0]
P( m1 | m0 ):
factor 6: P( m1 | m0 ):
Choice(m1)
0 Choice(m0)
0 0 Leaf 0.33333333