Merge pull request #2006 from borglab/feature/k_best_fg

Search for k-best in factor graph
release/4.3a0
Frank Dellaert 2025-01-28 17:06:08 -05:00 committed by GitHub
commit 1daca1946d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 385 additions and 222 deletions

View File

@ -16,19 +16,35 @@
* @author Richard Roberts
*/
#include <gtsam/inference/JunctionTree-inst.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/inference/JunctionTree-inst.h>
namespace gtsam {
// Instantiate base classes
template class EliminatableClusterTree<DiscreteBayesTree, DiscreteFactorGraph>;
template class JunctionTree<DiscreteBayesTree, DiscreteFactorGraph>;
// Instantiate base classes
template class EliminatableClusterTree<DiscreteBayesTree, DiscreteFactorGraph>;
template class JunctionTree<DiscreteBayesTree, DiscreteFactorGraph>;
/* ************************************************************************* */
DiscreteJunctionTree::DiscreteJunctionTree(
const DiscreteEliminationTree& eliminationTree) :
Base(eliminationTree) {}
/* ************************************************************************* */
DiscreteJunctionTree::DiscreteJunctionTree(
const DiscreteEliminationTree& eliminationTree)
: Base(eliminationTree) {}
/* ************************************************************************* */
void DiscreteJunctionTree::print(const std::string& s,
const KeyFormatter& keyFormatter) const {
auto visitor = [&keyFormatter](
const std::shared_ptr<DiscreteJunctionTree::Cluster>& node,
const std::string& parentString) {
// Print the current node
node->print(parentString + "-", keyFormatter);
node->factors.print(parentString + "-", keyFormatter);
std::cout << std::endl;
return parentString + "| "; // Increment the indentation
};
std::string parentString = s;
treeTraversal::DepthFirstForest(*this, parentString, visitor);
}
} // namespace gtsam

View File

@ -18,54 +18,71 @@
#pragma once
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteBayesTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/inference/JunctionTree.h>
namespace gtsam {
// Forward declarations
class DiscreteEliminationTree;
// Forward declarations
class DiscreteEliminationTree;
/**
* An EliminatableClusterTree, i.e., a set of variable clusters with factors,
* arranged in a tree, with the additional property that it represents the
* clique tree associated with a Bayes net.
*
* In GTSAM a junction tree is an intermediate data structure in multifrontal
* variable elimination. Each node is a cluster of factors, along with a
* clique of variables that are eliminated all at once. In detail, every node k
* represents a clique (maximal fully connected subset) of an associated chordal
* graph, such as a chordal Bayes net resulting from elimination.
*
* The difference with the BayesTree is that a JunctionTree stores factors,
* whereas a BayesTree stores conditionals, that are the product of eliminating
* the factors in the corresponding JunctionTree cliques.
*
* The tree structure and elimination method are exactly analogous to the
* EliminationTree, except that in the JunctionTree, at each node multiple
* variables are eliminated at a time.
*
* \ingroup Multifrontal
* @ingroup discrete
* \nosubgrouping
*/
class GTSAM_EXPORT DiscreteJunctionTree
: public JunctionTree<DiscreteBayesTree, DiscreteFactorGraph> {
public:
typedef JunctionTree<DiscreteBayesTree, DiscreteFactorGraph>
Base; ///< Base class
typedef DiscreteJunctionTree This; ///< This class
typedef std::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
/// @name Constructors
/// @{
/**
* An EliminatableClusterTree, i.e., a set of variable clusters with factors, arranged in a tree,
* with the additional property that it represents the clique tree associated with a Bayes net.
*
* In GTSAM a junction tree is an intermediate data structure in multifrontal
* variable elimination. Each node is a cluster of factors, along with a
* clique of variables that are eliminated all at once. In detail, every node k represents
* a clique (maximal fully connected subset) of an associated chordal graph, such as a
* chordal Bayes net resulting from elimination.
*
* The difference with the BayesTree is that a JunctionTree stores factors, whereas a
* BayesTree stores conditionals, that are the product of eliminating the factors in the
* corresponding JunctionTree cliques.
*
* The tree structure and elimination method are exactly analogous to the EliminationTree,
* except that in the JunctionTree, at each node multiple variables are eliminated at a time.
*
* \ingroup Multifrontal
* @ingroup discrete
* \nosubgrouping
* Build the elimination tree of a factor graph using precomputed column
* structure.
* @param factorGraph The factor graph for which to build the elimination tree
* @param structure The set of factors involving each variable. If this is
* not precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
* named constructor instead.
* @return The elimination tree
*/
class GTSAM_EXPORT DiscreteJunctionTree :
public JunctionTree<DiscreteBayesTree, DiscreteFactorGraph> {
public:
typedef JunctionTree<DiscreteBayesTree, DiscreteFactorGraph> Base; ///< Base class
typedef DiscreteJunctionTree This; ///< This class
typedef std::shared_ptr<This> shared_ptr; ///< Shared pointer to this class
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
/**
* Build the elimination tree of a factor graph using precomputed column structure.
* @param factorGraph The factor graph for which to build the elimination tree
* @param structure The set of factors involving each variable. If this is not
* precomputed, you can call the Create(const FactorGraph<DERIVEDFACTOR>&)
* named constructor instead.
* @return The elimination tree
*/
DiscreteJunctionTree(const DiscreteEliminationTree& eliminationTree);
};
/// @}
/// @name Testable
/// @{
/// typedef for wrapper:
using DiscreteCluster = DiscreteJunctionTree::Cluster;
}
/** Print the tree to cout */
void print(const std::string& name = "DiscreteJunctionTree: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const;
/// @}
};
/// typedef for wrapper:
using DiscreteCluster = DiscreteJunctionTree::Cluster;
} // namespace gtsam

