gtsam/gtsam/hybrid/HybridGaussianConditional.cpp

376 lines
14 KiB
C++
Raw Normal View History

2022-03-13 05:29:26 +08:00
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridGaussianConditional.cpp
2022-03-13 05:29:26 +08:00
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme
* @author Fan Jiang
2022-03-24 09:31:22 +08:00
* @author Varun Agrawal
* @author Frank Dellaert
2022-03-13 05:29:26 +08:00
* @date Mar 12, 2022
*/
#include <gtsam/base/utilities.h>
2022-06-07 22:07:53 +08:00
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
2022-12-31 01:10:16 +08:00
#include <gtsam/hybrid/HybridValues.h>
2022-03-16 00:50:31 +08:00
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h>
2022-03-24 08:16:05 +08:00
#include <gtsam/linear/GaussianFactorGraph.h>
2022-03-13 23:42:36 +08:00
namespace gtsam {
HybridGaussianConditional::HybridGaussianConditional(
2022-03-16 00:50:31 +08:00
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const HybridGaussianConditional::Conditionals &conditionals)
2022-03-16 00:50:31 +08:00
: BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents),
BaseConditional(continuousFrontals.size()),
conditionals_(conditionals) {
// Calculate logConstant_ as the maximum of the log constants of the
// conditionals, by visiting the decision tree:
logConstant_ = -std::numeric_limits<double>::infinity();
conditionals_.visit(
[this](const GaussianConditional::shared_ptr &conditional) {
2023-01-17 10:56:40 +08:00
if (conditional) {
this->logConstant_ = std::max(
this->logConstant_, conditional->logNormalizationConstant());
}
});
}
2022-03-13 23:42:36 +08:00
2022-05-23 12:29:12 +08:00
/* *******************************************************************************/
2024-09-13 18:20:46 +08:00
const HybridGaussianConditional::Conditionals &
HybridGaussianConditional::conditionals() const {
2022-03-24 08:16:05 +08:00
return conditionals_;
}
/* *******************************************************************************/
HybridGaussianConditional::HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals)
: HybridGaussianConditional(continuousFrontals, continuousParents,
discreteParents,
Conditionals(discreteParents, conditionals)) {}
2022-03-24 08:16:05 +08:00
/* *******************************************************************************/
2024-09-13 18:20:46 +08:00
// TODO(dellaert): This is copy/paste: HybridGaussianConditional should be
// derived from HybridGaussianFactor, no?
GaussianFactorGraphTree HybridGaussianConditional::add(
2024-08-21 03:26:04 +08:00
const GaussianFactorGraphTree &sum) const {
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
2022-03-24 08:16:05 +08:00
};
2024-08-21 03:26:04 +08:00
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
2022-03-24 08:16:05 +08:00
}
/* *******************************************************************************/
2024-09-13 18:20:46 +08:00
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
const {
auto wrap = [this](const GaussianConditional::shared_ptr &gc) {
2024-09-04 05:36:45 +08:00
// First check if conditional has not been pruned
if (gc) {
const double Cgm_Kgcm =
this->logConstant_ - gc->logNormalizationConstant();
// If there is a difference in the covariances, we need to account for
// that since the error is dependent on the mode.
if (Cgm_Kgcm > 0.0) {
// We add a constant factor which will be used when computing
// the probability of the discrete variables.
Vector c(1);
c << std::sqrt(2.0 * Cgm_Kgcm);
auto constantFactor = std::make_shared<JacobianFactor>(c);
return GaussianFactorGraph{gc, constantFactor};
}
}
2024-08-21 03:26:04 +08:00
return GaussianFactorGraph{gc};
2022-03-24 08:16:05 +08:00
};
2024-08-21 03:26:04 +08:00
return {conditionals_, wrap};
2022-03-24 08:16:05 +08:00
}
/* *******************************************************************************/
size_t HybridGaussianConditional::nrComponents() const {
2022-06-07 22:07:53 +08:00
size_t total = 0;
conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) {
if (node) total += 1;
});
return total;
}
/* *******************************************************************************/
GaussianConditional::shared_ptr HybridGaussianConditional::operator()(
2022-12-22 10:03:51 +08:00
const DiscreteValues &discreteValues) const {
auto &ptr = conditionals_(discreteValues);
2022-06-07 22:07:53 +08:00
if (!ptr) return nullptr;
2023-01-18 06:39:55 +08:00
auto conditional = std::dynamic_pointer_cast<GaussianConditional>(ptr);
2022-06-07 22:07:53 +08:00
if (conditional)
return conditional;
else
throw std::logic_error(
"A HybridGaussianConditional unexpectedly contained a non-conditional");
2022-06-07 22:07:53 +08:00
}
/* *******************************************************************************/
2024-09-13 18:20:46 +08:00
bool HybridGaussianConditional::equals(const HybridFactor &lf,
double tol) const {
2022-05-29 06:03:52 +08:00
const This *e = dynamic_cast<const This *>(&lf);
2023-01-02 05:36:46 +08:00
if (e == nullptr) return false;
// This will return false if either conditionals_ is empty or e->conditionals_
// is empty, but not if both are empty or both are not empty:
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
2023-01-03 16:47:34 +08:00
2023-01-02 05:36:46 +08:00
// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return f1->equals(*(f2), tol);
});
2022-03-13 23:42:36 +08:00
}
2022-05-23 12:29:12 +08:00
/* *******************************************************************************/
void HybridGaussianConditional::print(const std::string &s,
2024-09-13 18:20:46 +08:00
const KeyFormatter &formatter) const {
2022-12-03 09:13:46 +08:00
std::cout << (s.empty() ? "" : s + "\n");
if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Discrete ";
if (isHybrid()) std::cout << "Hybrid ";
2022-03-13 23:42:36 +08:00
BaseConditional::print("", formatter);
2022-06-08 02:11:49 +08:00
std::cout << " Discrete Keys = ";
for (auto &dk : discreteKeys()) {
2022-03-13 23:42:36 +08:00
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
}
2024-08-25 13:50:57 +08:00
std::cout << std::endl
<< " logNormalizationConstant: " << logConstant_ << std::endl
<< std::endl;
conditionals_.print(
2022-03-16 00:50:31 +08:00
"", [&](Key k) { return formatter(k); },
[&](const GaussianConditional::shared_ptr &gf) -> std::string {
RedirectCout rd;
2022-09-16 22:48:08 +08:00
if (gf && !gf->empty()) {
2022-03-16 00:50:31 +08:00
gf->print("", formatter);
2022-09-16 22:48:08 +08:00
return rd.str();
} else {
return "nullptr";
}
});
2022-03-13 23:42:36 +08:00
}
/* ************************************************************************* */
KeyVector HybridGaussianConditional::continuousParents() const {
// Get all parent keys:
const auto range = parents();
KeyVector continuousParentKeys(range.begin(), range.end());
// Loop over all discrete keys:
for (const auto &discreteKey : discreteKeys()) {
const Key key = discreteKey.first;
// remove that key from continuousParentKeys:
continuousParentKeys.erase(std::remove(continuousParentKeys.begin(),
continuousParentKeys.end(), key),
continuousParentKeys.end());
}
return continuousParentKeys;
}
/* ************************************************************************* */
2024-09-13 18:20:46 +08:00
bool HybridGaussianConditional::allFrontalsGiven(
const VectorValues &given) const {
for (auto &&kv : given) {
if (given.find(kv.first) == given.end()) {
return false;
}
}
return true;
}
/* ************************************************************************* */
2024-09-13 17:59:09 +08:00
std::shared_ptr<HybridGaussianFactor> HybridGaussianConditional::likelihood(
const VectorValues &given) const {
if (!allFrontalsGiven(given)) {
throw std::runtime_error(
2024-09-13 18:20:46 +08:00
"HybridGaussianConditional::likelihood: given values are missing some "
"frontals.");
}
const DiscreteKeys discreteParentKeys = discreteKeys();
const KeyVector continuousParentKeys = continuousParents();
const HybridGaussianFactor::FactorValuePairs likelihoods(
2024-09-14 05:37:29 +08:00
conditionals_,
[&](const GaussianConditional::shared_ptr &conditional)
-> GaussianFactorValuePair {
const auto likelihood_m = conditional->likelihood(given);
const double Cgm_Kgcm =
logConstant_ - conditional->logNormalizationConstant();
if (Cgm_Kgcm == 0.0) {
2024-09-14 05:37:29 +08:00
return {likelihood_m, 0.0};
} else {
// Add a constant to the likelihood in case the noise models
// are not all equal.
double c = 2.0 * Cgm_Kgcm;
return {likelihood_m, c};
}
});
2024-09-13 17:59:09 +08:00
return std::make_shared<HybridGaussianFactor>(
continuousParentKeys, discreteParentKeys, likelihoods);
}
2022-10-12 00:10:02 +08:00
/* ************************************************************************* */
2022-12-31 01:10:16 +08:00
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
2022-10-12 00:10:02 +08:00
std::set<DiscreteKey> s;
2022-12-31 01:10:16 +08:00
s.insert(discreteKeys.begin(), discreteKeys.end());
2022-10-12 00:10:02 +08:00
return s;
}
/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
*
* @param discreteProbs The probabilities of only discrete keys.
2022-10-12 00:10:02 +08:00
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
*/
std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) {
2022-10-12 00:35:58 +08:00
// Get the discrete keys as sets for the decision tree
// and the gaussian mixture.
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
2022-10-12 00:35:58 +08:00
auto gaussianMixtureKeySet = DiscreteKeysAsSet(this->discreteKeys());
auto pruner = [discreteProbs, discreteProbsKeySet, gaussianMixtureKeySet](
2022-10-12 00:35:58 +08:00
const Assignment<Key> &choices,
const GaussianConditional::shared_ptr &conditional)
-> GaussianConditional::shared_ptr {
// typecast so we can use this to get probability value
2022-12-31 01:10:16 +08:00
const DiscreteValues values(choices);
2022-10-12 00:10:02 +08:00
// Case where the gaussian mixture has the same
// discrete keys as the decision tree.
if (gaussianMixtureKeySet == discreteProbsKeySet) {
if (discreteProbs(values) == 0.0) {
2022-10-12 00:10:02 +08:00
// empty aka null pointer
std::shared_ptr<GaussianConditional> null;
2022-10-12 00:10:02 +08:00
return null;
} else {
return conditional;
}
} else {
2022-10-12 00:10:02 +08:00
std::vector<DiscreteKey> set_diff;
std::set_difference(
discreteProbsKeySet.begin(), discreteProbsKeySet.end(),
gaussianMixtureKeySet.begin(), gaussianMixtureKeySet.end(),
std::back_inserter(set_diff));
2022-10-12 00:10:02 +08:00
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
DiscreteValues augmented_values(values);
2022-12-31 14:53:59 +08:00
augmented_values.insert(assignment);
2022-10-12 00:10:02 +08:00
// If any one of the sub-branches are non-zero,
// we need this conditional.
if (discreteProbs(augmented_values) > 0.0) {
2022-10-12 00:10:02 +08:00
return conditional;
}
}
// If we are here, it means that all the sub-branches are 0,
// so we prune.
return nullptr;
}
};
2022-10-12 00:10:02 +08:00
return pruner;
}
/* *******************************************************************************/
void HybridGaussianConditional::prune(const DecisionTreeFactor &discreteProbs) {
2022-10-12 00:10:02 +08:00
// Functional which loops over all assignments and create a set of
// GaussianConditionals
auto pruner = prunerFunc(discreteProbs);
auto pruned_conditionals = conditionals_.apply(pruner);
conditionals_.root_ = pruned_conditionals.root_;
}
2022-11-02 08:19:36 +08:00
/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridGaussianConditional::logProbability(
2022-12-22 10:03:51 +08:00
const VectorValues &continuousValues) const {
// functor to calculate (double) logProbability value from
2023-01-11 13:55:18 +08:00
// GaussianConditional.
auto probFunc =
2022-12-22 10:03:51 +08:00
[continuousValues](const GaussianConditional::shared_ptr &conditional) {
2022-11-02 14:53:51 +08:00
if (conditional) {
2023-01-11 13:55:18 +08:00
return conditional->logProbability(continuousValues);
2022-11-02 14:53:51 +08:00
} else {
// Return arbitrarily small logProbability if conditional is null
2022-12-22 11:52:34 +08:00
// Conditional is null if it is pruned out.
return -1e20;
2022-11-02 14:53:51 +08:00
}
2022-11-02 08:19:36 +08:00
};
return DecisionTree<Key, double>(conditionals_, probFunc);
2022-11-02 08:19:36 +08:00
}
2024-09-05 03:18:27 +08:00
/* ************************************************************************* */
double HybridGaussianConditional::conditionalError(
2024-09-05 03:18:27 +08:00
const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const {
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridGaussianConditional::errorTree(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
2024-09-05 03:18:27 +08:00
return conditionalError(conditional, continuousValues);
};
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree;
}
2023-01-13 05:29:28 +08:00
/* *******************************************************************************/
double HybridGaussianConditional::error(const HybridValues &values) const {
2023-01-13 05:29:28 +08:00
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
2024-09-05 03:18:27 +08:00
return conditionalError(conditional, values.continuous());
2023-01-13 05:29:28 +08:00
}
2022-11-02 08:19:36 +08:00
/* *******************************************************************************/
2024-09-13 18:20:46 +08:00
double HybridGaussianConditional::logProbability(
const HybridValues &values) const {
2022-12-31 01:10:16 +08:00
auto conditional = conditionals_(values.discrete());
2023-01-11 13:55:18 +08:00
return conditional->logProbability(values.continuous());
2022-11-02 08:19:36 +08:00
}
2023-01-15 02:23:21 +08:00
/* *******************************************************************************/
double HybridGaussianConditional::evaluate(const HybridValues &values) const {
2023-01-15 02:23:21 +08:00
auto conditional = conditionals_(values.discrete());
return conditional->evaluate(values.continuous());
}
2022-03-26 07:14:00 +08:00
} // namespace gtsam