diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 4665a3136..f93d21651 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -16,8 +16,8 @@ */ #include -#include #include +#include namespace gtsam { @@ -112,13 +112,12 @@ HybridBayesNet HybridBayesNet::prune( /* ************************************************************************* */ GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const { - return boost::dynamic_pointer_cast(factors_.at(i)->inner()); + return factors_.at(i)->asMixture(); } /* ************************************************************************* */ DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const { - return boost::dynamic_pointer_cast( - factors_.at(i)->inner()); + return factors_.at(i)->asDiscreteConditional(); } /* ************************************************************************* */ @@ -138,4 +137,10 @@ HybridValues HybridBayesNet::optimize() const { return dag.argmax(); } +/* *******************************************************************************/ +VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const { + GaussianBayesNet gbn = this->choose(assignment); + return gbn.optimize(); +} + } // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 0d2dc3642..a16a4f42c 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -72,6 +72,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { /// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and /// put this method there? HybridValues optimize() const; + + /** + * @brief Given the discrete assignment, return the optimized estimate for the + * selected Gaussian BayesNet. + * + * @param assignment An assignment of discrete values. + * @return Values + */ + VectorValues optimize(const DiscreteValues &assignment) const; }; } // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h index 3ba5da393..91c9f8495 100644 --- a/gtsam/hybrid/HybridConditional.h +++ b/gtsam/hybrid/HybridConditional.h @@ -69,7 +69,7 @@ class GTSAM_EXPORT HybridConditional BaseConditional; ///< Typedef to our conditional base class protected: - // Type-erased pointer to the inner type + /// Type-erased pointer to the inner type boost::shared_ptr inner_; public: @@ -127,8 +127,7 @@ class GTSAM_EXPORT HybridConditional * @param gaussianMixture Gaussian Mixture Conditional used to create the * HybridConditional. */ - HybridConditional( - boost::shared_ptr gaussianMixture); + HybridConditional(boost::shared_ptr gaussianMixture); /** * @brief Return HybridConditional as a GaussianMixture @@ -168,10 +167,10 @@ class GTSAM_EXPORT HybridConditional /// Get the type-erased pointer to the inner type boost::shared_ptr inner() { return inner_; } -}; // DiscreteConditional +}; // HybridConditional // traits template <> -struct traits : public Testable {}; +struct traits : public Testable {}; } // namespace gtsam diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index f3db83955..d447bcce2 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -85,6 +85,40 @@ TEST(HybridBayesNet, Choose) { *gbn.at(3))); } +/* ****************************************************************************/ +// Test bayes net optimize +TEST(HybridBayesNet, Optimize) { + Switching s(4); + + Ordering ordering; + for (auto&& kvp : s.linearizationPoint) { + ordering += kvp.key; + } + + HybridBayesNet::shared_ptr hybridBayesNet; + HybridGaussianFactorGraph::shared_ptr remainingFactorGraph; + std::tie(hybridBayesNet, remainingFactorGraph) = + s.linearizedFactorGraph.eliminatePartialSequential(ordering); + + DiscreteValues assignment; + assignment[M(1)] = 1; + assignment[M(2)] = 1; + assignment[M(3)] = 1; + + VectorValues delta = hybridBayesNet->optimize(assignment); + + // The linearization point has the same value as the key index, + // e.g. X(1) = 1, X(2) = 2, + // but the factors specify X(k) = k-1, so delta should be -1. + VectorValues expected_delta; + expected_delta.insert(make_pair(X(1), -Vector1::Ones())); + expected_delta.insert(make_pair(X(2), -Vector1::Ones())); + expected_delta.insert(make_pair(X(3), -Vector1::Ones())); + expected_delta.insert(make_pair(X(4), -Vector1::Ones())); + + EXPECT(assert_equal(expected_delta, delta)); +} + /* ************************************************************************* */ int main() { TestResult tr;