HybridBayesNet::optimize
parent
a6101b2d8f
commit
f0df82ac04
|
|
@ -16,8 +16,8 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <gtsam/hybrid/HybridBayesNet.h>
|
#include <gtsam/hybrid/HybridBayesNet.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
|
||||||
#include <gtsam/hybrid/HybridLookupDAG.h>
|
#include <gtsam/hybrid/HybridLookupDAG.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
|
@ -112,13 +112,12 @@ HybridBayesNet HybridBayesNet::prune(
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
|
||||||
return boost::dynamic_pointer_cast<GaussianMixture>(factors_.at(i)->inner());
|
return factors_.at(i)->asMixture();
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
|
||||||
return boost::dynamic_pointer_cast<DiscreteConditional>(
|
return factors_.at(i)->asDiscreteConditional();
|
||||||
factors_.at(i)->inner());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
@ -138,4 +137,10 @@ HybridValues HybridBayesNet::optimize() const {
|
||||||
return dag.argmax();
|
return dag.argmax();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
|
||||||
|
GaussianBayesNet gbn = this->choose(assignment);
|
||||||
|
return gbn.optimize();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
|
||||||
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
|
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
|
||||||
/// put this method there?
|
/// put this method there?
|
||||||
HybridValues optimize() const;
|
HybridValues optimize() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Given the discrete assignment, return the optimized estimate for the
|
||||||
|
* selected Gaussian BayesNet.
|
||||||
|
*
|
||||||
|
* @param assignment An assignment of discrete values.
|
||||||
|
* @return Values
|
||||||
|
*/
|
||||||
|
VectorValues optimize(const DiscreteValues &assignment) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ class GTSAM_EXPORT HybridConditional
|
||||||
BaseConditional; ///< Typedef to our conditional base class
|
BaseConditional; ///< Typedef to our conditional base class
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Type-erased pointer to the inner type
|
/// Type-erased pointer to the inner type
|
||||||
boost::shared_ptr<Factor> inner_;
|
boost::shared_ptr<Factor> inner_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
@ -127,8 +127,7 @@ class GTSAM_EXPORT HybridConditional
|
||||||
* @param gaussianMixture Gaussian Mixture Conditional used to create the
|
* @param gaussianMixture Gaussian Mixture Conditional used to create the
|
||||||
* HybridConditional.
|
* HybridConditional.
|
||||||
*/
|
*/
|
||||||
HybridConditional(
|
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
|
||||||
boost::shared_ptr<GaussianMixture> gaussianMixture);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Return HybridConditional as a GaussianMixture
|
* @brief Return HybridConditional as a GaussianMixture
|
||||||
|
|
@ -168,10 +167,10 @@ class GTSAM_EXPORT HybridConditional
|
||||||
/// Get the type-erased pointer to the inner type
|
/// Get the type-erased pointer to the inner type
|
||||||
boost::shared_ptr<Factor> inner() { return inner_; }
|
boost::shared_ptr<Factor> inner() { return inner_; }
|
||||||
|
|
||||||
}; // DiscreteConditional
|
}; // HybridConditional
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template <>
|
template <>
|
||||||
struct traits<HybridConditional> : public Testable<DiscreteConditional> {};
|
struct traits<HybridConditional> : public Testable<HybridConditional> {};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -85,6 +85,40 @@ TEST(HybridBayesNet, Choose) {
|
||||||
*gbn.at(3)));
|
*gbn.at(3)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
// Test bayes net optimize
|
||||||
|
TEST(HybridBayesNet, Optimize) {
|
||||||
|
Switching s(4);
|
||||||
|
|
||||||
|
Ordering ordering;
|
||||||
|
for (auto&& kvp : s.linearizationPoint) {
|
||||||
|
ordering += kvp.key;
|
||||||
|
}
|
||||||
|
|
||||||
|
HybridBayesNet::shared_ptr hybridBayesNet;
|
||||||
|
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
|
||||||
|
std::tie(hybridBayesNet, remainingFactorGraph) =
|
||||||
|
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
|
||||||
|
|
||||||
|
DiscreteValues assignment;
|
||||||
|
assignment[M(1)] = 1;
|
||||||
|
assignment[M(2)] = 1;
|
||||||
|
assignment[M(3)] = 1;
|
||||||
|
|
||||||
|
VectorValues delta = hybridBayesNet->optimize(assignment);
|
||||||
|
|
||||||
|
// The linearization point has the same value as the key index,
|
||||||
|
// e.g. X(1) = 1, X(2) = 2,
|
||||||
|
// but the factors specify X(k) = k-1, so delta should be -1.
|
||||||
|
VectorValues expected_delta;
|
||||||
|
expected_delta.insert(make_pair(X(1), -Vector1::Ones()));
|
||||||
|
expected_delta.insert(make_pair(X(2), -Vector1::Ones()));
|
||||||
|
expected_delta.insert(make_pair(X(3), -Vector1::Ones()));
|
||||||
|
expected_delta.insert(make_pair(X(4), -Vector1::Ones()));
|
||||||
|
|
||||||
|
EXPECT(assert_equal(expected_delta, delta));
|
||||||
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
int main() {
|
int main() {
|
||||||
TestResult tr;
|
TestResult tr;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue