Merge pull request #1095 from borglab/feature/linear_improvements
commit
3e65779421
|
@ -18,6 +18,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#if BOOST_VERSION >= 107400
|
||||
#include <boost/serialization/library_version_type.hpp>
|
||||
#endif
|
||||
#include <boost/serialization/nvp.hpp>
|
||||
#include <boost/serialization/set.hpp>
|
||||
#include <gtsam/base/FastDefaultAllocator.h>
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <string>
|
||||
|
||||
// includes for standard serialization types
|
||||
#include <boost/serialization/version.hpp>
|
||||
#include <boost/serialization/optional.hpp>
|
||||
#include <boost/serialization/shared_ptr.hpp>
|
||||
#include <boost/serialization/vector.hpp>
|
||||
|
|
|
@ -225,13 +225,13 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
|||
|
||||
/* ****************************************************************************/
|
||||
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
|
||||
size_t parent_value) const {
|
||||
size_t frontal) const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"Single value likelihood can only be invoked on single-variable "
|
||||
"conditional");
|
||||
DiscreteValues values;
|
||||
values.emplace(keys_[0], parent_value);
|
||||
values.emplace(keys_[0], frontal);
|
||||
return likelihood(values);
|
||||
}
|
||||
|
||||
|
|
|
@ -177,7 +177,7 @@ class GTSAM_EXPORT DiscreteConditional
|
|||
const DiscreteValues& frontalValues) const;
|
||||
|
||||
/** Single variable version of likelihood. */
|
||||
DecisionTreeFactor::shared_ptr likelihood(size_t parent_value) const;
|
||||
DecisionTreeFactor::shared_ptr likelihood(size_t frontal) const;
|
||||
|
||||
/**
|
||||
* sample
|
||||
|
|
|
@ -224,6 +224,50 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
JacobianFactor::shared_ptr GaussianConditional::likelihood(
|
||||
const VectorValues& frontalValues) const {
|
||||
// Error is |Rx - (d - Sy - Tz - ...)|^2
|
||||
// so when we instantiate x (which has to be completely known) we beget:
|
||||
// |Sy + Tz + ... - (d - Rx)|^2
|
||||
// The noise model just transfers over!
|
||||
|
||||
// Get frontalValues as vector
|
||||
const Vector x =
|
||||
frontalValues.vector(KeyVector(beginFrontals(), endFrontals()));
|
||||
|
||||
// Copy the augmented Jacobian matrix:
|
||||
auto newAb = Ab_;
|
||||
|
||||
// Restrict view to parent blocks
|
||||
newAb.firstBlock() += nrFrontals_;
|
||||
|
||||
// Update right-hand-side (last column)
|
||||
auto last = newAb.matrix().cols() - 1;
|
||||
const auto RR = R().triangularView<Eigen::Upper>();
|
||||
newAb.matrix().col(last) -= RR * x;
|
||||
|
||||
// The keys now do not include the frontal keys:
|
||||
KeyVector newKeys;
|
||||
newKeys.reserve(nrParents());
|
||||
for (auto&& key : parents()) newKeys.push_back(key);
|
||||
|
||||
// Hopefully second newAb copy below is optimized out...
|
||||
return boost::make_shared<JacobianFactor>(newKeys, newAb, model_);
|
||||
}
|
||||
|
||||
/* **************************************************************************/
|
||||
JacobianFactor::shared_ptr GaussianConditional::likelihood(
|
||||
const Vector& frontal) const {
|
||||
if (nrFrontals() != 1)
|
||||
throw std::invalid_argument(
|
||||
"GaussianConditional Single value likelihood can only be invoked on "
|
||||
"single-variable conditional");
|
||||
VectorValues values;
|
||||
values.insert(keys_[0], frontal);
|
||||
return likelihood(values);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
VectorValues GaussianConditional::sample(const VectorValues& parentsValues,
|
||||
std::mt19937_64* rng) const {
|
||||
|
|
|
@ -151,6 +151,13 @@ namespace gtsam {
|
|||
/** Performs transpose backsubstition in place on values */
|
||||
void solveTransposeInPlace(VectorValues& gy) const;
|
||||
|
||||
/** Convert to a likelihood factor by providing value before bar. */
|
||||
JacobianFactor::shared_ptr likelihood(
|
||||
const VectorValues& frontalValues) const;
|
||||
|
||||
/** Single variable version of likelihood. */
|
||||
JacobianFactor::shared_ptr likelihood(const Vector& frontal) const;
|
||||
|
||||
/**
|
||||
* Sample from conditional, zero parent version
|
||||
* Example:
|
||||
|
|
|
@ -33,7 +33,7 @@ namespace gtsam {
|
|||
using boost::adaptors::map_values;
|
||||
using boost::accumulate;
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues::VectorValues(const VectorValues& first, const VectorValues& second)
|
||||
{
|
||||
// Merge using predicate for comparing first of pair
|
||||
|
@ -44,7 +44,7 @@ namespace gtsam {
|
|||
throw invalid_argument("Requested to merge two VectorValues that have one or more variables in common.");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues::VectorValues(const Vector& x, const Dims& dims) {
|
||||
using Pair = pair<const Key, size_t>;
|
||||
size_t j = 0;
|
||||
|
@ -61,7 +61,7 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues::VectorValues(const Vector& x, const Scatter& scatter) {
|
||||
size_t j = 0;
|
||||
for (const SlotEntry& v : scatter) {
|
||||
|
@ -74,7 +74,7 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues VectorValues::Zero(const VectorValues& other)
|
||||
{
|
||||
VectorValues result;
|
||||
|
@ -87,7 +87,7 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues::iterator VectorValues::insert(const std::pair<Key, Vector>& key_value) {
|
||||
std::pair<iterator, bool> result = values_.insert(key_value);
|
||||
if(!result.second)
|
||||
|
@ -97,7 +97,7 @@ namespace gtsam {
|
|||
return result.first;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
void VectorValues::update(const VectorValues& values)
|
||||
{
|
||||
iterator hint = begin();
|
||||
|
@ -115,7 +115,7 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
void VectorValues::insert(const VectorValues& values)
|
||||
{
|
||||
size_t originalSize = size();
|
||||
|
@ -124,14 +124,14 @@ namespace gtsam {
|
|||
throw invalid_argument("Requested to insert a VectorValues into another VectorValues that already contains one or more of its keys.");
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
void VectorValues::setZero()
|
||||
{
|
||||
for(Vector& v: values_ | map_values)
|
||||
v.setZero();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
GTSAM_EXPORT ostream& operator<<(ostream& os, const VectorValues& v) {
|
||||
// Change print depending on whether we are using TBB
|
||||
#ifdef GTSAM_USE_TBB
|
||||
|
@ -150,7 +150,7 @@ namespace gtsam {
|
|||
return os;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
void VectorValues::print(const string& str,
|
||||
const KeyFormatter& formatter) const {
|
||||
cout << str << ": " << size() << " elements\n";
|
||||
|
@ -158,7 +158,7 @@ namespace gtsam {
|
|||
cout.flush();
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
bool VectorValues::equals(const VectorValues& x, double tol) const {
|
||||
if(this->size() != x.size())
|
||||
return false;
|
||||
|
@ -170,7 +170,7 @@ namespace gtsam {
|
|||
return true;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
Vector VectorValues::vector() const {
|
||||
// Count dimensions
|
||||
DenseIndex totalDim = 0;
|
||||
|
@ -187,7 +187,7 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
Vector VectorValues::vector(const Dims& keys) const
|
||||
{
|
||||
// Count dimensions
|
||||
|
@ -203,12 +203,12 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
void VectorValues::swap(VectorValues& other) {
|
||||
this->values_.swap(other.values_);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
namespace internal
|
||||
{
|
||||
bool structureCompareOp(const boost::tuple<VectorValues::value_type,
|
||||
|
@ -219,14 +219,14 @@ namespace gtsam {
|
|||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
bool VectorValues::hasSameStructure(const VectorValues other) const
|
||||
{
|
||||
return accumulate(combine(*this, other)
|
||||
| transformed(internal::structureCompareOp), true, logical_and<bool>());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
double VectorValues::dot(const VectorValues& v) const
|
||||
{
|
||||
if(this->size() != v.size())
|
||||
|
@ -244,12 +244,12 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
double VectorValues::norm() const {
|
||||
return std::sqrt(this->squaredNorm());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
double VectorValues::squaredNorm() const {
|
||||
double sumSquares = 0.0;
|
||||
using boost::adaptors::map_values;
|
||||
|
@ -258,7 +258,7 @@ namespace gtsam {
|
|||
return sumSquares;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues VectorValues::operator+(const VectorValues& c) const
|
||||
{
|
||||
if(this->size() != c.size())
|
||||
|
@ -278,13 +278,13 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues VectorValues::add(const VectorValues& c) const
|
||||
{
|
||||
return *this + c;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues& VectorValues::operator+=(const VectorValues& c)
|
||||
{
|
||||
if(this->size() != c.size())
|
||||
|
@ -301,13 +301,13 @@ namespace gtsam {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues& VectorValues::addInPlace(const VectorValues& c)
|
||||
{
|
||||
return *this += c;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues& VectorValues::addInPlace_(const VectorValues& c)
|
||||
{
|
||||
for(const_iterator j2 = c.begin(); j2 != c.end(); ++j2) {
|
||||
|
@ -320,7 +320,7 @@ namespace gtsam {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues VectorValues::operator-(const VectorValues& c) const
|
||||
{
|
||||
if(this->size() != c.size())
|
||||
|
@ -340,13 +340,13 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues VectorValues::subtract(const VectorValues& c) const
|
||||
{
|
||||
return *this - c;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues operator*(const double a, const VectorValues &v)
|
||||
{
|
||||
VectorValues result;
|
||||
|
@ -359,13 +359,13 @@ namespace gtsam {
|
|||
return result;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues VectorValues::scale(const double a) const
|
||||
{
|
||||
return a * *this;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues& VectorValues::operator*=(double alpha)
|
||||
{
|
||||
for(Vector& v: *this | map_values)
|
||||
|
@ -373,12 +373,43 @@ namespace gtsam {
|
|||
return *this;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
VectorValues& VectorValues::scaleInPlace(double alpha)
|
||||
{
|
||||
return *this *= alpha;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
/* ************************************************************************ */
|
||||
string VectorValues::html(const KeyFormatter& keyFormatter) const {
|
||||
stringstream ss;
|
||||
|
||||
// Print out preamble.
|
||||
ss << "<div>\n<table class='VectorValues'>\n <thead>\n";
|
||||
|
||||
// Print out header row.
|
||||
ss << " <tr><th>Variable</th><th>value</th></tr>\n";
|
||||
|
||||
// Finish header and start body.
|
||||
ss << " </thead>\n <tbody>\n";
|
||||
|
||||
// Print out all rows.
|
||||
#ifdef GTSAM_USE_TBB
|
||||
// TBB uses un-ordered map, so inefficiently order them:
|
||||
std::map<Key, Vector> ordered;
|
||||
for (const auto& kv : *this) ordered.emplace(kv);
|
||||
for (const auto& kv : ordered) {
|
||||
#else
|
||||
for (const auto& kv : *this) {
|
||||
#endif
|
||||
ss << " <tr>";
|
||||
ss << "<th>" << keyFormatter(kv.first) << "</th><td>"
|
||||
<< kv.second.transpose() << "</td>";
|
||||
ss << "</tr>\n";
|
||||
}
|
||||
ss << " </tbody>\n</table>\n</div>";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
|
||||
} // \namespace gtsam
|
||||
|
|
|
@ -34,7 +34,7 @@
|
|||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* This class represents a collection of vector-valued variables associated
|
||||
* VectorValues represents a collection of vector-valued variables associated
|
||||
* each with a unique integer index. It is typically used to store the variables
|
||||
* of a GaussianFactorGraph. Optimizing a GaussianFactorGraph or GaussianBayesNet
|
||||
* returns this class.
|
||||
|
@ -69,7 +69,7 @@ namespace gtsam {
|
|||
* which is a view on the underlying data structure.
|
||||
*
|
||||
* This class is additionally used in gradient descent and dog leg to store the gradient.
|
||||
* \nosubgrouping
|
||||
* @addtogroup linear
|
||||
*/
|
||||
class GTSAM_EXPORT VectorValues {
|
||||
protected:
|
||||
|
@ -344,11 +344,16 @@ namespace gtsam {
|
|||
|
||||
/// @}
|
||||
|
||||
/// @}
|
||||
/// @name Matlab syntactic sugar for linear algebra operations
|
||||
/// @name Wrapper support
|
||||
/// @{
|
||||
|
||||
//inline VectorValues scale(const double a, const VectorValues& c) const { return a * (*this); }
|
||||
/**
|
||||
* @brief Output as a html table.
|
||||
*
|
||||
* @param keyFormatter function that formats keys.
|
||||
*/
|
||||
std::string html(
|
||||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const;
|
||||
|
||||
/// @}
|
||||
|
||||
|
|
|
@ -255,6 +255,7 @@ class VectorValues {
|
|||
|
||||
// enabling serialization functionality
|
||||
void serialize() const;
|
||||
string html() const;
|
||||
};
|
||||
|
||||
#include <gtsam/linear/GaussianFactor.h>
|
||||
|
@ -491,6 +492,9 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
|
|||
// Standard Interface
|
||||
gtsam::Key firstFrontalKey() const;
|
||||
gtsam::VectorValues solve(const gtsam::VectorValues& parents) const;
|
||||
gtsam::JacobianFactor* likelihood(
|
||||
const gtsam::VectorValues& frontalValues) const;
|
||||
gtsam::JacobianFactor* likelihood(Vector frontal) const;
|
||||
gtsam::VectorValues sample(const gtsam::VectorValues& parents) const;
|
||||
gtsam::VectorValues sample() const;
|
||||
|
||||
|
|
|
@ -340,6 +340,33 @@ TEST(GaussianConditional, FromMeanAndStddev) {
|
|||
EXPECT_DOUBLES_EQUAL(expected2, conditional2.error(values), 1e-9);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test likelihood method (conversion to JacobianFactor)
|
||||
TEST(GaussianConditional, likelihood) {
|
||||
Matrix A1 = (Matrix(2, 2) << 1., 2., 3., 4.).finished();
|
||||
const Vector2 b(20, 40), x0(1, 2);
|
||||
const double sigma = 0.01;
|
||||
|
||||
// |x0 - A1 x1 - b|^2
|
||||
auto conditional =
|
||||
GaussianConditional::FromMeanAndStddev(X(0), A1, X(1), b, sigma);
|
||||
|
||||
VectorValues frontalValues;
|
||||
frontalValues.insert(X(0), x0);
|
||||
auto actual1 = conditional.likelihood(frontalValues);
|
||||
CHECK(actual1);
|
||||
|
||||
// |(-A1) x1 - (b - x0)|^2
|
||||
JacobianFactor expected(X(1), -A1, b - x0,
|
||||
noiseModel::Isotropic::Sigma(2, sigma));
|
||||
EXPECT(assert_equal(expected, *actual1, tol));
|
||||
|
||||
// Check single vector version
|
||||
auto actual2 = conditional.likelihood(x0);
|
||||
CHECK(actual2);
|
||||
EXPECT(assert_equal(expected, *actual2, tol));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Test sampling
|
||||
TEST(GaussianConditional, sample) {
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/linear/VectorValues.h>
|
||||
#include <gtsam/inference/LabeledSymbol.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
|
@ -248,6 +248,33 @@ TEST(VectorValues, print)
|
|||
EXPECT(expected == actual.str());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check html representation.
|
||||
TEST(VectorValues, html) {
|
||||
VectorValues vv;
|
||||
using symbol_shorthand::X;
|
||||
vv.insert(X(1), Vector2(2, 3.1));
|
||||
vv.insert(X(2), Vector2(4, 5.2));
|
||||
vv.insert(X(5), Vector2(6, 7.3));
|
||||
vv.insert(X(7), Vector2(8, 9.4));
|
||||
string expected =
|
||||
"<div>\n"
|
||||
"<table class='VectorValues'>\n"
|
||||
" <thead>\n"
|
||||
" <tr><th>Variable</th><th>value</th></tr>\n"
|
||||
" </thead>\n"
|
||||
" <tbody>\n"
|
||||
" <tr><th>x1</th><td> 2 3.1</td></tr>\n"
|
||||
" <tr><th>x2</th><td> 4 5.2</td></tr>\n"
|
||||
" <tr><th>x5</th><td> 6 7.3</td></tr>\n"
|
||||
" <tr><th>x7</th><td> 8 9.4</td></tr>\n"
|
||||
" </tbody>\n"
|
||||
"</table>\n"
|
||||
"</div>";
|
||||
string actual = vv.html();
|
||||
EXPECT(actual == expected);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() { TestResult tr; return TestRegistry::runAllTests(tr); }
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue