diff --git a/gtsam.h b/gtsam.h index aebbea8b8..4aa1480df 100644 --- a/gtsam.h +++ b/gtsam.h @@ -787,6 +787,8 @@ class SymbolicFactorGraph { // Standard interface // FIXME: Must wrap FastSet for this to work //FastSet keys() const; + + pair eliminateFrontals(size_t nFrontals) const; }; #include @@ -996,6 +998,9 @@ class GaussianFactorGraph { size_t size() const; gtsam::GaussianFactor* at(size_t idx) const; + // Inference + pair eliminateFrontals(size_t nFrontals) const; + // Building the graph void push_back(gtsam::GaussianFactor* factor); void add(Vector b); diff --git a/gtsam/inference/FactorGraph-inl.h b/gtsam/inference/FactorGraph-inl.h index 66c44db52..6bd9c340c 100644 --- a/gtsam/inference/FactorGraph-inl.h +++ b/gtsam/inference/FactorGraph-inl.h @@ -23,6 +23,7 @@ #pragma once #include +#include #include #include @@ -85,6 +86,57 @@ namespace gtsam { return size_; } + /* ************************************************************************* */ + template + std::pair::sharedConditional, FactorGraph > + FactorGraph::eliminateFrontals(size_t nFrontals, const Eliminate& eliminate) const + { + // Build variable index + VariableIndex variableIndex(*this); + + // Find first variable + Index firstIndex = 0; + while(firstIndex < variableIndex.size() && variableIndex[firstIndex].empty()) + ++ firstIndex; + + // Check that number of variables is in bounds + if(firstIndex + nFrontals >= variableIndex.size()) + throw std::invalid_argument("Requested to eliminate more frontal variables than exist in the factor graph."); + + // Get set of involved factors + FastSet involvedFactorIs; + for(Index j = firstIndex; j < firstIndex + nFrontals; ++j) { + BOOST_FOREACH(size_t i, variableIndex[j]) { + involvedFactorIs.insert(i); + } + } + + // Separate factors into involved and remaining + FactorGraph involvedFactors; + FactorGraph remainingFactors; + FastSet::const_iterator involvedFactorIsIt = involvedFactorIs.begin(); + for(size_t i = 0; i < this->size(); ++i) { + if(*involvedFactorIsIt == i) { + // If the current factor is involved, add it to involved and increment involved iterator + involvedFactors.push_back((*this)[i]); + ++ involvedFactorIsIt; + } else { + // If not involved, add to remaining + remainingFactors.push_back((*this)[i]); + } + } + + // Do dense elimination on the involved factors + typename FactorGraph::EliminationResult eliminationResult = + eliminate(involvedFactors, nFrontals); + + // Add the remaining factor back into the factor graph + remainingFactors.push_back(eliminationResult.second); + + // Return the eliminated factor and remaining factor graph + return std::make_pair(eliminationResult.first, remainingFactors); + } + /* ************************************************************************* */ template void FactorGraph::replace(size_t index, sharedFactor factor) { diff --git a/gtsam/inference/FactorGraph.h b/gtsam/inference/FactorGraph.h index 5295a4c58..ed3e7d952 100644 --- a/gtsam/inference/FactorGraph.h +++ b/gtsam/inference/FactorGraph.h @@ -175,6 +175,13 @@ template class BayesTree; /** Get the last factor */ sharedFactor back() const { return factors_.back(); } + /** Eliminate the first \c n frontal variables, returning the resulting + * conditional and remaining factor graph - this is very inefficient for + * eliminating all variables, to do that use EliminationTree or + * JunctionTree. + */ + std::pair > eliminateFrontals(size_t nFrontals, const Eliminate& eliminate) const; + /// @} /// @name Modifying Factor Graphs (imperative, discouraged) /// @{ diff --git a/gtsam/inference/SymbolicFactorGraph.cpp b/gtsam/inference/SymbolicFactorGraph.cpp index 00f3439a0..ff7a91ca6 100644 --- a/gtsam/inference/SymbolicFactorGraph.cpp +++ b/gtsam/inference/SymbolicFactorGraph.cpp @@ -63,6 +63,13 @@ namespace gtsam { return keys; } + /* ************************************************************************* */ + std::pair + SymbolicFactorGraph::eliminateFrontals(size_t nFrontals) const + { + return FactorGraph::eliminateFrontals(nFrontals, EliminateSymbolic); + } + /* ************************************************************************* */ IndexFactor::shared_ptr CombineSymbolic( const FactorGraph& factors, const FastMap SymbolicFactorGraph(const FactorGraph& fg); + + /** Eliminate the first \c n frontal variables, returning the resulting + * conditional and remaining factor graph - this is very inefficient for + * eliminating all variables, to do that use EliminationTree or + * JunctionTree. Note that this version simply calls + * FactorGraph::eliminateFrontals with EliminateSymbolic + * as the eliminate function argument. + */ + std::pair eliminateFrontals(size_t nFrontals) const; /// @} /// @name Standard Interface @@ -68,6 +77,8 @@ namespace gtsam { */ FastSet keys() const; + + /// @} /// @name Advanced Interface /// @{ @@ -87,9 +98,8 @@ namespace gtsam { }; /** Create a combined joint factor (new style for EliminationTree). */ - IndexFactor::shared_ptr CombineSymbolic( - const FactorGraph& factors, const FastMap >& variableSlots); + IndexFactor::shared_ptr CombineSymbolic(const FactorGraph& factors, + const FastMap >& variableSlots); /** * CombineAndEliminate provides symbolic elimination. diff --git a/gtsam/inference/tests/testFactorGraph.cpp b/gtsam/inference/tests/testFactorGraph.cpp index 4743d1102..6121a7f99 100644 --- a/gtsam/inference/tests/testFactorGraph.cpp +++ b/gtsam/inference/tests/testFactorGraph.cpp @@ -21,16 +21,44 @@ #include #include #include // for operator += +#include using namespace boost::assign; #include +#include #include using namespace std; using namespace gtsam; -typedef boost::shared_ptr shared; +/* ************************************************************************* */ +TEST(FactorGraph, eliminateFrontals) { + + SymbolicFactorGraph sfgOrig; + sfgOrig.push_factor(0,1); + sfgOrig.push_factor(0,2); + sfgOrig.push_factor(1,3); + sfgOrig.push_factor(1,4); + sfgOrig.push_factor(2,3); + sfgOrig.push_factor(4,5); + + IndexConditional::shared_ptr actualCond; + SymbolicFactorGraph actualSfg; + boost::tie(actualCond, actualSfg) = sfgOrig.eliminateFrontals(2, EliminateSymbolic); + + vector condIndices; + condIndices += 0,1,2,3,4; + IndexConditional expectedCond(condIndices, 2); + + SymbolicFactorGraph expectedSfg; + expectedSfg.push_factor(2,3); + expectedSfg.push_factor(4,5); + expectedSfg.push_factor(2,3,4); + + EXPECT(assert_equal(expectedSfg, actualSfg)); + EXPECT(assert_equal(expectedCond, *actualCond)); +} ///* ************************************************************************* */ // SL-FIX TEST( FactorGraph, splitMinimumSpanningTree ) diff --git a/gtsam/linear/GaussianFactorGraph.cpp b/gtsam/linear/GaussianFactorGraph.cpp index e2598b1ac..02f37915c 100644 --- a/gtsam/linear/GaussianFactorGraph.cpp +++ b/gtsam/linear/GaussianFactorGraph.cpp @@ -47,6 +47,13 @@ namespace gtsam { return keys; } + /* ************************************************************************* */ + std::pair + GaussianFactorGraph::eliminateFrontals(size_t nFrontals) const + { + return FactorGraph::eliminateFrontals(nFrontals, EliminateQR); + } + /* ************************************************************************* */ void GaussianFactorGraph::permuteWithInverse( const Permutation& inversePermutation) { diff --git a/gtsam/linear/GaussianFactorGraph.h b/gtsam/linear/GaussianFactorGraph.h index fd1ec1f6c..51748b79b 100644 --- a/gtsam/linear/GaussianFactorGraph.h +++ b/gtsam/linear/GaussianFactorGraph.h @@ -132,6 +132,16 @@ namespace gtsam { typedef FastSet Keys; Keys keys() const; + + /** Eliminate the first \c n frontal variables, returning the resulting + * conditional and remaining factor graph - this is very inefficient for + * eliminating all variables, to do that use EliminationTree or + * JunctionTree. Note that this version simply calls + * FactorGraph::eliminateFrontals with EliminateQR as the + * eliminate function argument. + */ + std::pair eliminateFrontals(size_t nFrontals) const; + /** Permute the variables in the factors */ void permuteWithInverse(const Permutation& inversePermutation);