View File

@ -9,38 +9,37 @@
* -------------------------------------------------------------------------- */
/*
/**
* DiscreteSearch.cpp
*
* @date January, 2025
* @author Frank Dellaert
*/
#include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/discrete/DiscreteSearch.h>
namespace gtsam {
using Slot = DiscreteSearch::Slot;
using Solution = DiscreteSearch::Solution;
/**
* @brief Represents a node in the search tree for discrete search algorithms.
*
* @details Each SearchNode contains a partial assignment of discrete variables,
* the current error, a bound on the final error, and the index of the next
* conditional to be assigned.
/*
* A SearchNode represents a node in the search tree for the search algorithm.
* Each SearchNode contains a partial assignment of discrete variables, the
* current error, a bound on the final error, and the index of the next
* slot to be assigned.
*/
struct SearchNode {
DiscreteValues assignment; ///< Partial assignment of discrete variables.
double error; ///< Current error for the partial assignment.
double bound; ///< Lower bound on the final error for unassigned variables.
int nextConditional; ///< Index of the next conditional to be assigned.
DiscreteValues assignment; // Partial assignment of discrete variables.
double error; // Current error for the partial assignment.
double bound; // Lower bound on the final error
std::optional<size_t> next; // Index of the next slot to be assigned.
/**
* @brief Construct the root node for the search.
*/
static SearchNode Root(size_t numConditionals, double bound) {
return {DiscreteValues(), 0.0, bound,
static_cast<int>(numConditionals) - 1};
// Construct the root node for the search.
static SearchNode Root(size_t numSlots, double bound) {
return {DiscreteValues(), 0.0, bound, 0};
}
struct Compare {
@ -49,40 +48,22 @@ struct SearchNode {
}
};
/**
* @brief Checks if the node represents a complete assignment.
*
* @return True if all variables have been assigned, false otherwise.
*/
inline bool isComplete() const { return nextConditional < 0; }
// Checks if the node represents a complete assignment.
inline bool isComplete() const { return !next; }
/**
* @brief Expands the node by assigning the next variable.
*
* @param conditional The discrete conditional representing the next variable
* to be assigned.
* @param fa The frontal assignment for the next variable.
* @return A new SearchNode representing the expanded state.
*/
SearchNode expand(const DiscreteConditional& conditional,
const DiscreteValues& fa) const {
// Expands the node by assigning the next variable(s).
SearchNode expand(const DiscreteValues& fa, const Slot& slot,
std::optional<size_t> nextSlot) const {
// Combine the new frontal assignment with the current partial assignment
DiscreteValues newAssignment = assignment;
for (auto& [key, value] : fa) {
newAssignment[key] = value;
}
return {newAssignment, error + conditional.error(newAssignment), 0.0,
nextConditional - 1};
double errorSoFar = error + slot.factor->error(newAssignment);
return {newAssignment, errorSoFar, errorSoFar + slot.heuristic, nextSlot};
}
/**
* @brief Prints the SearchNode to an output stream.
*
* @param os The output stream.
* @param node The SearchNode to be printed.
* @return The output stream.
*/
// Prints the SearchNode to an output stream.
friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
return os;
@ -95,17 +76,20 @@ struct CompareSolution {
}
};
// Define the Solutions class
/*
* A Solutions object maintains a priority queue of the best solutions found
* during the search. The priority queue is limited to a maximum size, and
* solutions are only added if they are better than the worst solution.
*/
class Solutions {
private:
size_t maxSize_;
size_t maxSize_; // Maximum number of solutions to keep
std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_;
public:
Solutions(size_t maxSize) : maxSize_(maxSize) {}
/// Add a solution to the priority queue, possibly evicting the worst one.
/// Return true if we added the solution.
// Add a solution to the priority queue, possibly evicting the worst one.
// Return true if we added the solution.
bool maybeAdd(double error, const DiscreteValues& assignment) {
const bool full = pq_.size() == maxSize_;
if (full && error >= pq_.top().error) return false;
@ -114,7 +98,7 @@ class Solutions {
return true;
}
/// Check if we have any solutions
// Check if we have any solutions
bool empty() const { return pq_.empty(); }
// Method to print all solutions
@ -128,9 +112,9 @@ class Solutions {
return os;
}
/// Check if (partial) solution with given bound can be pruned. If we have
/// room, we never prune. Otherwise, prune if lower bound on error is worse
/// than our current worst error.
// Check if (partial) solution with given bound can be pruned. If we have
// room, we never prune. Otherwise, prune if lower bound on error is worse
// than our current worst error.
bool prune(double bound) const {
if (pq_.size() < maxSize_) return false;
return bound >= pq_.top().error;
@ -150,97 +134,155 @@ class Solutions {
}
};
// Get the factor associated with a node, possibly product of factors.
template <typename NodeType>
static DiscreteFactor::shared_ptr getFactor(const NodeType& node) {
const auto& factors = node->factors;
return factors.size() == 1 ? factors.back()
: DiscreteFactorGraph(factors).product();
}
DiscreteSearch::DiscreteSearch(const DiscreteEliminationTree& etree) {
using NodePtr = std::shared_ptr<DiscreteEliminationTree::Node>;
auto visitor = [this](const NodePtr& node, int data) {
const DiscreteFactor::shared_ptr factor = getFactor(node);
const size_t cardinality = factor->cardinality(node->key);
std::vector<std::pair<Key, size_t>> pairs{{node->key, cardinality}};
const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
slots_.emplace_back(std::move(slot));
return data + 1;
};
int data = 0; // unused
treeTraversal::DepthFirstForest(etree, data, visitor);
lowerBound_ = computeHeuristic();
}
DiscreteSearch::DiscreteSearch(const DiscreteJunctionTree& junctionTree) {
using NodePtr = std::shared_ptr<DiscreteJunctionTree::Cluster>;
auto visitor = [this](const NodePtr& cluster, int data) {
const auto factor = getFactor(cluster);
std::vector<std::pair<Key, size_t>> pairs;
for (Key key : cluster->orderedFrontalKeys) {
pairs.emplace_back(key, factor->cardinality(key));
}
const Slot slot{factor, DiscreteValues::CartesianProduct(pairs), 0.0};
slots_.emplace_back(std::move(slot));
return data + 1;
};
int data = 0; // unused
treeTraversal::DepthFirstForest(junctionTree, data, visitor);
lowerBound_ = computeHeuristic();
}
DiscreteSearch DiscreteSearch::FromFactorGraph(
const DiscreteFactorGraph& factorGraph, const Ordering& ordering,
bool buildJunctionTree) {
const DiscreteEliminationTree etree(factorGraph, ordering);
if (buildJunctionTree) {
const DiscreteJunctionTree junctionTree(etree);
return DiscreteSearch(junctionTree);
} else {
return DiscreteSearch(etree);
}
}
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
std::vector<DiscreteConditional::shared_ptr> conditionals;
for (auto& factor : bayesNet) conditionals_.push_back(factor);
costToGo_ = computeCostToGo(conditionals_);
slots_.reserve(bayesNet.size());
for (auto& conditional : bayesNet) {
const Slot slot{conditional, conditional->frontalAssignments(), 0.0};
slots_.emplace_back(std::move(slot));
}
std::reverse(slots_.begin(), slots_.end());
lowerBound_ = computeHeuristic();
}
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
std::function<void(const DiscreteBayesTree::sharedClique&)>
collectConditionals = [&](const auto& clique) {
if (!clique) return;
for (const auto& child : clique->children) collectConditionals(child);
conditionals_.push_back(clique->conditional());
};
for (const auto& root : bayesTree.roots()) collectConditionals(root);
costToGo_ = computeCostToGo(conditionals_);
using NodePtr = DiscreteBayesTree::sharedClique;
auto visitor = [this](const NodePtr& clique, int data) {
auto conditional = clique->conditional();
const Slot slot{conditional, conditional->frontalAssignments(), 0.0};
slots_.emplace_back(std::move(slot));
return data + 1;
};
int data = 0; // unused
treeTraversal::DepthFirstForest(bayesTree, data, visitor);
lowerBound_ = computeHeuristic();
}
struct SearchNodeQueue
: public std::priority_queue<SearchNode, std::vector<SearchNode>,
SearchNode::Compare> {
void expandNextNode(
const std::vector<DiscreteConditional::shared_ptr>& conditionals,
const std::vector<double>& costToGo, Solutions* solutions) {
void DiscreteSearch::print(const std::string& name,
const KeyFormatter& formatter) const {
std::cout << name << " with " << slots_.size() << " slots:\n";
for (size_t i = 0; i < slots_.size(); ++i) {
std::cout << i << ": " << slots_[i] << std::endl;
}
}
using SearchNodeQueue = std::priority_queue<SearchNode, std::vector<SearchNode>,
SearchNode::Compare>;
std::vector<Solution> DiscreteSearch::run(size_t K) const {
if (slots_.empty()) {
return {Solution(0.0, DiscreteValues())};
}
Solutions solutions(K);
SearchNodeQueue expansions;
expansions.push(SearchNode::Root(slots_.size(), lowerBound_));
// Perform the search
while (!expansions.empty()) {
// Pop the partial assignment with the smallest bound
SearchNode current = top();
pop();
SearchNode current = expansions.top();
expansions.pop();
// If we already have K solutions, prune if we cannot beat the worst one.
if (solutions->prune(current.bound)) {
return;
if (solutions.prune(current.bound)) {
continue;
}
// Check if we have a complete assignment
if (current.isComplete()) {
solutions->maybeAdd(current.error, current.assignment);
return;
solutions.maybeAdd(current.error, current.assignment);
continue;
}
// Expand on the next factor
const auto& conditional = conditionals[current.nextConditional];
for (auto& fa : conditional->frontalAssignments()) {
auto childNode = current.expand(*conditional, fa);
if (childNode.nextConditional >= 0)
childNode.bound = childNode.error + costToGo[childNode.nextConditional];
// Get the next slot to expand
const auto& slot = slots_[*current.next];
std::optional<size_t> nextSlot = *current.next + 1;
if (nextSlot == slots_.size()) nextSlot.reset();
for (auto& fa : slot.assignments) {
auto childNode = current.expand(fa, slot, nextSlot);
// Again, prune if we cannot beat the worst solution
if (!solutions->prune(childNode.bound)) {
emplace(childNode);
if (!solutions.prune(childNode.bound)) {
expansions.emplace(childNode);
}
}
}
};
std::vector<Solution> DiscreteSearch::run(size_t K) const {
Solutions solutions(K);
SearchNodeQueue expansions;
expansions.push(SearchNode::Root(conditionals_.size(),
costToGo_.empty() ? 0.0 : costToGo_.back()));
#ifdef DISCRETE_SEARCH_DEBUG
size_t numExpansions = 0;
#endif
// Perform the search
while (!expansions.empty()) {
expansions.expandNextNode(conditionals_, costToGo_, &solutions);
#ifdef DISCRETE_SEARCH_DEBUG
++numExpansions;
#endif
}
#ifdef DISCRETE_SEARCH_DEBUG
std::cout << "Number of expansions: " << numExpansions << std::endl;
#endif
// Extract solutions from bestSolutions in ascending order of error
return solutions.extractSolutions();
}
std::vector<double> DiscreteSearch::computeCostToGo(
const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
std::vector<double> costToGo;
/*
* We have a number of factors, each with a max value, and we want to compute
* a lower-bound on the cost-to-go for each slot, *not* including this factor.
* For the last slot[n-1], this is 0.0, as the cost after that is zero.
* For the second-to-last slot, it is h = -log(max(factor[n-1])), because after
* we assign slot[n-2] we still need to assign slot[n-1], which will cost *at
* least* h. We return the estimated lower bound of the cost for *all* slots.
*/
double DiscreteSearch::computeHeuristic() {
double error = 0.0;
for (const auto& conditional : conditionals) {
Ordering ordering(conditional->begin(), conditional->end());
auto maxx = conditional->max(ordering);
for (auto it = slots_.rbegin(); it != slots_.rend(); ++it) {
it->heuristic = error;
Ordering ordering(it->factor->begin(), it->factor->end());
auto maxx = it->factor->max(ordering);
error -= std::log(maxx->evaluate({}));
costToGo.push_back(error);
}
return costToGo;
return error;
}
} // namespace gtsam

View File

@ -9,8 +9,12 @@
* -------------------------------------------------------------------------- */
/*
* DiscreteSearch.cpp
/**
* @file DiscreteSearch.h
* @brief Defines the DiscreteSearch class for discrete search algorithms.
*
* @details This file contains the definition of the DiscreteSearch class, which
* is used in discrete search algorithms to find the K best solutions.
*
* @date January, 2025
* @author Frank Dellaert
@ -24,12 +28,53 @@
namespace gtsam {
/**
* DiscreteSearch: Search for the K best solutions.
* @brief DiscreteSearch: Search for the K best solutions.
*
* This class is used to search for the K best solutions in a DiscreteBayesNet.
* This is implemented with a modified A* search algorithm that uses a priority
* queue to manage the search nodes. That machinery is defined in the .cpp file.
* The heuristic we use is the sum of the log-probabilities of the
* maximum-probability assignments for each slot, for all slots to the right of
* the current slot.
*
* TODO: The heuristic could be refined by using the partial assignment in
* search node to refine the max-probability assignment for the remaining slots.
* This would incur more computation but will lead to fewer expansions.
*/
class GTSAM_EXPORT DiscreteSearch {
public:
/**
* @brief A solution to a discrete search problem.
* We structure the search as a set of slots, each with a factor and
* a set of variable assignments that need to be chosen. In addition, each
* slot has a heuristic associated with it.
*
* Example:
* The factors in the search problem (always parents before descendents!):
* [P(A), P(B|A), P(C|A,B)]
* The assignments for each factor.
* [[A0,A1], [B0,B1], [C0,C1,C2]]
* A lower bound on the cost-to-go after each slot, e.g.,
* [-log(max_B P(B|A)) -log(max_C P(C|A,B)), -log(max_C P(C|A,B)), 0.0]
* Note that these decrease as we move from right to left.
* We keep the global lower bound as lowerBound_. In the example, it is:
* -log(max_B P(B|A)) -log(max_C P(C|A,B)) -log(max_C P(C|A,B))
*/
struct Slot {
DiscreteFactor::shared_ptr factor;
std::vector<DiscreteValues> assignments;
double heuristic;
friend std::ostream& operator<<(std::ostream& os, const Slot& slot) {
os << "Slot with " << slot.assignments.size()
<< " assignments, heuristic=" << slot.heuristic;
os << ", factor:\n" << slot.factor->markdown() << std::endl;
return os;
}
};
/**
* A solution is a set of assignments, covering all the slots.
* as well as an associated error = -log(probability)
*/
struct Solution {
double error;
@ -42,16 +87,56 @@ class GTSAM_EXPORT DiscreteSearch {
}
};
/**
* Construct from a DiscreteBayesNet and K.
*/
DiscreteSearch(const DiscreteBayesNet& bayesNet);
public:
/// @name Standard Constructors
/// @{
/**
* Construct from a DiscreteBayesTree and K.
* Construct from a DiscreteFactorGraph.
*
* Internally creates either an elimination tree or a junction tree. The
* latter incurs more up-front computation but the search itself might be
* faster. Then again, for the elimination tree, the heuristic will be more
* fine-grained (more slots).
*
* @param factorGraph The factor graph to search over.
* @param ordering The ordering used to create etree (and maybe jtree).
* @param buildJunctionTree Whether to build a junction tree or not.
*/
static DiscreteSearch FromFactorGraph(const DiscreteFactorGraph& factorGraph,
const Ordering& ordering,
bool buildJunctionTree = false);
/// Construct from a DiscreteEliminationTree.
DiscreteSearch(const DiscreteEliminationTree& etree);
/// Construct from a DiscreteJunctionTree.
DiscreteSearch(const DiscreteJunctionTree& junctionTree);
//// Construct from a DiscreteBayesNet.
DiscreteSearch(const DiscreteBayesNet& bayesNet);
/// Construct from a DiscreteBayesTree.
DiscreteSearch(const DiscreteBayesTree& bayesTree);
/// @}
/// @name Testable
/// @{
/** Print the tree to cout */
void print(const std::string& name = "DiscreteSearch: ",
const KeyFormatter& formatter = DefaultKeyFormatter) const;
/// @}
/// @name Standard API
/// @{
/// Return lower bound on the cost-to-go for the entire search
double lowerBound() const { return lowerBound_; }
/// Read access to the slots
const std::vector<Slot>& slots() const { return slots_; }
/**
* @brief Search for the K best solutions.
*
@ -64,15 +149,16 @@ class GTSAM_EXPORT DiscreteSearch {
*/
std::vector<Solution> run(size_t K = 1) const;
/// @}
private:
/// Compute the cumulative cost-to-go for each conditional slot.
static std::vector<double> computeCostToGo(
const std::vector<DiscreteConditional::shared_ptr>& conditionals);
/**
* Compute the cumulative lower-bound cost-to-go after each slot is filled.
* @return the estimated lower bound of the cost for *all* slots.
*/
double computeHeuristic();
/// Expand the next node in the search tree.
void expandNextNode() const;
std::vector<DiscreteConditional::shared_ptr> conditionals_;
std::vector<double> costToGo_;
double lowerBound_; ///< Lower bound on the cost-to-go for the entire search.
std::vector<Slot> slots_; ///< The slots to fill in the search.
};
} // namespace gtsam

View File

@ -58,4 +58,4 @@ DiscreteBayesNet createAsiaExample() {
return asia;
}
} // namespace asia_example
} // namespace gtsam
} // namespace gtsam

View File

@ -28,9 +28,15 @@ using namespace gtsam;
namespace asia {
using namespace asia_example;
static const DiscreteBayesNet bayesNet = createAsiaExample();
// Create factor graph and optimize with max-product for MPE
static const DiscreteFactorGraph factorGraph(bayesNet);
static const DiscreteValues mpe = factorGraph.optimize();
// Create ordering
static const Ordering ordering{D, X, B, E, L, T, S, A};
// Create Bayes tree
static const DiscreteBayesTree bayesTree =
*factorGraph.eliminateMultifrontal(ordering);
} // namespace asia
@ -45,29 +51,6 @@ TEST(DiscreteBayesNet, EmptyKBest) {
EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
}
/* ************************************************************************* */
TEST(DiscreteBayesNet, AsiaKBest) {
const DiscreteSearch search(asia::bayesNet);
// Ask for the MPE
auto mpe = search.run();
EXPECT_LONGS_EQUAL(1, mpe.size());
// Regression test: check the MPE solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
// Check it is equal to MPE via inference
EXPECT(assert_equal(asia::mpe, mpe[0].assignment));
// Ask for top 4 solutions
auto solutions = search.run(4);
EXPECT_LONGS_EQUAL(4, solutions.size());
// Regression test: check the first and last solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
}
/* ************************************************************************* */
TEST(DiscreteBayesTree, EmptyTree) {
DiscreteBayesTree bt;
@ -81,26 +64,45 @@ TEST(DiscreteBayesTree, EmptyTree) {
}
/* ************************************************************************* */
TEST(DiscreteBayesTree, AsiaTreeKBest) {
DiscreteSearch search(asia::bayesTree);
TEST(DiscreteBayesNet, AsiaKBest) {
auto fromETree =
DiscreteSearch::FromFactorGraph(asia::factorGraph, asia::ordering);
auto fromJunctionTree =
DiscreteSearch::FromFactorGraph(asia::factorGraph, asia::ordering, true);
const DiscreteSearch fromBayesNet(asia::bayesNet);
const DiscreteSearch fromBayesTree(asia::bayesTree);
// Ask for MPE
auto mpe = search.run();
for (auto& search :
{fromETree, fromJunctionTree, fromBayesNet, fromBayesTree}) {
// Ask for the MPE
auto mpe = search.run();
EXPECT_LONGS_EQUAL(1, mpe.size());
// Regression test: check the MPE solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
// Regression on error lower bound
EXPECT_DOUBLES_EQUAL(1.205536, search.lowerBound(), 1e-5);
// Check it is equal to MPE via inference
EXPECT(assert_equal(asia::mpe, mpe[0].assignment));
// Check that the cost-to-go heuristic decreases from there
auto slots = search.slots();
double previousHeuristic = search.lowerBound();
for (auto&& slot : slots) {
EXPECT(slot.heuristic <= previousHeuristic);
previousHeuristic = slot.heuristic;
}
// Ask for top 4 solutions
auto solutions = search.run(4);
EXPECT_LONGS_EQUAL(1, mpe.size());
// Regression test: check the MPE solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
EXPECT_LONGS_EQUAL(4, solutions.size());
// Regression test: check the first and last solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
// Check it is equal to MPE via inference
EXPECT(assert_equal(asia::mpe, mpe[0].assignment));
// Ask for top 4 solutions
auto solutions = search.run(4);
EXPECT_LONGS_EQUAL(4, solutions.size());
// Regression test: check the first and last solution
EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
}
}
/* ************************************************************************* */