Merge remote-tracking branch 'upstream/develop' into develop

release/4.3a0
senselessDev 2022-01-29 22:03:36 +01:00
commit a82ddcc4d4
32 changed files with 685 additions and 483 deletions

View File

@ -11,7 +11,7 @@ endif()
set (GTSAM_VERSION_MAJOR 4) set (GTSAM_VERSION_MAJOR 4)
set (GTSAM_VERSION_MINOR 2) set (GTSAM_VERSION_MINOR 2)
set (GTSAM_VERSION_PATCH 0) set (GTSAM_VERSION_PATCH 0)
set (GTSAM_PRERELEASE_VERSION "a3") set (GTSAM_PRERELEASE_VERSION "a4")
math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}")
if (${GTSAM_VERSION_PATCH} EQUAL 0) if (${GTSAM_VERSION_PATCH} EQUAL 0)

View File

@ -31,11 +31,12 @@
namespace gtsam { namespace gtsam {
/** A Bayes net made from discrete conditional distributions. */ /**
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> * A Bayes net made from discrete conditional distributions.
{ * @addtogroup discrete
public: */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
public:
typedef BayesNet<DiscreteConditional> Base; typedef BayesNet<DiscreteConditional> Base;
typedef DiscreteBayesNet This; typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType; typedef DiscreteConditional ConditionalType;
@ -49,16 +50,20 @@ namespace gtsam {
DiscreteBayesNet() {} DiscreteBayesNet() {}
/** Construct from iterator over conditionals */ /** Construct from iterator over conditionals */
template<typename ITERATOR> template <typename ITERATOR>
DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} DiscreteBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Construct from container of factors (shared_ptr or plain objects) */ /** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER> template <class CONTAINER>
explicit DiscreteBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} explicit DiscreteBayesNet(const CONTAINER& conditionals)
: Base(conditionals) {}
/** Implicit copy/downcast constructor to override explicit template container constructor */ /** Implicit copy/downcast constructor to override explicit template
template<class DERIVEDCONDITIONAL> * container constructor */
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {} template <class DERIVEDCONDITIONAL>
DiscreteBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}
/// Destructor /// Destructor
virtual ~DiscreteBayesNet() {} virtual ~DiscreteBayesNet() {}

View File

@ -102,6 +102,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor {
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const;
gtsam::Key firstFrontalKey() const;
size_t nrFrontals() const; size_t nrFrontals() const;
size_t nrParents() const; size_t nrParents() const;
void printSignature( void printSignature(
@ -156,13 +157,17 @@ class DiscreteBayesNet {
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const; bool equals(const gtsam::DiscreteBayesNet& other, double tol = 1e-9) const;
string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
void saveGraph(string s, const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues sample() const; gtsam::DiscreteValues sample() const;
gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const; gtsam::DiscreteValues sample(gtsam::DiscreteValues given) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter, string markdown(const gtsam::KeyFormatter& keyFormatter,
@ -228,19 +233,6 @@ class DiscreteLookupDAG {
gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const; gtsam::DiscreteValues argmax(gtsam::DiscreteValues given) const;
}; };
#include <gtsam/inference/DotWriter.h>
class DotWriter {
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
bool plotFactorPoints = true, bool connectKeysToFactor = true,
bool binaryEdges = true);
double figureWidthInches;
double figureHeightInches;
bool plotFactorPoints;
bool connectKeysToFactor;
bool binaryEdges;
};
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
class DiscreteFactorGraph { class DiscreteFactorGraph {
DiscreteFactorGraph(); DiscreteFactorGraph();
@ -265,14 +257,6 @@ class DiscreteFactorGraph {
void print(string s = "") const; void print(string s = "") const;
bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const; bool equals(const gtsam::DiscreteFactorGraph& fg, double tol = 1e-9) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& dotWriter = gtsam::DotWriter()) const;
gtsam::DecisionTreeFactor product() const; gtsam::DecisionTreeFactor product() const;
double operator()(const gtsam::DiscreteValues& values) const; double operator()(const gtsam::DiscreteValues& values) const;
gtsam::DiscreteValues optimize() const; gtsam::DiscreteValues optimize() const;
@ -294,6 +278,14 @@ class DiscreteFactorGraph {
std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph> std::pair<gtsam::DiscreteBayesTree, gtsam::DiscreteFactorGraph>
eliminatePartialMultifrontal(const gtsam::Ordering& ordering); eliminatePartialMultifrontal(const gtsam::Ordering& ordering);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
string markdown(const gtsam::KeyFormatter& keyFormatter = string markdown(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
string markdown(const gtsam::KeyFormatter& keyFormatter, string markdown(const gtsam::KeyFormatter& keyFormatter,

View File

@ -150,12 +150,21 @@ TEST(DiscreteBayesNet, Dot) {
fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
string actual = fragment.dot(); string actual = fragment.dot();
cout << actual << endl;
EXPECT(actual == EXPECT(actual ==
"digraph G{\n" "digraph {\n"
"0->3\n" " size=\"5,5\";\n"
"4->6\n" "\n"
"3->5\n" " var0[label=\"0\"];\n"
"6->5\n" " var3[label=\"3\"];\n"
" var4[label=\"4\"];\n"
" var5[label=\"5\"];\n"
" var6[label=\"6\"];\n"
"\n"
" var3->var5\n"
" var6->var5\n"
" var4->var6\n"
" var0->var3\n"
"}"); "}");
} }

View File

@ -49,16 +49,14 @@
namespace gtsam { namespace gtsam {
/** /**
* @brief A 3D rotation represented as a rotation matrix if the preprocessor * @brief Rot3 is a 3D rotation represented as a rotation matrix if the
* symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion if it * preprocessor symbol GTSAM_USE_QUATERNIONS is not defined, or as a quaternion
* is defined. * if it is defined.
* @addtogroup geometry * @addtogroup geometry
* \nosubgrouping */
*/ class GTSAM_EXPORT Rot3 : public LieGroup<Rot3, 3> {
class GTSAM_EXPORT Rot3 : public LieGroup<Rot3,3> { private:
private:
#ifdef GTSAM_USE_QUATERNIONS #ifdef GTSAM_USE_QUATERNIONS
/** Internal Eigen Quaternion */ /** Internal Eigen Quaternion */
@ -67,8 +65,7 @@ namespace gtsam {
SO3 rot_; SO3 rot_;
#endif #endif
public: public:
/// @name Constructors and named constructors /// @name Constructors and named constructors
/// @{ /// @{
@ -83,7 +80,7 @@ namespace gtsam {
*/ */
Rot3(const Point3& col1, const Point3& col2, const Point3& col3); Rot3(const Point3& col1, const Point3& col2, const Point3& col3);
/** constructor from a rotation matrix, as doubles in *row-major* order !!! */ /// Construct from a rotation matrix, as doubles in *row-major* order !!!
Rot3(double R11, double R12, double R13, Rot3(double R11, double R12, double R13,
double R21, double R22, double R23, double R21, double R22, double R23,
double R31, double R32, double R33); double R31, double R32, double R33);
@ -567,6 +564,9 @@ namespace gtsam {
#endif #endif
}; };
/// std::vector of Rot3s, mainly for wrapper
using Rot3Vector = std::vector<Rot3, Eigen::aligned_allocator<Rot3> >;
/** /**
* [RQ] receives a 3 by 3 matrix and returns an upper triangular matrix R * [RQ] receives a 3 by 3 matrix and returns an upper triangular matrix R
* and 3 rotation angles corresponding to the rotation matrix Q=Qz'*Qy'*Qx' * and 3 rotation angles corresponding to the rotation matrix Q=Qz'*Qy'*Qx'
@ -585,5 +585,6 @@ namespace gtsam {
template<> template<>
struct traits<const Rot3> : public internal::LieGroup<Rot3> {}; struct traits<const Rot3> : public internal::LieGroup<Rot3> {};
}
} // namespace gtsam

View File

@ -10,41 +10,51 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file BayesNet.h * @file BayesNet.h
* @brief Bayes network * @brief Bayes network
* @author Frank Dellaert * @author Frank Dellaert
* @author Richard Roberts * @author Richard Roberts
*/ */
#pragma once #pragma once
#include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph-inst.h>
#include <boost/range/adaptor/reversed.hpp> #include <boost/range/adaptor/reversed.hpp>
#include <fstream> #include <fstream>
#include <string>
namespace gtsam { namespace gtsam {
/* ************************************************************************* */ /* ************************************************************************* */
template <class CONDITIONAL> template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::print( void BayesNet<CONDITIONAL>::print(const std::string& s,
const std::string& s, const KeyFormatter& formatter) const { const KeyFormatter& formatter) const {
Base::print(s, formatter); Base::print(s, formatter);
} }
/* ************************************************************************* */ /* ************************************************************************* */
template <class CONDITIONAL> template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::dot(std::ostream& os, void BayesNet<CONDITIONAL>::dot(std::ostream& os,
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter,
os << "digraph G{\n"; const DotWriter& writer) const {
writer.digraphPreamble(&os);
for (auto conditional : *this) { // Create nodes for each variable in the graph
for (Key key : this->keys()) {
auto position = writer.variablePos(key);
writer.drawVariable(key, keyFormatter, position, &os);
}
os << "\n";
// Reverse order as typically Bayes nets stored in reverse topological sort.
for (auto conditional : boost::adaptors::reverse(*this)) {
auto frontals = conditional->frontals(); auto frontals = conditional->frontals();
const Key me = frontals.front(); const Key me = frontals.front();
auto parents = conditional->parents(); auto parents = conditional->parents();
for (const Key& p : parents) for (const Key& p : parents)
os << keyFormatter(p) << "->" << keyFormatter(me) << "\n"; os << " var" << keyFormatter(p) << "->var" << keyFormatter(me) << "\n";
} }
os << "}"; os << "}";
@ -53,18 +63,20 @@ void BayesNet<CONDITIONAL>::dot(std::ostream& os,
/* ************************************************************************* */ /* ************************************************************************* */
template <class CONDITIONAL> template <class CONDITIONAL>
std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter) const { std::string BayesNet<CONDITIONAL>::dot(const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::stringstream ss; std::stringstream ss;
dot(ss, keyFormatter); dot(ss, keyFormatter, writer);
return ss.str(); return ss.str();
} }
/* ************************************************************************* */ /* ************************************************************************* */
template <class CONDITIONAL> template <class CONDITIONAL>
void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename, void BayesNet<CONDITIONAL>::saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter) const { const KeyFormatter& keyFormatter,
const DotWriter& writer) const {
std::ofstream of(filename.c_str()); std::ofstream of(filename.c_str());
dot(of, keyFormatter); dot(of, keyFormatter, writer);
of.close(); of.close();
} }

View File

@ -10,77 +10,79 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file BayesNet.h * @file BayesNet.h
* @brief Bayes network * @brief Bayes network
* @author Frank Dellaert * @author Frank Dellaert
* @author Richard Roberts * @author Richard Roberts
*/ */
#pragma once #pragma once
#include <boost/shared_ptr.hpp>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <boost/shared_ptr.hpp>
#include <string>
namespace gtsam { namespace gtsam {
/** /**
* A BayesNet is a tree of conditionals, stored in elimination order. * A BayesNet is a tree of conditionals, stored in elimination order.
* * @addtogroup inference
* todo: how to handle Bayes nets with an optimize function? Currently using global functions. */
* \nosubgrouping template <class CONDITIONAL>
*/ class BayesNet : public FactorGraph<CONDITIONAL> {
template<class CONDITIONAL> private:
class BayesNet : public FactorGraph<CONDITIONAL> { typedef FactorGraph<CONDITIONAL> Base;
private: public:
typedef typename boost::shared_ptr<CONDITIONAL>
sharedConditional; ///< A shared pointer to a conditional
typedef FactorGraph<CONDITIONAL> Base; protected:
/// @name Standard Constructors
/// @{
public: /** Default constructor as an empty BayesNet */
typedef typename boost::shared_ptr<CONDITIONAL> sharedConditional; ///< A shared pointer to a conditional BayesNet() {}
protected: /** Construct from iterator over conditionals */
/// @name Standard Constructors template <typename ITERATOR>
/// @{ BayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Default constructor as an empty BayesNet */ /// @}
BayesNet() {};
/** Construct from iterator over conditionals */ public:
template<typename ITERATOR> /// @name Testable
BayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} /// @{
/// @} /** print out graph */
void print(
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
public: /// @}
/// @name Testable
/// @{
/** print out graph */ /// @name Graph Display
void print( /// @{
const std::string& s = "BayesNet",
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
/// @} /// Output to graphviz format, stream version.
void dot(std::ostream& os,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// @name Graph Display /// Output to graphviz format string.
/// @{ std::string dot(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// Output to graphviz format, stream version. /// output to file with graphviz format.
void dot(std::ostream& os, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const; void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DotWriter& writer = DotWriter()) const;
/// Output to graphviz format string. /// @}
std::string dot( };
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// output to file with graphviz format. } // namespace gtsam
void saveGraph(const std::string& filename,
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
};
}
#include <gtsam/inference/BayesNet-inst.h> #include <gtsam/inference/BayesNet-inst.h>

View File

@ -16,30 +16,41 @@
* @date December, 2021 * @date December, 2021
*/ */
#include <gtsam/base/Vector.h>
#include <gtsam/inference/DotWriter.h> #include <gtsam/inference/DotWriter.h>
#include <gtsam/base/Vector.h>
#include <gtsam/inference/Symbol.h>
#include <ostream> #include <ostream>
using namespace std; using namespace std;
namespace gtsam { namespace gtsam {
void DotWriter::writePreamble(ostream* os) const { void DotWriter::graphPreamble(ostream* os) const {
*os << "graph {\n"; *os << "graph {\n";
*os << " size=\"" << figureWidthInches << "," << figureHeightInches *os << " size=\"" << figureWidthInches << "," << figureHeightInches
<< "\";\n\n"; << "\";\n\n";
} }
void DotWriter::DrawVariable(Key key, const KeyFormatter& keyFormatter, void DotWriter::digraphPreamble(ostream* os) const {
*os << "digraph {\n";
*os << " size=\"" << figureWidthInches << "," << figureHeightInches
<< "\";\n\n";
}
void DotWriter::drawVariable(Key key, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position, const boost::optional<Vector2>& position,
ostream* os) { ostream* os) const {
// Label the node with the label from the KeyFormatter // Label the node with the label from the KeyFormatter
*os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key) *os << " var" << keyFormatter(key) << "[label=\"" << keyFormatter(key)
<< "\""; << "\"";
if (position) { if (position) {
*os << ", pos=\"" << position->x() << "," << position->y() << "!\""; *os << ", pos=\"" << position->x() << "," << position->y() << "!\"";
} }
if (boxes.count(key)) {
*os << ", shape=box";
}
*os << "];\n"; *os << "];\n";
} }
@ -53,18 +64,35 @@ void DotWriter::DrawFactor(size_t i, const boost::optional<Vector2>& position,
} }
static void ConnectVariables(Key key1, Key key2, static void ConnectVariables(Key key1, Key key2,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter, ostream* os) {
ostream* os) {
*os << " var" << keyFormatter(key1) << "--" *os << " var" << keyFormatter(key1) << "--"
<< "var" << keyFormatter(key2) << ";\n"; << "var" << keyFormatter(key2) << ";\n";
} }
static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter, static void ConnectVariableFactor(Key key, const KeyFormatter& keyFormatter,
size_t i, ostream* os) { size_t i, ostream* os) {
*os << " var" << keyFormatter(key) << "--" *os << " var" << keyFormatter(key) << "--"
<< "factor" << i << ";\n"; << "factor" << i << ";\n";
} }
/// Return variable position or none
boost::optional<Vector2> DotWriter::variablePos(Key key) const {
boost::optional<Vector2> result = boost::none;
// Check position hint
Symbol symbol(key);
auto hint = positionHints.find(symbol.chr());
if (hint != positionHints.end())
result.reset(Vector2(symbol.index(), hint->second));
// Override with explicit position, if given.
auto pos = variablePositions.find(key);
if (pos != variablePositions.end())
result.reset(pos->second);
return result;
}
void DotWriter::processFactor(size_t i, const KeyVector& keys, void DotWriter::processFactor(size_t i, const KeyVector& keys,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position, const boost::optional<Vector2>& position,
@ -74,7 +102,10 @@ void DotWriter::processFactor(size_t i, const KeyVector& keys,
ConnectVariables(keys[0], keys[1], keyFormatter, os); ConnectVariables(keys[0], keys[1], keyFormatter, os);
} else { } else {
// Create dot for the factor. // Create dot for the factor.
DrawFactor(i, position, os); if (!position && factorPositions.count(i))
DrawFactor(i, factorPositions.at(i), os);
else
DrawFactor(i, position, os);
// Make factor-variable connections // Make factor-variable connections
if (connectKeysToFactor) { if (connectKeysToFactor) {

View File

@ -23,10 +23,15 @@
#include <gtsam/inference/Key.h> #include <gtsam/inference/Key.h>
#include <iosfwd> #include <iosfwd>
#include <map>
#include <set>
namespace gtsam { namespace gtsam {
/// Graphviz formatter. /**
* @brief DotWriter is a helper class for writing graphviz .dot files.
* @addtogroup inference
*/
struct GTSAM_EXPORT DotWriter { struct GTSAM_EXPORT DotWriter {
double figureWidthInches; ///< The figure width on paper in inches double figureWidthInches; ///< The figure width on paper in inches
double figureHeightInches; ///< The figure height on paper in inches double figureHeightInches; ///< The figure height on paper in inches
@ -35,6 +40,28 @@ struct GTSAM_EXPORT DotWriter {
///< the dot of the factor ///< the dot of the factor
bool binaryEdges; ///< just use non-dotted edges for binary factors bool binaryEdges; ///< just use non-dotted edges for binary factors
/**
* Variable positions can be optionally specified and will be included in the
* dot file with a "!' sign, so "neato" can use it to render them.
*/
std::map<Key, Vector2> variablePositions;
/**
* The position hints allow one to use symbol character and index to specify
* position. Unless variable positions are specified, if a hint is present for
* a given symbol, it will be used to calculate the positions as (index,hint).
*/
std::map<char, double> positionHints;
/** A set of keys that will be displayed as a box */
std::set<Key> boxes;
/**
* Factor positions can be optionally specified and will be included in the
* dot file with a "!' sign, so "neato" can use it to render them.
*/
std::map<size_t, Vector2> factorPositions;
explicit DotWriter(double figureWidthInches = 5, explicit DotWriter(double figureWidthInches = 5,
double figureHeightInches = 5, double figureHeightInches = 5,
bool plotFactorPoints = true, bool plotFactorPoints = true,
@ -45,18 +72,24 @@ struct GTSAM_EXPORT DotWriter {
connectKeysToFactor(connectKeysToFactor), connectKeysToFactor(connectKeysToFactor),
binaryEdges(binaryEdges) {} binaryEdges(binaryEdges) {}
/// Write out preamble, including size. /// Write out preamble for graph, including size.
void writePreamble(std::ostream* os) const; void graphPreamble(std::ostream* os) const;
/// Write out preamble for digraph, including size.
void digraphPreamble(std::ostream* os) const;
/// Create a variable dot fragment. /// Create a variable dot fragment.
static void DrawVariable(Key key, const KeyFormatter& keyFormatter, void drawVariable(Key key, const KeyFormatter& keyFormatter,
const boost::optional<Vector2>& position, const boost::optional<Vector2>& position,
std::ostream* os); std::ostream* os) const;
/// Create factor dot. /// Create factor dot.
static void DrawFactor(size_t i, const boost::optional<Vector2>& position, static void DrawFactor(size_t i, const boost::optional<Vector2>& position,
std::ostream* os); std::ostream* os);
/// Return variable position or none
boost::optional<Vector2> variablePos(Key key) const;
/// Draw a single factor, specified by its index i and its variable keys. /// Draw a single factor, specified by its index i and its variable keys.
void processFactor(size_t i, const KeyVector& keys, void processFactor(size_t i, const KeyVector& keys,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,

View File

@ -131,11 +131,12 @@ template <class FACTOR>
void FactorGraph<FACTOR>::dot(std::ostream& os, void FactorGraph<FACTOR>::dot(std::ostream& os,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
const DotWriter& writer) const { const DotWriter& writer) const {
writer.writePreamble(&os); writer.graphPreamble(&os);
// Create nodes for each variable in the graph // Create nodes for each variable in the graph
for (Key key : keys()) { for (Key key : keys()) {
writer.DrawVariable(key, keyFormatter, boost::none, &os); auto position = writer.variablePos(key);
writer.drawVariable(key, keyFormatter, position, &os);
} }
os << "\n"; os << "\n";

168
gtsam/inference/inference.i Normal file
View File

@ -0,0 +1,168 @@
//*************************************************************************
// inference
//*************************************************************************
namespace gtsam {
#include <gtsam/inference/Key.h>
// Default keyformatter
void PrintKeyList(
const gtsam::KeyList& keys, const string& s = "",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
void PrintKeyVector(
const gtsam::KeyVector& keys, const string& s = "",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
void PrintKeySet(
const gtsam::KeySet& keys, const string& s = "",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
#include <gtsam/inference/Symbol.h>
class Symbol {
Symbol();
Symbol(char c, uint64_t j);
Symbol(size_t key);
size_t key() const;
void print(const string& s = "") const;
bool equals(const gtsam::Symbol& expected, double tol) const;
char chr() const;
uint64_t index() const;
string string() const;
};
size_t symbol(char chr, size_t index);
char symbolChr(size_t key);
size_t symbolIndex(size_t key);
namespace symbol_shorthand {
size_t A(size_t j);
size_t B(size_t j);
size_t C(size_t j);
size_t D(size_t j);
size_t E(size_t j);
size_t F(size_t j);
size_t G(size_t j);
size_t H(size_t j);
size_t I(size_t j);
size_t J(size_t j);
size_t K(size_t j);
size_t L(size_t j);
size_t M(size_t j);
size_t N(size_t j);
size_t O(size_t j);
size_t P(size_t j);
size_t Q(size_t j);
size_t R(size_t j);
size_t S(size_t j);
size_t T(size_t j);
size_t U(size_t j);
size_t V(size_t j);
size_t W(size_t j);
size_t X(size_t j);
size_t Y(size_t j);
size_t Z(size_t j);
} // namespace symbol_shorthand
#include <gtsam/inference/LabeledSymbol.h>
class LabeledSymbol {
LabeledSymbol(size_t full_key);
LabeledSymbol(const gtsam::LabeledSymbol& key);
LabeledSymbol(unsigned char valType, unsigned char label, size_t j);
size_t key() const;
unsigned char label() const;
unsigned char chr() const;
size_t index() const;
gtsam::LabeledSymbol upper() const;
gtsam::LabeledSymbol lower() const;
gtsam::LabeledSymbol newChr(unsigned char c) const;
gtsam::LabeledSymbol newLabel(unsigned char label) const;
void print(string s = "") const;
};
size_t mrsymbol(unsigned char c, unsigned char label, size_t j);
unsigned char mrsymbolChr(size_t key);
unsigned char mrsymbolLabel(size_t key);
size_t mrsymbolIndex(size_t key);
#include <gtsam/inference/Ordering.h>
class Ordering {
/// Type of ordering to use
enum OrderingType { COLAMD, METIS, NATURAL, CUSTOM };
// Standard Constructors and Named Constructors
Ordering();
Ordering(const gtsam::Ordering& other);
template <FACTOR_GRAPH = {gtsam::NonlinearFactorGraph,
gtsam::GaussianFactorGraph}>
static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph);
// Testable
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::Ordering& ord, double tol) const;
// Standard interface
size_t size() const;
size_t at(size_t key) const;
void push_back(size_t key);
// enabling serialization functionality
void serialize() const;
};
#include <gtsam/inference/DotWriter.h>
class DotWriter {
DotWriter(double figureWidthInches = 5, double figureHeightInches = 5,
bool plotFactorPoints = true, bool connectKeysToFactor = true,
bool binaryEdges = true);
double figureWidthInches;
double figureHeightInches;
bool plotFactorPoints;
bool connectKeysToFactor;
bool binaryEdges;
std::map<gtsam::Key, gtsam::Vector2> variablePositions;
std::map<char, double> positionHints;
std::set<Key> boxes;
std::map<size_t, gtsam::Vector2> factorPositions;
};
#include <gtsam/inference/VariableIndex.h>
// Headers for overloaded methods below, break hierarchy :-/
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
#include <gtsam/symbolic/SymbolicFactorGraph.h>
class VariableIndex {
// Standard Constructors and Named Constructors
VariableIndex();
// TODO: Templetize constructor when wrap supports it
// template<T = {gtsam::FactorGraph}>
// VariableIndex(const T& factorGraph, size_t nVariables);
// VariableIndex(const T& factorGraph);
VariableIndex(const gtsam::SymbolicFactorGraph& sfg);
VariableIndex(const gtsam::GaussianFactorGraph& gfg);
VariableIndex(const gtsam::NonlinearFactorGraph& fg);
VariableIndex(const gtsam::VariableIndex& other);
// Testable
bool equals(const gtsam::VariableIndex& other, double tol) const;
void print(string s = "VariableIndex: ",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
// Standard interface
size_t size() const;
size_t nFactors() const;
size_t nEntries() const;
};
} // namespace gtsam

View File

@ -205,23 +205,5 @@ namespace gtsam {
} }
/* ************************************************************************* */ /* ************************************************************************* */
void GaussianBayesNet::saveGraph(const std::string& s,
const KeyFormatter& keyFormatter) const {
std::ofstream of(s.c_str());
of << "digraph G{\n";
for (auto conditional : boost::adaptors::reverse(*this)) {
typename GaussianConditional::Frontals frontals = conditional->frontals();
Key me = frontals.front();
typename GaussianConditional::Parents parents = conditional->parents();
for (Key p : parents)
of << keyFormatter(p) << "->" << keyFormatter(me) << std::endl;
}
of << "}";
of.close();
}
/* ************************************************************************* */
} // namespace gtsam } // namespace gtsam

View File

@ -21,17 +21,22 @@
#pragma once #pragma once
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <gtsam/global_includes.h> #include <gtsam/global_includes.h>
#include <utility>
namespace gtsam { namespace gtsam {
/** A Bayes net made from linear-Gaussian densities */ /**
class GTSAM_EXPORT GaussianBayesNet: public FactorGraph<GaussianConditional> * GaussianBayesNet is a Bayes net made from linear-Gaussian conditionals.
* @addtogroup linear
*/
class GTSAM_EXPORT GaussianBayesNet: public BayesNet<GaussianConditional>
{ {
public: public:
typedef FactorGraph<GaussianConditional> Base; typedef BayesNet<GaussianConditional> Base;
typedef GaussianBayesNet This; typedef GaussianBayesNet This;
typedef GaussianConditional ConditionalType; typedef GaussianConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
@ -44,16 +49,21 @@ namespace gtsam {
GaussianBayesNet() {} GaussianBayesNet() {}
/** Construct from iterator over conditionals */ /** Construct from iterator over conditionals */
template<typename ITERATOR> template <typename ITERATOR>
GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} GaussianBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Construct from container of factors (shared_ptr or plain objects) */ /** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER> template <class CONTAINER>
explicit GaussianBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} explicit GaussianBayesNet(const CONTAINER& conditionals) {
push_back(conditionals);
}
/** Implicit copy/downcast constructor to override explicit template container constructor */ /** Implicit copy/downcast constructor to override explicit template
template<class DERIVEDCONDITIONAL> * container constructor */
GaussianBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {} template <class DERIVEDCONDITIONAL>
explicit GaussianBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}
/// Destructor /// Destructor
virtual ~GaussianBayesNet() {} virtual ~GaussianBayesNet() {}
@ -66,6 +76,13 @@ namespace gtsam {
/** Check equality */ /** Check equality */
bool equals(const This& bn, double tol = 1e-9) const; bool equals(const This& bn, double tol = 1e-9) const;
/// print graph
void print(
const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}
/// @} /// @}
/// @name Standard Interface /// @name Standard Interface
@ -180,23 +197,6 @@ namespace gtsam {
*/ */
VectorValues backSubstituteTranspose(const VectorValues& gx) const; VectorValues backSubstituteTranspose(const VectorValues& gx) const;
/// print graph
void print(
const std::string& s = "",
const KeyFormatter& formatter = DefaultKeyFormatter) const override {
Base::print(s, formatter);
}
/**
* @brief Save the GaussianBayesNet as an image. Requires `dot` to be
* installed.
*
* @param s The name of the figure.
* @param keyFormatter Formatter to use for styling keys in the graph.
*/
void saveGraph(const std::string& s, const KeyFormatter& keyFormatter =
DefaultKeyFormatter) const;
/// @} /// @}
private: private:

View File

@ -437,42 +437,53 @@ class GaussianFactorGraph {
pair<Matrix,Vector> hessian() const; pair<Matrix,Vector> hessian() const;
pair<Matrix,Vector> hessian(const gtsam::Ordering& ordering) const; pair<Matrix,Vector> hessian(const gtsam::Ordering& ordering) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
// enabling serialization functionality // enabling serialization functionality
void serialize() const; void serialize() const;
}; };
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
virtual class GaussianConditional : gtsam::JacobianFactor { virtual class GaussianConditional : gtsam::JacobianFactor {
//Constructors // Constructors
GaussianConditional(size_t key, Vector d, Matrix R, const gtsam::noiseModel::Diagonal* sigmas); GaussianConditional(size_t key, Vector d, Matrix R,
const gtsam::noiseModel::Diagonal* sigmas);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
const gtsam::noiseModel::Diagonal* sigmas); const gtsam::noiseModel::Diagonal* sigmas);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
size_t name2, Matrix T, const gtsam::noiseModel::Diagonal* sigmas); size_t name2, Matrix T,
const gtsam::noiseModel::Diagonal* sigmas);
//Constructors with no noise model // Constructors with no noise model
GaussianConditional(size_t key, Vector d, Matrix R); GaussianConditional(size_t key, Vector d, Matrix R);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S); GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S);
GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S, GaussianConditional(size_t key, Vector d, Matrix R, size_t name1, Matrix S,
size_t name2, Matrix T); size_t name2, Matrix T);
//Standard Interface // Standard Interface
void print(string s = "GaussianConditional", void print(string s = "GaussianConditional",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::GaussianConditional& cg, double tol) const; bool equals(const gtsam::GaussianConditional& cg, double tol) const;
gtsam::Key firstFrontalKey() const;
// Advanced Interface
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
const gtsam::VectorValues& rhs) const;
void solveTransposeInPlace(gtsam::VectorValues& gy) const;
Matrix R() const;
Matrix S() const;
Vector d() const;
// Advanced Interface // enabling serialization functionality
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const; void serialize() const;
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
const gtsam::VectorValues& rhs) const;
void solveTransposeInPlace(gtsam::VectorValues& gy) const;
Matrix R() const;
Matrix S() const;
Vector d() const;
// enabling serialization functionality
void serialize() const;
}; };
#include <gtsam/linear/GaussianDensity.h> #include <gtsam/linear/GaussianDensity.h>
@ -524,6 +535,14 @@ virtual class GaussianBayesNet {
double logDeterminant() const; double logDeterminant() const;
gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const; gtsam::VectorValues backSubstitute(const gtsam::VectorValues& gx) const;
gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const; gtsam::VectorValues backSubstituteTranspose(const gtsam::VectorValues& gx) const;
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
}; };
#include <gtsam/linear/GaussianBayesTree.h> #include <gtsam/linear/GaussianBayesTree.h>

View File

@ -301,5 +301,31 @@ TEST(GaussianBayesNet, ComputeSteepestDescentPoint) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
int main() { TestResult tr; return TestRegistry::runAllTests(tr);} TEST(GaussianBayesNet, Dot) {
GaussianBayesNet fragment;
DotWriter writer;
writer.variablePositions.emplace(_x_, Vector2(10, 20));
writer.variablePositions.emplace(_y_, Vector2(50, 20));
auto position = writer.variablePos(_x_);
CHECK(position);
EXPECT(assert_equal(Vector2(10, 20), *position, 1e-5));
string actual = noisyBayesNet.dot(DefaultKeyFormatter, writer);
EXPECT(actual ==
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" var11[label=\"11\", pos=\"10,20!\"];\n"
" var22[label=\"22\", pos=\"50,20!\"];\n"
"\n"
" var22->var11\n"
"}");
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -34,7 +34,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
min.y() = std::numeric_limits<double>::infinity(); min.y() = std::numeric_limits<double>::infinity();
for (const Key& key : keys) { for (const Key& key : keys) {
if (values.exists(key)) { if (values.exists(key)) {
boost::optional<Vector2> xy = operator()(values.at(key)); boost::optional<Vector2> xy = extractPosition(values.at(key));
if (xy) { if (xy) {
if (xy->x() < min.x()) min.x() = xy->x(); if (xy->x() < min.x()) min.x() = xy->x();
if (xy->y() < min.y()) min.y() = xy->y(); if (xy->y() < min.y()) min.y() = xy->y();
@ -44,7 +44,7 @@ Vector2 GraphvizFormatting::findBounds(const Values& values,
return min; return min;
} }
boost::optional<Vector2> GraphvizFormatting::operator()( boost::optional<Vector2> GraphvizFormatting::extractPosition(
const Value& value) const { const Value& value) const {
Vector3 t; Vector3 t;
if (const GenericValue<Pose2>* p = if (const GenericValue<Pose2>* p =
@ -121,12 +121,11 @@ boost::optional<Vector2> GraphvizFormatting::operator()(
return Vector2(x, y); return Vector2(x, y);
} }
// Return affinely transformed variable position if it exists.
boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values, boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
const Vector2& min, const Vector2& min,
Key key) const { Key key) const {
if (!values.exists(key)) return boost::none; if (!values.exists(key)) return DotWriter::variablePos(key);
boost::optional<Vector2> xy = operator()(values.at(key)); boost::optional<Vector2> xy = extractPosition(values.at(key));
if (xy) { if (xy) {
xy->x() = scale * (xy->x() - min.x()); xy->x() = scale * (xy->x() - min.x());
xy->y() = scale * (xy->y() - min.y()); xy->y() = scale * (xy->y() - min.y());
@ -134,7 +133,6 @@ boost::optional<Vector2> GraphvizFormatting::variablePos(const Values& values,
return xy; return xy;
} }
// Return affinely transformed factor position if it exists.
boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min, boost::optional<Vector2> GraphvizFormatting::factorPos(const Vector2& min,
size_t i) const { size_t i) const {
if (factorPositions.size() == 0) return boost::none; if (factorPositions.size() == 0) return boost::none;

View File

@ -33,17 +33,14 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
/// World axes to be assigned to paper axes /// World axes to be assigned to paper axes
enum Axis { X, Y, Z, NEGX, NEGY, NEGZ }; enum Axis { X, Y, Z, NEGX, NEGY, NEGZ };
Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal Axis paperHorizontalAxis; ///< The world axis assigned to the horizontal
///< paper axis ///< paper axis
Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper Axis paperVerticalAxis; ///< The world axis assigned to the vertical paper
///< axis ///< axis
double scale; ///< Scale all positions to reduce / increase density double scale; ///< Scale all positions to reduce / increase density
bool mergeSimilarFactors; ///< Merge multiple factors that have the same bool mergeSimilarFactors; ///< Merge multiple factors that have the same
///< connectivity ///< connectivity
/// (optional for each factor) Manually specify factor "dot" positions:
std::map<size_t, Vector2> factorPositions;
/// Default constructor sets up robot coordinates. Paper horizontal is robot /// Default constructor sets up robot coordinates. Paper horizontal is robot
/// Y, paper vertical is robot X. Default figure size of 5x5 in. /// Y, paper vertical is robot X. Default figure size of 5x5 in.
GraphvizFormatting() GraphvizFormatting()
@ -55,8 +52,8 @@ struct GTSAM_EXPORT GraphvizFormatting : public DotWriter {
// Find bounds // Find bounds
Vector2 findBounds(const Values& values, const KeySet& keys) const; Vector2 findBounds(const Values& values, const KeySet& keys) const;
/// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3 /// Extract a Vector2 from either Vector2, Pose2, Pose3, or Point3
boost::optional<Vector2> operator()(const Value& value) const; boost::optional<Vector2> extractPosition(const Value& value) const;
/// Return affinely transformed variable position if it exists. /// Return affinely transformed variable position if it exists.
boost::optional<Vector2> variablePos(const Values& values, const Vector2& min, boost::optional<Vector2> variablePos(const Values& values, const Vector2& min,

View File

@ -102,7 +102,7 @@ bool NonlinearFactorGraph::equals(const NonlinearFactorGraph& other, double tol)
void NonlinearFactorGraph::dot(std::ostream& os, const Values& values, void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
const KeyFormatter& keyFormatter, const KeyFormatter& keyFormatter,
const GraphvizFormatting& writer) const { const GraphvizFormatting& writer) const {
writer.writePreamble(&os); writer.graphPreamble(&os);
// Find bounds (imperative) // Find bounds (imperative)
KeySet keys = this->keys(); KeySet keys = this->keys();
@ -111,7 +111,7 @@ void NonlinearFactorGraph::dot(std::ostream& os, const Values& values,
// Create nodes for each variable in the graph // Create nodes for each variable in the graph
for (Key key : keys) { for (Key key : keys) {
auto position = writer.variablePos(values, min, key); auto position = writer.variablePos(values, min, key);
writer.DrawVariable(key, keyFormatter, position, &os); writer.drawVariable(key, keyFormatter, position, &os);
} }
os << "\n"; os << "\n";

View File

@ -43,12 +43,14 @@ namespace gtsam {
class ExpressionFactor; class ExpressionFactor;
/** /**
* A non-linear factor graph is a graph of non-Gaussian, i.e. non-linear factors, * A NonlinearFactorGraph is a graph of non-Gaussian, i.e. non-linear factors,
* which derive from NonlinearFactor. The values structures are typically (in SAM) more general * which derive from NonlinearFactor. The values structures are typically (in
* than just vectors, e.g., Rot3 or Pose3, which are objects in non-linear manifolds. * SAM) more general than just vectors, e.g., Rot3 or Pose3, which are objects
* Linearizing the non-linear factor graph creates a linear factor graph on the * in non-linear manifolds. Linearizing the non-linear factor graph creates a
* tangent vector space at the linearization point. Because the tangent space is a true * linear factor graph on the tangent vector space at the linearization point.
* vector space, the config type will be an VectorValues in that linearized factor graph. * Because the tangent space is a true vector space, the config type will be
* an VectorValues in that linearized factor graph.
* @addtogroup nonlinear
*/ */
class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph<NonlinearFactor> { class GTSAM_EXPORT NonlinearFactorGraph: public FactorGraph<NonlinearFactor> {
@ -58,6 +60,9 @@ namespace gtsam {
typedef NonlinearFactorGraph This; typedef NonlinearFactorGraph This;
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
/// @name Standard Constructors
/// @{
/** Default constructor */ /** Default constructor */
NonlinearFactorGraph() {} NonlinearFactorGraph() {}
@ -76,6 +81,10 @@ namespace gtsam {
/// Destructor /// Destructor
virtual ~NonlinearFactorGraph() {} virtual ~NonlinearFactorGraph() {}
/// @}
/// @name Testable
/// @{
/** print */ /** print */
void print( void print(
const std::string& str = "NonlinearFactorGraph: ", const std::string& str = "NonlinearFactorGraph: ",
@ -90,6 +99,10 @@ namespace gtsam {
/** Test equality */ /** Test equality */
bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const; bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const;
/// @}
/// @name Standard Interface
/// @{
/** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */ /** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */
double error(const Values& values) const; double error(const Values& values) const;
@ -206,6 +219,7 @@ namespace gtsam {
emplace_shared<PriorFactor<T>>(key, prior, covariance); emplace_shared<PriorFactor<T>>(key, prior, covariance);
} }
/// @}
/// @name Graph Display /// @name Graph Display
/// @{ /// @{
@ -215,20 +229,19 @@ namespace gtsam {
/// Output to graphviz format, stream version, with Values/extra options. /// Output to graphviz format, stream version, with Values/extra options.
void dot(std::ostream& os, const Values& values, void dot(std::ostream& os, const Values& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const GraphvizFormatting& graphvizFormatting = const GraphvizFormatting& writer = GraphvizFormatting()) const;
GraphvizFormatting()) const;
/// Output to graphviz format string, with Values/extra options. /// Output to graphviz format string, with Values/extra options.
std::string dot(const Values& values, std::string dot(
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const Values& values,
const GraphvizFormatting& graphvizFormatting = const KeyFormatter& keyFormatter = DefaultKeyFormatter,
GraphvizFormatting()) const; const GraphvizFormatting& writer = GraphvizFormatting()) const;
/// output to file with graphviz format, with Values/extra options. /// output to file with graphviz format, with Values/extra options.
void saveGraph(const std::string& filename, const Values& values, void saveGraph(
const KeyFormatter& keyFormatter = DefaultKeyFormatter, const std::string& filename, const Values& values,
const GraphvizFormatting& graphvizFormatting = const KeyFormatter& keyFormatter = DefaultKeyFormatter,
GraphvizFormatting()) const; const GraphvizFormatting& writer = GraphvizFormatting()) const;
/// @} /// @}
private: private:
@ -251,6 +264,8 @@ namespace gtsam {
public: public:
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42 #ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated
/// @{
/** @deprecated */ /** @deprecated */
boost::shared_ptr<HessianFactor> GTSAM_DEPRECATED linearizeToHessianFactor( boost::shared_ptr<HessianFactor> GTSAM_DEPRECATED linearizeToHessianFactor(
const Values& values, boost::none_t, const Dampen& dampen = nullptr) const const Values& values, boost::none_t, const Dampen& dampen = nullptr) const
@ -275,6 +290,7 @@ namespace gtsam {
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
saveGraph(filename, values, keyFormatter, graphvizFormatting); saveGraph(filename, values, keyFormatter, graphvizFormatting);
} }
/// @}
#endif #endif
}; };

View File

@ -23,121 +23,9 @@ namespace gtsam {
#include <gtsam/geometry/SOn.h> #include <gtsam/geometry/SOn.h>
#include <gtsam/geometry/StereoPoint2.h> #include <gtsam/geometry/StereoPoint2.h>
#include <gtsam/geometry/Unit3.h> #include <gtsam/geometry/Unit3.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/navigation/ImuBias.h> #include <gtsam/navigation/ImuBias.h>
#include <gtsam/navigation/NavState.h> #include <gtsam/navigation/NavState.h>
class Symbol {
Symbol();
Symbol(char c, uint64_t j);
Symbol(size_t key);
size_t key() const;
void print(const string& s = "") const;
bool equals(const gtsam::Symbol& expected, double tol) const;
char chr() const;
uint64_t index() const;
string string() const;
};
size_t symbol(char chr, size_t index);
char symbolChr(size_t key);
size_t symbolIndex(size_t key);
namespace symbol_shorthand {
size_t A(size_t j);
size_t B(size_t j);
size_t C(size_t j);
size_t D(size_t j);
size_t E(size_t j);
size_t F(size_t j);
size_t G(size_t j);
size_t H(size_t j);
size_t I(size_t j);
size_t J(size_t j);
size_t K(size_t j);
size_t L(size_t j);
size_t M(size_t j);
size_t N(size_t j);
size_t O(size_t j);
size_t P(size_t j);
size_t Q(size_t j);
size_t R(size_t j);
size_t S(size_t j);
size_t T(size_t j);
size_t U(size_t j);
size_t V(size_t j);
size_t W(size_t j);
size_t X(size_t j);
size_t Y(size_t j);
size_t Z(size_t j);
} // namespace symbol_shorthand
// Default keyformatter
void PrintKeyList(
const gtsam::KeyList& keys, const string& s = "",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
void PrintKeyVector(
const gtsam::KeyVector& keys, const string& s = "",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
void PrintKeySet(
const gtsam::KeySet& keys, const string& s = "",
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
#include <gtsam/inference/LabeledSymbol.h>
class LabeledSymbol {
LabeledSymbol(size_t full_key);
LabeledSymbol(const gtsam::LabeledSymbol& key);
LabeledSymbol(unsigned char valType, unsigned char label, size_t j);
size_t key() const;
unsigned char label() const;
unsigned char chr() const;
size_t index() const;
gtsam::LabeledSymbol upper() const;
gtsam::LabeledSymbol lower() const;
gtsam::LabeledSymbol newChr(unsigned char c) const;
gtsam::LabeledSymbol newLabel(unsigned char label) const;
void print(string s = "") const;
};
size_t mrsymbol(unsigned char c, unsigned char label, size_t j);
unsigned char mrsymbolChr(size_t key);
unsigned char mrsymbolLabel(size_t key);
size_t mrsymbolIndex(size_t key);
#include <gtsam/inference/Ordering.h>
class Ordering {
/// Type of ordering to use
enum OrderingType {
COLAMD, METIS, NATURAL, CUSTOM
};
// Standard Constructors and Named Constructors
Ordering();
Ordering(const gtsam::Ordering& other);
template <FACTOR_GRAPH = {gtsam::NonlinearFactorGraph,
gtsam::GaussianFactorGraph}>
static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph);
// Testable
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
bool equals(const gtsam::Ordering& ord, double tol) const;
// Standard interface
size_t size() const;
size_t at(size_t key) const;
void push_back(size_t key);
// enabling serialization functionality
void serialize() const;
};
#include <gtsam/nonlinear/GraphvizFormatting.h> #include <gtsam/nonlinear/GraphvizFormatting.h>
class GraphvizFormatting : gtsam::DotWriter { class GraphvizFormatting : gtsam::DotWriter {
GraphvizFormatting(); GraphvizFormatting();
@ -207,18 +95,17 @@ class NonlinearFactorGraph {
gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const; gtsam::GaussianFactorGraph* linearize(const gtsam::Values& values) const;
gtsam::NonlinearFactorGraph clone() const; gtsam::NonlinearFactorGraph clone() const;
// enabling serialization functionality
void serialize() const;
string dot( string dot(
const gtsam::Values& values, const gtsam::Values& values,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const GraphvizFormatting& writer = GraphvizFormatting()); const GraphvizFormatting& formatting = GraphvizFormatting());
void saveGraph(const string& s, const gtsam::Values& values, void saveGraph(
const gtsam::KeyFormatter& keyFormatter = const string& s, const gtsam::Values& values,
gtsam::DefaultKeyFormatter, const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const GraphvizFormatting& writer = const GraphvizFormatting& formatting = GraphvizFormatting()) const;
GraphvizFormatting()) const;
// enabling serialization functionality
void serialize() const;
}; };
#include <gtsam/nonlinear/NonlinearFactor.h> #include <gtsam/nonlinear/NonlinearFactor.h>

View File

@ -323,6 +323,8 @@ virtual class KarcherMeanFactor : gtsam::NonlinearFactor {
KarcherMeanFactor(const gtsam::KeyVector& keys); KarcherMeanFactor(const gtsam::KeyVector& keys);
}; };
gtsam::Rot3 FindKarcherMean(const gtsam::Rot3Vector& rotations);
#include <gtsam/slam/FrobeniusFactor.h> #include <gtsam/slam/FrobeniusFactor.h>
gtsam::noiseModel::Isotropic* ConvertNoiseModel(gtsam::noiseModel::Base* model, gtsam::noiseModel::Isotropic* ConvertNoiseModel(gtsam::noiseModel::Base* model,
size_t d); size_t d);

View File

@ -16,41 +16,16 @@
* @author Richard Roberts * @author Richard Roberts
*/ */
#include <gtsam/symbolic/SymbolicBayesNet.h>
#include <gtsam/symbolic/SymbolicConditional.h>
#include <gtsam/inference/FactorGraph-inst.h> #include <gtsam/inference/FactorGraph-inst.h>
#include <gtsam/symbolic/SymbolicBayesNet.h>
#include <boost/range/adaptor/reversed.hpp>
#include <fstream>
namespace gtsam { namespace gtsam {
// Instantiate base class // Instantiate base class
template class FactorGraph<SymbolicConditional>; template class FactorGraph<SymbolicConditional>;
/* ************************************************************************* */
bool SymbolicBayesNet::equals(const This& bn, double tol) const
{
return Base::equals(bn, tol);
}
/* ************************************************************************* */
void SymbolicBayesNet::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const
{
std::ofstream of(s.c_str());
of << "digraph G{\n";
for (auto conditional: boost::adaptors::reverse(*this)) {
SymbolicConditional::Frontals frontals = conditional->frontals();
Key me = frontals.front();
SymbolicConditional::Parents parents = conditional->parents();
for(Key p: parents)
of << p << "->" << me << std::endl;
}
of << "}";
of.close();
}
/* ************************************************************************* */
bool SymbolicBayesNet::equals(const This& bn, double tol) const {
return Base::equals(bn, tol);
} }
} // namespace gtsam

View File

@ -19,19 +19,19 @@
#pragma once #pragma once
#include <gtsam/symbolic/SymbolicConditional.h> #include <gtsam/symbolic/SymbolicConditional.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h> #include <gtsam/inference/FactorGraph.h>
#include <gtsam/base/types.h> #include <gtsam/base/types.h>
namespace gtsam { namespace gtsam {
/** Symbolic Bayes Net /**
* \nosubgrouping * A SymbolicBayesNet is a Bayes Net of purely symbolic conditionals.
* @addtogroup symbolic
*/ */
class SymbolicBayesNet : public FactorGraph<SymbolicConditional> { class SymbolicBayesNet : public BayesNet<SymbolicConditional> {
public:
public: typedef BayesNet<SymbolicConditional> Base;
typedef FactorGraph<SymbolicConditional> Base;
typedef SymbolicBayesNet This; typedef SymbolicBayesNet This;
typedef SymbolicConditional ConditionalType; typedef SymbolicConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr; typedef boost::shared_ptr<This> shared_ptr;
@ -44,16 +44,21 @@ namespace gtsam {
SymbolicBayesNet() {} SymbolicBayesNet() {}
/** Construct from iterator over conditionals */ /** Construct from iterator over conditionals */
template<typename ITERATOR> template <typename ITERATOR>
SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional) : Base(firstConditional, lastConditional) {} SymbolicBayesNet(ITERATOR firstConditional, ITERATOR lastConditional)
: Base(firstConditional, lastConditional) {}
/** Construct from container of factors (shared_ptr or plain objects) */ /** Construct from container of factors (shared_ptr or plain objects) */
template<class CONTAINER> template <class CONTAINER>
explicit SymbolicBayesNet(const CONTAINER& conditionals) : Base(conditionals) {} explicit SymbolicBayesNet(const CONTAINER& conditionals) {
push_back(conditionals);
}
/** Implicit copy/downcast constructor to override explicit template container constructor */ /** Implicit copy/downcast constructor to override explicit template
template<class DERIVEDCONDITIONAL> * container constructor */
SymbolicBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph) : Base(graph) {} template <class DERIVEDCONDITIONAL>
explicit SymbolicBayesNet(const FactorGraph<DERIVEDCONDITIONAL>& graph)
: Base(graph) {}
/// Destructor /// Destructor
virtual ~SymbolicBayesNet() {} virtual ~SymbolicBayesNet() {}
@ -75,13 +80,6 @@ namespace gtsam {
/// @} /// @}
/// @name Standard Interface
/// @{
GTSAM_EXPORT void saveGraph(const std::string &s, const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
/// @}
private: private:
/** Serialization function */ /** Serialization function */
friend class boost::serialization::access; friend class boost::serialization::access;

View File

@ -3,11 +3,6 @@
//************************************************************************* //*************************************************************************
namespace gtsam { namespace gtsam {
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
// ###################
#include <gtsam/symbolic/SymbolicFactor.h> #include <gtsam/symbolic/SymbolicFactor.h>
virtual class SymbolicFactor { virtual class SymbolicFactor {
// Standard Constructors and Named Constructors // Standard Constructors and Named Constructors
@ -82,6 +77,14 @@ virtual class SymbolicFactorGraph {
const gtsam::KeyVector& key_vector, const gtsam::KeyVector& key_vector,
const gtsam::Ordering& marginalizedVariableOrdering); const gtsam::Ordering& marginalizedVariableOrdering);
gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector); gtsam::SymbolicFactorGraph* marginal(const gtsam::KeyVector& key_vector);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
}; };
#include <gtsam/symbolic/SymbolicConditional.h> #include <gtsam/symbolic/SymbolicConditional.h>
@ -103,6 +106,7 @@ virtual class SymbolicConditional : gtsam::SymbolicFactor {
bool equals(const gtsam::SymbolicConditional& other, double tol) const; bool equals(const gtsam::SymbolicConditional& other, double tol) const;
// Standard interface // Standard interface
gtsam::Key firstFrontalKey() const;
size_t nrFrontals() const; size_t nrFrontals() const;
size_t nrParents() const; size_t nrParents() const;
}; };
@ -125,6 +129,14 @@ class SymbolicBayesNet {
gtsam::SymbolicConditional* back() const; gtsam::SymbolicConditional* back() const;
void push_back(gtsam::SymbolicConditional* conditional); void push_back(gtsam::SymbolicConditional* conditional);
void push_back(const gtsam::SymbolicBayesNet& bayesNet); void push_back(const gtsam::SymbolicBayesNet& bayesNet);
string dot(
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
void saveGraph(
string s,
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter,
const gtsam::DotWriter& writer = gtsam::DotWriter()) const;
}; };
#include <gtsam/symbolic/SymbolicBayesTree.h> #include <gtsam/symbolic/SymbolicBayesTree.h>
@ -173,29 +185,4 @@ class SymbolicBayesTreeClique {
void deleteCachedShortcuts(); void deleteCachedShortcuts();
}; };
#include <gtsam/inference/VariableIndex.h>
class VariableIndex {
// Standard Constructors and Named Constructors
VariableIndex();
// TODO: Templetize constructor when wrap supports it
// template<T = {gtsam::FactorGraph}>
// VariableIndex(const T& factorGraph, size_t nVariables);
// VariableIndex(const T& factorGraph);
VariableIndex(const gtsam::SymbolicFactorGraph& sfg);
VariableIndex(const gtsam::GaussianFactorGraph& gfg);
VariableIndex(const gtsam::NonlinearFactorGraph& fg);
VariableIndex(const gtsam::VariableIndex& other);
// Testable
bool equals(const gtsam::VariableIndex& other, double tol) const;
void print(string s = "VariableIndex: ",
const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
// Standard interface
size_t size() const;
size_t nFactors() const;
size_t nEntries() const;
};
} // namespace gtsam } // namespace gtsam

View File

@ -15,13 +15,16 @@
* @author Frank Dellaert * @author Frank Dellaert
*/ */
#include <boost/make_shared.hpp> #include <gtsam/symbolic/SymbolicBayesNet.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/Vector.h>
#include <gtsam/base/VectorSpace.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/symbolic/SymbolicConditional.h>
#include <CppUnitLite/TestHarness.h> #include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h> #include <boost/make_shared.hpp>
#include <gtsam/symbolic/SymbolicBayesNet.h>
#include <gtsam/symbolic/SymbolicConditional.h>
using namespace std; using namespace std;
using namespace gtsam; using namespace gtsam;
@ -30,7 +33,6 @@ static const Key _L_ = 0;
static const Key _A_ = 1; static const Key _A_ = 1;
static const Key _B_ = 2; static const Key _B_ = 2;
static const Key _C_ = 3; static const Key _C_ = 3;
static const Key _D_ = 4;
static SymbolicConditional::shared_ptr static SymbolicConditional::shared_ptr
B(new SymbolicConditional(_B_)), B(new SymbolicConditional(_B_)),
@ -78,14 +80,41 @@ TEST( SymbolicBayesNet, combine )
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(SymbolicBayesNet, saveGraph) { TEST(SymbolicBayesNet, Dot) {
using symbol_shorthand::A;
using symbol_shorthand::X;
SymbolicBayesNet bn; SymbolicBayesNet bn;
bn += SymbolicConditional(_A_, _B_); bn += SymbolicConditional(X(3), X(2), A(2));
KeyVector keys {_B_, _C_, _D_}; bn += SymbolicConditional(X(2), X(1), A(1));
bn += SymbolicConditional::FromKeys(keys,2); bn += SymbolicConditional(X(1));
bn += SymbolicConditional(_D_);
bn.saveGraph("SymbolicBayesNet.dot"); DotWriter writer;
writer.positionHints.emplace('a', 2);
writer.positionHints.emplace('x', 1);
writer.boxes.emplace(A(1));
writer.boxes.emplace(A(2));
auto position = writer.variablePos(A(1));
CHECK(position);
EXPECT(assert_equal(Vector2(1, 2), *position, 1e-5));
string actual = bn.dot(DefaultKeyFormatter, writer);
bn.saveGraph("bn.dot", DefaultKeyFormatter, writer);
EXPECT(actual ==
"digraph {\n"
" size=\"5,5\";\n"
"\n"
" vara1[label=\"a1\", pos=\"1,2!\", shape=box];\n"
" vara2[label=\"a2\", pos=\"2,2!\", shape=box];\n"
" varx1[label=\"x1\", pos=\"1,1!\"];\n"
" varx2[label=\"x2\", pos=\"2,1!\"];\n"
" varx3[label=\"x3\", pos=\"3,1!\"];\n"
"\n"
" varx1->varx2\n"
" vara1->varx2\n"
" varx2->varx3\n"
" vara2->varx3\n"
"}");
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -45,6 +45,7 @@ set(ignore
gtsam::Point3Pairs gtsam::Point3Pairs
gtsam::Pose3Pairs gtsam::Pose3Pairs
gtsam::Pose3Vector gtsam::Pose3Vector
gtsam::Rot3Vector
gtsam::KeyVector gtsam::KeyVector
gtsam::BinaryMeasurementsUnit3 gtsam::BinaryMeasurementsUnit3
gtsam::DiscreteKey gtsam::DiscreteKey
@ -53,6 +54,7 @@ set(ignore
set(interface_headers set(interface_headers
${PROJECT_SOURCE_DIR}/gtsam/gtsam.i ${PROJECT_SOURCE_DIR}/gtsam/gtsam.i
${PROJECT_SOURCE_DIR}/gtsam/base/base.i ${PROJECT_SOURCE_DIR}/gtsam/base/base.i
${PROJECT_SOURCE_DIR}/gtsam/inference/inference.i
${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i ${PROJECT_SOURCE_DIR}/gtsam/discrete/discrete.i
${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i ${PROJECT_SOURCE_DIR}/gtsam/geometry/geometry.i
${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i ${PROJECT_SOURCE_DIR}/gtsam/linear/linear.i

View File

@ -0,0 +1,15 @@
/* Please refer to:
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
* These are required to save one copy operation on Python calls.
*
* NOTES
* =================
*
* `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11
* automatic STL binding, such that the raw objects can be accessed in Python.
* Without this they will be automatically converted to a Python object, and all
* mutations on Python side will not be reflected on C++.
*/
#include <pybind11/stl.h>

View File

@ -15,3 +15,4 @@ PYBIND11_MAKE_OPAQUE(
std::vector<boost::shared_ptr<gtsam::BetweenFactor<gtsam::Pose3> > >); std::vector<boost::shared_ptr<gtsam::BetweenFactor<gtsam::Pose3> > >);
PYBIND11_MAKE_OPAQUE( PYBIND11_MAKE_OPAQUE(
std::vector<boost::shared_ptr<gtsam::BetweenFactor<gtsam::Pose2> > >); std::vector<boost::shared_ptr<gtsam::BetweenFactor<gtsam::Pose2> > >);
PYBIND11_MAKE_OPAQUE(gtsam::Rot3Vector);

View File

@ -0,0 +1,13 @@
/* Please refer to:
* https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
* These are required to save one copy operation on Python calls.
*
* NOTES
* =================
*
* `py::bind_vector` and similar machinery gives the std container a Python-like
* interface, but without the `<pybind11/stl.h>` copying mechanism. Combined
* with `PYBIND11_MAKE_OPAQUE` this allows the types to be modified with Python,
* and saves one copy operation.
*/

View File

@ -12,8 +12,9 @@
*/ */
py::bind_vector< py::bind_vector<
std::vector<boost::shared_ptr<gtsam::BetweenFactor<gtsam::Pose3> > > >( std::vector<boost::shared_ptr<gtsam::BetweenFactor<gtsam::Pose3>>>>(
m_, "BetweenFactorPose3s"); m_, "BetweenFactorPose3s");
py::bind_vector< py::bind_vector<
std::vector<boost::shared_ptr<gtsam::BetweenFactor<gtsam::Pose2> > > >( std::vector<boost::shared_ptr<gtsam::BetweenFactor<gtsam::Pose2>>>>(
m_, "BetweenFactorPose2s"); m_, "BetweenFactorPose2s");
py::bind_vector<gtsam::Rot3Vector>(m_, "Rot3Vector");

View File

@ -78,7 +78,7 @@ class TestGraphvizFormatting(GtsamTestCase):
graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X graphviz_formatting.paperHorizontalAxis = gtsam.GraphvizFormatting.Axis.X
graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y graphviz_formatting.paperVerticalAxis = gtsam.GraphvizFormatting.Axis.Y
self.assertEqual(self.graph.dot(self.values, self.assertEqual(self.graph.dot(self.values,
writer=graphviz_formatting), formatting=graphviz_formatting),
textwrap.dedent(expected_result)) textwrap.dedent(expected_result))
def test_factor_points(self): def test_factor_points(self):
@ -100,7 +100,7 @@ class TestGraphvizFormatting(GtsamTestCase):
graphviz_formatting.plotFactorPoints = False graphviz_formatting.plotFactorPoints = False
self.assertEqual(self.graph.dot(self.values, self.assertEqual(self.graph.dot(self.values,
writer=graphviz_formatting), formatting=graphviz_formatting),
textwrap.dedent(expected_result)) textwrap.dedent(expected_result))
def test_width_height(self): def test_width_height(self):
@ -127,7 +127,7 @@ class TestGraphvizFormatting(GtsamTestCase):
graphviz_formatting.figureHeightInches = 10 graphviz_formatting.figureHeightInches = 10
self.assertEqual(self.graph.dot(self.values, self.assertEqual(self.graph.dot(self.values,
writer=graphviz_formatting), formatting=graphviz_formatting),
textwrap.dedent(expected_result)) textwrap.dedent(expected_result))

View File

@ -15,27 +15,15 @@ import unittest
import gtsam import gtsam
import numpy as np import numpy as np
from gtsam import Rot3
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
KEY = 0 KEY = 0
MODEL = gtsam.noiseModel.Unit.Create(3) MODEL = gtsam.noiseModel.Unit.Create(3)
def find_Karcher_mean_Rot3(rotations):
"""Find the Karcher mean of given values."""
# Cost function C(R) = \sum PriorFactor(R_i)::error(R)
# No closed form solution.
graph = gtsam.NonlinearFactorGraph()
for R in rotations:
graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL))
initial = gtsam.Values()
initial.insert(KEY, gtsam.Rot3())
result = gtsam.GaussNewtonOptimizer(graph, initial).optimize()
return result.atRot3(KEY)
# Rot3 version # Rot3 version
R = gtsam.Rot3.Expmap(np.array([0.1, 0, 0])) R = Rot3.Expmap(np.array([0.1, 0, 0]))
class TestKarcherMean(GtsamTestCase): class TestKarcherMean(GtsamTestCase):
@ -43,11 +31,23 @@ class TestKarcherMean(GtsamTestCase):
def test_find(self): def test_find(self):
# Check that optimizing for Karcher mean (which minimizes Between distance) # Check that optimizing for Karcher mean (which minimizes Between distance)
# gets correct result. # gets correct result.
rotations = {R, R.inverse()} rotations = gtsam.Rot3Vector([R, R.inverse()])
expected = gtsam.Rot3() expected = Rot3()
actual = find_Karcher_mean_Rot3(rotations) actual = gtsam.FindKarcherMean(rotations)
self.gtsamAssertEquals(expected, actual) self.gtsamAssertEquals(expected, actual)
def test_find_karcher_mean_identity(self):
"""Averaging 3 identity rotations should yield the identity."""
a1Rb1 = Rot3()
a2Rb2 = Rot3()
a3Rb3 = Rot3()
aRb_list = gtsam.Rot3Vector([a1Rb1, a2Rb2, a3Rb3])
aRb_expected = Rot3()
aRb = gtsam.FindKarcherMean(aRb_list)
self.gtsamAssertEquals(aRb, aRb_expected)
def test_factor(self): def test_factor(self):
"""Check that the InnerConstraint factor leaves the mean unchanged.""" """Check that the InnerConstraint factor leaves the mean unchanged."""
# Make a graph with two variables, one between, and one InnerConstraint # Make a graph with two variables, one between, and one InnerConstraint
@ -66,11 +66,11 @@ class TestKarcherMean(GtsamTestCase):
initial = gtsam.Values() initial = gtsam.Values()
initial.insert(1, R.inverse()) initial.insert(1, R.inverse())
initial.insert(2, R) initial.insert(2, R)
expected = find_Karcher_mean_Rot3([R, R.inverse()]) expected = Rot3()
result = gtsam.GaussNewtonOptimizer(graph, initial).optimize() result = gtsam.GaussNewtonOptimizer(graph, initial).optimize()
actual = find_Karcher_mean_Rot3( actual = gtsam.FindKarcherMean(
[result.atRot3(1), result.atRot3(2)]) gtsam.Rot3Vector([result.atRot3(1), result.atRot3(2)]))
self.gtsamAssertEquals(expected, actual) self.gtsamAssertEquals(expected, actual)
self.gtsamAssertEquals( self.gtsamAssertEquals(
R12, result.atRot3(1).between(result.atRot3(2))) R12, result.atRot3(1).between(result.atRot3(2)))