run formatting and rename wrappedFactors to asGaussianFactorGraphTree

release/4.3a0
Varun Agrawal 2022-05-24 11:36:12 -04:00
parent e325cd1c4b
commit eb074e7424
13 changed files with 103 additions and 82 deletions

View File

@ -36,7 +36,8 @@ GaussianMixtureConditional::GaussianMixtureConditional(
conditionals_(conditionals) {}
/* *******************************************************************************/
const GaussianMixtureConditional::Conditionals &GaussianMixtureConditional::conditionals() {
const GaussianMixtureConditional::Conditionals &
GaussianMixtureConditional::conditionals() {
return conditionals_;
}
@ -47,8 +48,8 @@ GaussianMixtureConditional GaussianMixtureConditional::FromConditionalList(
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
Conditionals dt(discreteParents, conditionalsList);
return GaussianMixtureConditional(continuousFrontals, continuousParents, discreteParents,
dt);
return GaussianMixtureConditional(continuousFrontals, continuousParents,
discreteParents, dt);
}
/* *******************************************************************************/
@ -60,12 +61,13 @@ GaussianMixtureConditional::Sum GaussianMixtureConditional::add(
result.push_back(graph2);
return result;
};
const Sum wrapped = asGraph();
return sum.empty() ? wrapped : sum.apply(wrapped, add);
const Sum tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}
/* *******************************************************************************/
GaussianMixtureConditional::Sum GaussianMixtureConditional::asGraph() const {
GaussianMixtureConditional::Sum
GaussianMixtureConditional::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianFactor::shared_ptr &factor) {
GaussianFactorGraph result;
result.push_back(factor);
@ -74,14 +76,15 @@ GaussianMixtureConditional::Sum GaussianMixtureConditional::asGraph() const {
return {conditionals_, lambda};
}
/* TODO(fan): this (for Testable) is not implemented! */
bool GaussianMixtureConditional::equals(const HybridFactor &lf, double tol) const {
return false;
/* *******************************************************************************/
bool GaussianMixtureConditional::equals(const HybridFactor &lf,
double tol) const {
return BaseFactor::equals(lf, tol);
}
/* *******************************************************************************/
void GaussianMixtureConditional::print(const std::string &s,
const KeyFormatter &formatter) const {
const KeyFormatter &formatter) const {
std::cout << s << ": ";
if (isContinuous_) std::cout << "Cont. ";
if (isDiscrete_) std::cout << "Disc. ";

View File

@ -25,8 +25,9 @@
#include <gtsam/linear/GaussianConditional.h>
namespace gtsam {
class GaussianMixtureConditional : public HybridFactor,
public Conditional<HybridFactor, GaussianMixtureConditional> {
class GaussianMixtureConditional
: public HybridFactor,
public Conditional<HybridFactor, GaussianMixtureConditional> {
public:
using This = GaussianMixtureConditional;
using shared_ptr = boost::shared_ptr<GaussianMixtureConditional>;
@ -47,9 +48,9 @@ class GaussianMixtureConditional : public HybridFactor,
* @param conditionals a decision tree of GaussianConditionals.
*/
GaussianMixtureConditional(const KeyVector &continuousFrontals,
const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const Conditionals &conditionals);
const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const Conditionals &conditionals);
using Sum = DecisionTree<Key, GaussianFactorGraph>;
@ -60,30 +61,32 @@ class GaussianMixtureConditional : public HybridFactor,
*/
Sum add(const Sum &sum) const;
/**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
*/
Sum asGraph() const;
/**
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
* @param discreteParents Discrete parents variables
* @param conditionals List of conditionals
*/
/**
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
*
* @param continuousFrontals The continuous frontal variables
* @param continuousParents The continuous parent variables
* @param discreteParents Discrete parents variables
* @param conditionals List of conditionals
*/
static This FromConditionalList(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals);
/* TODO: this is only a stub */
/// Test equality with base HybridFactor
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
/* print utility */
/* print utility */
void print(
const std::string &s = "GaussianMixtureConditional\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
protected:
/**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
*/
Sum asGaussianFactorGraphTree() const;
};
} // namespace gtsam

View File

@ -26,10 +26,13 @@
namespace gtsam {
/* *******************************************************************************/
GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const Factors &factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {}
/* *******************************************************************************/
bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const {
return false;
}
@ -43,7 +46,6 @@ GaussianMixtureFactor GaussianMixtureFactor::FromFactorList(
return GaussianMixtureFactor(continuousKeys, discreteKeys, dt);
}
/* *******************************************************************************/
void GaussianMixtureFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
@ -74,12 +76,13 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
result.push_back(graph2);
return result;
};
const Sum wrapped = wrappedFactors();
return sum.empty() ? wrapped : sum.apply(wrapped, add);
const Sum tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}
/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::wrappedFactors() const {
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const GaussianFactor::shared_ptr &factor) {
GaussianFactorGraph result;
result.push_back(factor);

View File

@ -57,13 +57,20 @@ class GaussianMixtureFactor : public HybridFactor {
Sum add(const Sum &sum) const;
Sum wrappedFactors() const;
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
void print(
const std::string &s = "HybridFactor\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
protected:
/**
* @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs.
*
* @return Sum (DecisionTree<Key, GaussianFactorGraph)
*/
Sum asGaussianFactorGraphTree() const;
};
} // namespace gtsam

View File

@ -17,10 +17,10 @@
*/
#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/inference/BayesTree-inst.h>
#include <gtsam/inference/BayesTreeCliqueBase-inst.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridBayesNet.h>
namespace gtsam {

View File

@ -73,8 +73,8 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/* This does special stuff for the hybrid case */
template <class CLIQUE>
class BayesTreeOrphanWrapper<
CLIQUE,
typename std::enable_if<boost::is_same<CLIQUE, HybridBayesTreeClique>::value> >
CLIQUE, typename std::enable_if<
boost::is_same<CLIQUE, HybridBayesTreeClique>::value> >
: public CLIQUE::ConditionalType {
public:
typedef CLIQUE CliqueType;

View File

@ -93,7 +93,8 @@ class GTSAM_EXPORT HybridConditional
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional);
HybridConditional(boost::shared_ptr<GaussianMixtureConditional> gaussianMixture);
HybridConditional(
boost::shared_ptr<GaussianMixtureConditional> gaussianMixture);
GaussianMixtureConditional::shared_ptr asMixture() {
if (!isHybrid_) throw std::invalid_argument("Not a mixture");

View File

@ -19,15 +19,16 @@
#include <gtsam/hybrid/HybridDiscreteFactor.h>
#include <boost/make_shared.hpp>
#include "gtsam/discrete/DecisionTreeFactor.h"
namespace gtsam {
// TODO(fan): THIS IS VERY VERY DIRTY! We need to get DiscreteFactor right!
HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other)
: Base(boost::dynamic_pointer_cast<DecisionTreeFactor>(other)->discreteKeys()) {
: Base(boost::dynamic_pointer_cast<DecisionTreeFactor>(other)
->discreteKeys()) {
inner = other;
}
HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)

View File

@ -24,8 +24,9 @@
namespace gtsam {
/**
* A HybridDiscreteFactor is a wrapper for DiscreteFactor, so we hide the
* implementation of DiscreteFactor, and thus avoiding diamond inheritance.
* A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows
* us to hide the implementation of DiscreteFactor and thus avoid diamond
* inheritance.
*/
class HybridDiscreteFactor : public HybridFactor {
public:

View File

@ -15,8 +15,8 @@
* @author Fan Jiang
*/
#include <gtsam/inference/EliminationTree-inst.h>
#include <gtsam/hybrid/HybridEliminationTree.h>
#include <gtsam/inference/EliminationTree-inst.h>
namespace gtsam {
@ -26,18 +26,17 @@ template class EliminationTree<HybridBayesNet, HybridFactorGraph>;
/* ************************************************************************* */
HybridEliminationTree::HybridEliminationTree(
const HybridFactorGraph& factorGraph, const VariableIndex& structure,
const Ordering& order) :
Base(factorGraph, structure, order) {}
const Ordering& order)
: Base(factorGraph, structure, order) {}
/* ************************************************************************* */
HybridEliminationTree::HybridEliminationTree(
const HybridFactorGraph& factorGraph, const Ordering& order) :
Base(factorGraph, order) {}
const HybridFactorGraph& factorGraph, const Ordering& order)
: Base(factorGraph, order) {}
/* ************************************************************************* */
bool HybridEliminationTree::equals(const This& other, double tol) const
{
bool HybridEliminationTree::equals(const This& other, double tol) const {
return Base::equals(other, tol);
}
}
} // namespace gtsam

View File

@ -64,9 +64,8 @@ HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys)
isDiscrete_(true),
discreteKeys_(discreteKeys) {}
void HybridFactor::print(
const std::string &s,
const KeyFormatter &formatter) const {
void HybridFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << s;
if (isContinuous_) std::cout << "Cont. ";
if (isDiscrete_) std::cout << "Disc. ";

View File

@ -229,8 +229,8 @@ EliminateHybrid(const HybridFactorGraph &factors, const Ordering &frontalKeys) {
if (p) {
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
} else {
// It is an orphan wrapper
if (DEBUG) std::cout << "Got an orphan wrapper conditional\n";
// It is an orphan wrapped conditional
if (DEBUG) std::cout << "Got an orphan conditional\n";
}
}
}

View File

@ -17,8 +17,8 @@
#pragma once
#include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/hybrid/HybridFactorGraph.h>
#include <gtsam/inference/JunctionTree.h>
namespace gtsam {
@ -27,41 +27,45 @@ namespace gtsam {
class HybridEliminationTree;
/**
* An EliminatableClusterTree, i.e., a set of variable clusters with factors, arranged in a tree,
* with the additional property that it represents the clique tree associated with a Bayes net.
* An EliminatableClusterTree, i.e., a set of variable clusters with factors,
* arranged in a tree, with the additional property that it represents the
* clique tree associated with a Bayes net.
*
* In GTSAM a junction tree is an intermediate data structure in multifrontal
* variable elimination. Each node is a cluster of factors, along with a
* clique of variables that are eliminated all at once. In detail, every node k represents
* a clique (maximal fully connected subset) of an associated chordal graph, such as a
* chordal Bayes net resulting from elimination.
* clique of variables that are eliminated all at once. In detail, every node k
* represents a clique (maximal fully connected subset) of an associated chordal
* graph, such as a chordal Bayes net resulting from elimination.
*
* The difference with the BayesTree is that a JunctionTree stores factors, whereas a
* BayesTree stores conditionals, that are the product of eliminating the factors in the
* corresponding JunctionTree cliques.
* The difference with the BayesTree is that a JunctionTree stores factors,
* whereas a BayesTree stores conditionals, that are the product of eliminating
* the factors in the corresponding JunctionTree cliques.
*
* The tree structure and elimination method are exactly analogous to the EliminationTree,
* except that in the JunctionTree, at each node multiple variables are eliminated at a time.
* The tree structure and elimination method are exactly analogous to the
* EliminationTree, except that in the JunctionTree, at each node multiple
* variables are eliminated at a time.
*
* \addtogroup Multifrontal
* \nosubgrouping
*/
class GTSAM_EXPORT HybridJunctionTree :
public JunctionTree<HybridBayesTree, HybridFactorGraph> {
class GTSAM_EXPORT HybridJunctionTree
: public JunctionTree<HybridBayesTree, HybridFactorGraph> {
public:
typedef JunctionTree<HybridBayesTree, HybridFactorGraph> Base; ///< Base class
typedef HybridJunctionTree This; ///< This class
typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
typedef JunctionTree<HybridBayesTree, HybridFactorGraph>
Base; ///< Base class
typedef HybridJunctionTree This; ///< This class
typedef boost::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
/**
* Build the elimination tree of a factor graph using precomputed column structure.
* @param factorGraph The factor graph for which to build the elimination tree
* @param structure The set of factors involving each variable. If this is not
* precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
* named constructor instead.
* @return The elimination tree
*/
* Build the elimination tree of a factor graph using precomputed column
* structure.
* @param factorGraph The factor graph for which to build the elimination tree
* @param structure The set of factors involving each variable. If this is
* not precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
* named constructor instead.
* @return The elimination tree
*/
HybridJunctionTree(const HybridEliminationTree& eliminationTree);
};
}
} // namespace gtsam