Fixed tests
parent
876e2e822e
commit
88f27a210a
|
@ -56,6 +56,10 @@ namespace gtsam {
|
|||
/// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph:
|
||||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||
|
||||
using OrphanWrapper = BayesTreeOrphanWrapper<HybridBayesTree::Clique>;
|
||||
|
||||
using boost::dynamic_pointer_cast;
|
||||
|
||||
/* ************************************************************************ */
|
||||
// Throw a runtime exception for method specified in string s, and factor f:
|
||||
static void throwRuntimeError(const std::string &s,
|
||||
|
@ -88,8 +92,6 @@ static GaussianFactorGraphTree addGaussian(
|
|||
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
||||
// keys, and then loop over all assignments to populate a vector.
|
||||
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||
using boost::dynamic_pointer_cast;
|
||||
|
||||
gttic(assembleGraphTree);
|
||||
|
||||
GaussianFactorGraphTree result;
|
||||
|
@ -113,14 +115,9 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
|||
// Don't do anything for discrete-only factors
|
||||
// since we want to eliminate continuous values only.
|
||||
continue;
|
||||
} else if (auto orphan = dynamic_pointer_cast<
|
||||
BayesTreeOrphanWrapper<HybridBayesTree::Clique>>(f)) {
|
||||
// We need to handle the case where the object is actually an
|
||||
// BayesTreeOrphanWrapper!
|
||||
throw std::invalid_argument(
|
||||
"gtsam::assembleGraphTree: BayesTreeOrphanWrapper is not implemented "
|
||||
"yet.");
|
||||
} else {
|
||||
// TODO(dellaert): there was an unattributed comment here: We need to
|
||||
// handle the case where the object is actually an BayesTreeOrphanWrapper!
|
||||
throwRuntimeError("gtsam::assembleGraphTree", f);
|
||||
}
|
||||
}
|
||||
|
@ -134,17 +131,19 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
|||
static std::pair<HybridConditional::shared_ptr, boost::shared_ptr<Factor>>
|
||||
continuousElimination(const HybridGaussianFactorGraph &factors,
|
||||
const Ordering &frontalKeys) {
|
||||
using boost::dynamic_pointer_cast;
|
||||
GaussianFactorGraph gfg;
|
||||
for (auto &fp : factors) {
|
||||
if (auto hgf = dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
|
||||
for (auto &f : factors) {
|
||||
if (auto hgf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
gfg.push_back(hgf->inner());
|
||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(fp)) {
|
||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||
// Ignore orphaned clique.
|
||||
// TODO(dellaert): is this correct? If so explain here.
|
||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
auto gc = hc->asGaussian();
|
||||
assert(gc);
|
||||
if (!gc) throwRuntimeError("continuousElimination", gc);
|
||||
gfg.push_back(gc);
|
||||
} else {
|
||||
// It is an orphan wrapped conditional
|
||||
throwRuntimeError("continuousElimination", f);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -159,14 +158,18 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
const Ordering &frontalKeys) {
|
||||
DiscreteFactorGraph dfg;
|
||||
|
||||
for (auto &factor : factors) {
|
||||
if (auto dtf = boost::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
|
||||
for (auto &f : factors) {
|
||||
if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
dfg.push_back(dtf);
|
||||
} else if (auto hc =
|
||||
boost::static_pointer_cast<HybridConditional>(factor)) {
|
||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||
// Ignore orphaned clique.
|
||||
// TODO(dellaert): is this correct? If so explain here.
|
||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
auto dc = hc->asDiscrete();
|
||||
if (!dc) throwRuntimeError("continuousElimination", dc);
|
||||
dfg.push_back(hc->asDiscrete());
|
||||
} else {
|
||||
// It is an orphan wrapper
|
||||
throwRuntimeError("continuousElimination", f);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -456,8 +459,6 @@ const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
|
|||
/* ************************************************************************ */
|
||||
AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
||||
const VectorValues &continuousValues) const {
|
||||
using boost::dynamic_pointer_cast;
|
||||
|
||||
AlgebraicDecisionTree<Key> error_tree(0.0);
|
||||
|
||||
// Iterate over each factor.
|
||||
|
@ -496,10 +497,10 @@ AlgebraicDecisionTree<Key> HybridGaussianFactorGraph::error(
|
|||
double HybridGaussianFactorGraph::error(const HybridValues &values) const {
|
||||
double error = 0.0;
|
||||
for (auto &f : factors_) {
|
||||
if (auto hf = boost::dynamic_pointer_cast<HybridFactor>(f)) {
|
||||
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
|
||||
// TODO(dellaert): needs to change when we discard other wrappers.
|
||||
error += hf->error(values);
|
||||
} else if (auto dtf = boost::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
} else if (auto dtf = dynamic_pointer_cast<DecisionTreeFactor>(f)) {
|
||||
error -= log((*dtf)(values.discrete()));
|
||||
} else {
|
||||
throwRuntimeError("HybridGaussianFactorGraph::error", f);
|
||||
|
|
|
@ -563,12 +563,10 @@ factor 4: Continuous [x2]
|
|||
]
|
||||
b = [ -10 ]
|
||||
No noise model
|
||||
factor 5: Discrete [m0]
|
||||
P( m0 ):
|
||||
factor 5: P( m0 ):
|
||||
Leaf 0.5
|
||||
|
||||
factor 6: Discrete [m1 m0]
|
||||
P( m1 | m0 ):
|
||||
factor 6: P( m1 | m0 ):
|
||||
Choice(m1)
|
||||
0 Choice(m0)
|
||||
0 0 Leaf 0.33333333
|
||||
|
|
Loading…
Reference in New Issue