HybridBayesNet::optimize

release/4.3a0
Varun Agrawal 2022-08-26 11:35:24 -04:00
parent a6101b2d8f
commit f0df82ac04
4 changed files with 56 additions and 9 deletions

View File

@ -16,8 +16,8 @@
*/
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>
namespace gtsam {
@ -112,13 +112,12 @@ HybridBayesNet HybridBayesNet::prune(
/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return boost::dynamic_pointer_cast<GaussianMixture>(factors_.at(i)->inner());
return factors_.at(i)->asMixture();
}
/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return boost::dynamic_pointer_cast<DiscreteConditional>(
factors_.at(i)->inner());
return factors_.at(i)->asDiscreteConditional();
}
/* ************************************************************************* */
@ -138,4 +137,10 @@ HybridValues HybridBayesNet::optimize() const {
return dag.argmax();
}
/* *******************************************************************************/
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment);
return gbn.optimize();
}
} // namespace gtsam

View File

@ -72,6 +72,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
/// put this method there?
HybridValues optimize() const;
/**
* @brief Given the discrete assignment, return the optimized estimate for the
* selected Gaussian BayesNet.
*
* @param assignment An assignment of discrete values.
* @return Values
*/
VectorValues optimize(const DiscreteValues &assignment) const;
};
} // namespace gtsam

View File

@ -69,7 +69,7 @@ class GTSAM_EXPORT HybridConditional
BaseConditional; ///< Typedef to our conditional base class
protected:
// Type-erased pointer to the inner type
/// Type-erased pointer to the inner type
boost::shared_ptr<Factor> inner_;
public:
@ -127,8 +127,7 @@ class GTSAM_EXPORT HybridConditional
* @param gaussianMixture Gaussian Mixture Conditional used to create the
* HybridConditional.
*/
HybridConditional(
boost::shared_ptr<GaussianMixture> gaussianMixture);
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);
/**
* @brief Return HybridConditional as a GaussianMixture
@ -168,10 +167,10 @@ class GTSAM_EXPORT HybridConditional
/// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() { return inner_; }
}; // DiscreteConditional
}; // HybridConditional
// traits
template <>
struct traits<HybridConditional> : public Testable<DiscreteConditional> {};
struct traits<HybridConditional> : public Testable<HybridConditional> {};
} // namespace gtsam

View File

@ -85,6 +85,40 @@ TEST(HybridBayesNet, Choose) {
*gbn.at(3)));
}
/* ****************************************************************************/
// Test bayes net optimize
TEST(HybridBayesNet, Optimize) {
Switching s(4);
Ordering ordering;
for (auto&& kvp : s.linearizationPoint) {
ordering += kvp.key;
}
HybridBayesNet::shared_ptr hybridBayesNet;
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
std::tie(hybridBayesNet, remainingFactorGraph) =
s.linearizedFactorGraph.eliminatePartialSequential(ordering);
DiscreteValues assignment;
assignment[M(1)] = 1;
assignment[M(2)] = 1;
assignment[M(3)] = 1;
VectorValues delta = hybridBayesNet->optimize(assignment);
// The linearization point has the same value as the key index,
// e.g. X(1) = 1, X(2) = 2,
// but the factors specify X(k) = k-1, so delta should be -1.
VectorValues expected_delta;
expected_delta.insert(make_pair(X(1), -Vector1::Ones()));
expected_delta.insert(make_pair(X(2), -Vector1::Ones()));
expected_delta.insert(make_pair(X(3), -Vector1::Ones()));
expected_delta.insert(make_pair(X(4), -Vector1::Ones()));
EXPECT(assert_equal(expected_delta, delta));
}
/* ************************************************************************* */
int main() {
TestResult tr;