commit
						3f6ae48dfb
					
				| 
						 | 
				
			
			@ -0,0 +1,246 @@
 | 
			
		|||
/* ----------------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
 | 
			
		||||
 * Atlanta, Georgia 30332-0415
 | 
			
		||||
 * All Rights Reserved
 | 
			
		||||
 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
 | 
			
		||||
 | 
			
		||||
 * See LICENSE for the license information
 | 
			
		||||
 | 
			
		||||
 * -------------------------------------------------------------------------- */
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
 * DiscreteSearch.cpp
 | 
			
		||||
 *
 | 
			
		||||
 * @date January, 2025
 | 
			
		||||
 * @author Frank Dellaert
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DiscreteSearch.h>
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
using Solution = DiscreteSearch::Solution;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief Represents a node in the search tree for discrete search algorithms.
 | 
			
		||||
 *
 | 
			
		||||
 * @details Each SearchNode contains a partial assignment of discrete variables,
 | 
			
		||||
 * the current error, a bound on the final error, and the index of the next
 | 
			
		||||
 * conditional to be assigned.
 | 
			
		||||
 */
 | 
			
		||||
struct SearchNode {
 | 
			
		||||
  DiscreteValues assignment;  ///< Partial assignment of discrete variables.
 | 
			
		||||
  double error;               ///< Current error for the partial assignment.
 | 
			
		||||
  double bound;  ///< Lower bound on the final error for unassigned variables.
 | 
			
		||||
  int nextConditional;  ///< Index of the next conditional to be assigned.
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Construct the root node for the search.
 | 
			
		||||
   */
 | 
			
		||||
  static SearchNode Root(size_t numConditionals, double bound) {
 | 
			
		||||
    return {DiscreteValues(), 0.0, bound,
 | 
			
		||||
            static_cast<int>(numConditionals) - 1};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  struct Compare {
 | 
			
		||||
    bool operator()(const SearchNode& a, const SearchNode& b) const {
 | 
			
		||||
      return a.bound > b.bound;  // smallest bound -> highest priority
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Checks if the node represents a complete assignment.
 | 
			
		||||
   *
 | 
			
		||||
   * @return True if all variables have been assigned, false otherwise.
 | 
			
		||||
   */
 | 
			
		||||
  inline bool isComplete() const { return nextConditional < 0; }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Expands the node by assigning the next variable.
 | 
			
		||||
   *
 | 
			
		||||
   * @param conditional The discrete conditional representing the next variable
 | 
			
		||||
   * to be assigned.
 | 
			
		||||
   * @param fa The frontal assignment for the next variable.
 | 
			
		||||
   * @return A new SearchNode representing the expanded state.
 | 
			
		||||
   */
 | 
			
		||||
  SearchNode expand(const DiscreteConditional& conditional,
 | 
			
		||||
                    const DiscreteValues& fa) const {
 | 
			
		||||
    // Combine the new frontal assignment with the current partial assignment
 | 
			
		||||
    DiscreteValues newAssignment = assignment;
 | 
			
		||||
    for (auto& [key, value] : fa) {
 | 
			
		||||
      newAssignment[key] = value;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return {newAssignment, error + conditional.error(newAssignment), 0.0,
 | 
			
		||||
            nextConditional - 1};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Prints the SearchNode to an output stream.
 | 
			
		||||
   *
 | 
			
		||||
   * @param os The output stream.
 | 
			
		||||
   * @param node The SearchNode to be printed.
 | 
			
		||||
   * @return The output stream.
 | 
			
		||||
   */
 | 
			
		||||
  friend std::ostream& operator<<(std::ostream& os, const SearchNode& node) {
 | 
			
		||||
    os << "SearchNode(error=" << node.error << ", bound=" << node.bound << ")";
 | 
			
		||||
    return os;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct CompareSolution {
 | 
			
		||||
  bool operator()(const Solution& a, const Solution& b) const {
 | 
			
		||||
    return a.error < b.error;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Define the Solutions class
 | 
			
		||||
class Solutions {
 | 
			
		||||
 private:
 | 
			
		||||
  size_t maxSize_;
 | 
			
		||||
  std::priority_queue<Solution, std::vector<Solution>, CompareSolution> pq_;
 | 
			
		||||
 | 
			
		||||
 public:
 | 
			
		||||
  Solutions(size_t maxSize) : maxSize_(maxSize) {}
 | 
			
		||||
 | 
			
		||||
  /// Add a solution to the priority queue, possibly evicting the worst one.
 | 
			
		||||
  /// Return true if we added the solution.
 | 
			
		||||
  bool maybeAdd(double error, const DiscreteValues& assignment) {
 | 
			
		||||
    const bool full = pq_.size() == maxSize_;
 | 
			
		||||
    if (full && error >= pq_.top().error) return false;
 | 
			
		||||
    if (full) pq_.pop();
 | 
			
		||||
    pq_.emplace(error, assignment);
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Check if we have any solutions
 | 
			
		||||
  bool empty() const { return pq_.empty(); }
 | 
			
		||||
 | 
			
		||||
  // Method to print all solutions
 | 
			
		||||
  friend std::ostream& operator<<(std::ostream& os, const Solutions& sn) {
 | 
			
		||||
    os << "Solutions (top " << sn.pq_.size() << "):\n";
 | 
			
		||||
    auto pq = sn.pq_;
 | 
			
		||||
    while (!pq.empty()) {
 | 
			
		||||
      os << pq.top() << "\n";
 | 
			
		||||
      pq.pop();
 | 
			
		||||
    }
 | 
			
		||||
    return os;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /// Check if (partial) solution with given bound can be pruned. If we have
 | 
			
		||||
  /// room, we never prune. Otherwise, prune if lower bound on error is worse
 | 
			
		||||
  /// than our current worst error.
 | 
			
		||||
  bool prune(double bound) const {
 | 
			
		||||
    if (pq_.size() < maxSize_) return false;
 | 
			
		||||
    return bound >= pq_.top().error;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Method to extract solutions in ascending order of error
 | 
			
		||||
  std::vector<Solution> extractSolutions() {
 | 
			
		||||
    std::vector<Solution> result;
 | 
			
		||||
    while (!pq_.empty()) {
 | 
			
		||||
      result.push_back(pq_.top());
 | 
			
		||||
      pq_.pop();
 | 
			
		||||
    }
 | 
			
		||||
    std::sort(
 | 
			
		||||
        result.begin(), result.end(),
 | 
			
		||||
        [](const Solution& a, const Solution& b) { return a.error < b.error; });
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesNet& bayesNet) {
 | 
			
		||||
  std::vector<DiscreteConditional::shared_ptr> conditionals;
 | 
			
		||||
  for (auto& factor : bayesNet) conditionals_.push_back(factor);
 | 
			
		||||
  costToGo_ = computeCostToGo(conditionals_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DiscreteSearch::DiscreteSearch(const DiscreteBayesTree& bayesTree) {
 | 
			
		||||
  std::function<void(const DiscreteBayesTree::sharedClique&)>
 | 
			
		||||
      collectConditionals = [&](const auto& clique) {
 | 
			
		||||
        if (!clique) return;
 | 
			
		||||
        for (const auto& child : clique->children) collectConditionals(child);
 | 
			
		||||
        conditionals_.push_back(clique->conditional());
 | 
			
		||||
      };
 | 
			
		||||
  for (const auto& root : bayesTree.roots()) collectConditionals(root);
 | 
			
		||||
  costToGo_ = computeCostToGo(conditionals_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct SearchNodeQueue
 | 
			
		||||
    : public std::priority_queue<SearchNode, std::vector<SearchNode>,
 | 
			
		||||
                                 SearchNode::Compare> {
 | 
			
		||||
  void expandNextNode(
 | 
			
		||||
      const std::vector<DiscreteConditional::shared_ptr>& conditionals,
 | 
			
		||||
      const std::vector<double>& costToGo, Solutions* solutions) {
 | 
			
		||||
    // Pop the partial assignment with the smallest bound
 | 
			
		||||
    SearchNode current = top();
 | 
			
		||||
    pop();
 | 
			
		||||
 | 
			
		||||
    // If we already have K solutions, prune if we cannot beat the worst one.
 | 
			
		||||
    if (solutions->prune(current.bound)) {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Check if we have a complete assignment
 | 
			
		||||
    if (current.isComplete()) {
 | 
			
		||||
      solutions->maybeAdd(current.error, current.assignment);
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Expand on the next factor
 | 
			
		||||
    const auto& conditional = conditionals[current.nextConditional];
 | 
			
		||||
 | 
			
		||||
    for (auto& fa : conditional->frontalAssignments()) {
 | 
			
		||||
      auto childNode = current.expand(*conditional, fa);
 | 
			
		||||
      if (childNode.nextConditional >= 0)
 | 
			
		||||
        childNode.bound = childNode.error + costToGo[childNode.nextConditional];
 | 
			
		||||
 | 
			
		||||
      // Again, prune if we cannot beat the worst solution
 | 
			
		||||
      if (!solutions->prune(childNode.bound)) {
 | 
			
		||||
        emplace(childNode);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
std::vector<Solution> DiscreteSearch::run(size_t K) const {
 | 
			
		||||
  Solutions solutions(K);
 | 
			
		||||
  SearchNodeQueue expansions;
 | 
			
		||||
  expansions.push(SearchNode::Root(conditionals_.size(),
 | 
			
		||||
                                   costToGo_.empty() ? 0.0 : costToGo_.back()));
 | 
			
		||||
 | 
			
		||||
#ifdef DISCRETE_SEARCH_DEBUG
 | 
			
		||||
  size_t numExpansions = 0;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Perform the search
 | 
			
		||||
  while (!expansions.empty()) {
 | 
			
		||||
    expansions.expandNextNode(conditionals_, costToGo_, &solutions);
 | 
			
		||||
#ifdef DISCRETE_SEARCH_DEBUG
 | 
			
		||||
    ++numExpansions;
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#ifdef DISCRETE_SEARCH_DEBUG
 | 
			
		||||
  std::cout << "Number of expansions: " << numExpansions << std::endl;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Extract solutions from bestSolutions in ascending order of error
 | 
			
		||||
  return solutions.extractSolutions();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<double> DiscreteSearch::computeCostToGo(
 | 
			
		||||
    const std::vector<DiscreteConditional::shared_ptr>& conditionals) {
 | 
			
		||||
  std::vector<double> costToGo;
 | 
			
		||||
  double error = 0.0;
 | 
			
		||||
  for (const auto& conditional : conditionals) {
 | 
			
		||||
    Ordering ordering(conditional->begin(), conditional->end());
 | 
			
		||||
    auto maxx = conditional->max(ordering);
 | 
			
		||||
    error -= std::log(maxx->evaluate({}));
 | 
			
		||||
    costToGo.push_back(error);
 | 
			
		||||
  }
 | 
			
		||||
  return costToGo;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,78 @@
 | 
			
		|||
/* ----------------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
 | 
			
		||||
 * Atlanta, Georgia 30332-0415
 | 
			
		||||
 * All Rights Reserved
 | 
			
		||||
 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
 | 
			
		||||
 | 
			
		||||
 * See LICENSE for the license information
 | 
			
		||||
 | 
			
		||||
 * -------------------------------------------------------------------------- */
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
 * DiscreteSearch.cpp
 | 
			
		||||
 *
 | 
			
		||||
 * @date January, 2025
 | 
			
		||||
 * @author Frank Dellaert
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DiscreteBayesNet.h>
 | 
			
		||||
#include <gtsam/discrete/DiscreteBayesTree.h>
 | 
			
		||||
 | 
			
		||||
#include <queue>
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * DiscreteSearch: Search for the K best solutions.
 | 
			
		||||
 */
 | 
			
		||||
class GTSAM_EXPORT DiscreteSearch {
 | 
			
		||||
 public:
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief A solution to a discrete search problem.
 | 
			
		||||
   */
 | 
			
		||||
  struct Solution {
 | 
			
		||||
    double error;
 | 
			
		||||
    DiscreteValues assignment;
 | 
			
		||||
    Solution(double err, const DiscreteValues& assign)
 | 
			
		||||
        : error(err), assignment(assign) {}
 | 
			
		||||
    friend std::ostream& operator<<(std::ostream& os, const Solution& sn) {
 | 
			
		||||
      os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]";
 | 
			
		||||
      return os;
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Construct from a DiscreteBayesNet and K.
 | 
			
		||||
   */
 | 
			
		||||
  DiscreteSearch(const DiscreteBayesNet& bayesNet);
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Construct from a DiscreteBayesTree and K.
 | 
			
		||||
   */
 | 
			
		||||
  DiscreteSearch(const DiscreteBayesTree& bayesTree);
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * @brief Search for the K best solutions.
 | 
			
		||||
   *
 | 
			
		||||
   * This method performs a search to find the K best solutions for the given
 | 
			
		||||
   * DiscreteBayesNet. It uses a priority queue to manage the search nodes,
 | 
			
		||||
   * expanding nodes with the smallest bound first. The search continues until
 | 
			
		||||
   * all possible nodes have been expanded or pruned.
 | 
			
		||||
   *
 | 
			
		||||
   * @return A vector of the K best solutions found during the search.
 | 
			
		||||
   */
 | 
			
		||||
  std::vector<Solution> run(size_t K = 1) const;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  /// Compute the cumulative cost-to-go for each conditional slot.
 | 
			
		||||
  static std::vector<double> computeCostToGo(
 | 
			
		||||
      const std::vector<DiscreteConditional::shared_ptr>& conditionals);
 | 
			
		||||
 | 
			
		||||
  /// Expand the next node in the search tree.
 | 
			
		||||
  void expandNextNode() const;
 | 
			
		||||
 | 
			
		||||
  std::vector<DiscreteConditional::shared_ptr> conditionals_;
 | 
			
		||||
  std::vector<double> costToGo_;
 | 
			
		||||
};
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			@ -26,12 +26,24 @@ using std::stringstream;
 | 
			
		|||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
static void stream(std::ostream& os, const DiscreteValues& x,
 | 
			
		||||
                   const KeyFormatter& keyFormatter) {
 | 
			
		||||
  for (const auto& kv : x)
 | 
			
		||||
    os << "(" << keyFormatter(kv.first) << ", " << kv.second << ")";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
std::ostream& operator<<(std::ostream& os, const DiscreteValues& x) {
 | 
			
		||||
  stream(os, x, DefaultKeyFormatter);
 | 
			
		||||
  return os;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************ */
 | 
			
		||||
void DiscreteValues::print(const string& s,
 | 
			
		||||
                           const KeyFormatter& keyFormatter) const {
 | 
			
		||||
  cout << s << ": ";
 | 
			
		||||
  for (auto&& kv : *this)
 | 
			
		||||
    cout << "(" << keyFormatter(kv.first) << ", " << kv.second << ")";
 | 
			
		||||
  stream(cout, *this, keyFormatter);
 | 
			
		||||
  cout << endl;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -64,6 +64,9 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
 | 
			
		|||
  /// @name Standard Interface
 | 
			
		||||
  /// @{
 | 
			
		||||
 | 
			
		||||
  /// ostream operator:
 | 
			
		||||
  friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x);
 | 
			
		||||
 | 
			
		||||
  // insert in base class;
 | 
			
		||||
  std::pair<iterator, bool> insert( const value_type& value ){
 | 
			
		||||
    return Base::insert(value);
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,61 @@
 | 
			
		|||
/* ----------------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
 | 
			
		||||
 * Atlanta, Georgia 30332-0415
 | 
			
		||||
 * All Rights Reserved
 | 
			
		||||
 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
 | 
			
		||||
 | 
			
		||||
 * See LICENSE for the license information
 | 
			
		||||
 | 
			
		||||
 * -------------------------------------------------------------------------- */
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
 * AsiaExample.h
 | 
			
		||||
 *
 | 
			
		||||
 *  @date Jan, 2025
 | 
			
		||||
 *  @author Frank Dellaert
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
#include <gtsam/discrete/DiscreteBayesNet.h>
 | 
			
		||||
#include <gtsam/inference/Symbol.h>
 | 
			
		||||
 | 
			
		||||
namespace gtsam {
 | 
			
		||||
namespace asia_example {
 | 
			
		||||
 | 
			
		||||
static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3),
 | 
			
		||||
                 B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6),
 | 
			
		||||
                 S = Symbol('S', 7), A = Symbol('A', 8);
 | 
			
		||||
 | 
			
		||||
static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2),
 | 
			
		||||
    Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2),
 | 
			
		||||
    Asia(A, 2);
 | 
			
		||||
 | 
			
		||||
// Function to construct the Asia priors
 | 
			
		||||
DiscreteBayesNet createPriors() {
 | 
			
		||||
  DiscreteBayesNet priors;
 | 
			
		||||
  priors.add(Smoking % "50/50");
 | 
			
		||||
  priors.add(Asia, "99/1");
 | 
			
		||||
  return priors;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Function to construct the incomplete Asia example
 | 
			
		||||
DiscreteBayesNet createFragment() {
 | 
			
		||||
  DiscreteBayesNet fragment;
 | 
			
		||||
  fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
 | 
			
		||||
  fragment.add(LungCancer | Smoking = "99/1 90/10");
 | 
			
		||||
  fragment.add(Tuberculosis | Asia = "99/1 95/5");
 | 
			
		||||
  for (const auto& factor : createPriors()) fragment.push_back(factor);
 | 
			
		||||
  return fragment;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Function to construct the Asia example
 | 
			
		||||
DiscreteBayesNet createAsiaExample() {
 | 
			
		||||
  DiscreteBayesNet asia;
 | 
			
		||||
  asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
 | 
			
		||||
  asia.add(XRay | Either = "95/5 2/98");
 | 
			
		||||
  asia.add(Bronchitis | Smoking = "70/30 40/60");
 | 
			
		||||
  for (const auto& factor : createFragment()) asia.push_back(factor);
 | 
			
		||||
  return asia;
 | 
			
		||||
}
 | 
			
		||||
}  // namespace asia_example
 | 
			
		||||
}  // namespace gtsam
 | 
			
		||||
| 
						 | 
				
			
			@ -23,40 +23,19 @@
 | 
			
		|||
#include <gtsam/discrete/DiscreteBayesNet.h>
 | 
			
		||||
#include <gtsam/discrete/DiscreteFactorGraph.h>
 | 
			
		||||
#include <gtsam/discrete/DiscreteMarginals.h>
 | 
			
		||||
#include <gtsam/inference/Symbol.h>
 | 
			
		||||
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
using namespace std;
 | 
			
		||||
#include "AsiaExample.h"
 | 
			
		||||
 | 
			
		||||
using namespace gtsam;
 | 
			
		||||
 | 
			
		||||
static const DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2),
 | 
			
		||||
    LungCancer(6, 2), Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
 | 
			
		||||
 | 
			
		||||
using ADT = AlgebraicDecisionTree<Key>;
 | 
			
		||||
 | 
			
		||||
// Function to construct the Asia example
 | 
			
		||||
DiscreteBayesNet constructAsiaExample() {
 | 
			
		||||
  DiscreteBayesNet asia;
 | 
			
		||||
 | 
			
		||||
  asia.add(Asia, "99/1");
 | 
			
		||||
  asia.add(Smoking % "50/50");  // Signature version
 | 
			
		||||
 | 
			
		||||
  asia.add(Tuberculosis | Asia = "99/1 95/5");
 | 
			
		||||
  asia.add(LungCancer | Smoking = "99/1 90/10");
 | 
			
		||||
  asia.add(Bronchitis | Smoking = "70/30 40/60");
 | 
			
		||||
 | 
			
		||||
  asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
 | 
			
		||||
 | 
			
		||||
  asia.add(XRay | Either = "95/5 2/98");
 | 
			
		||||
  asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
 | 
			
		||||
 | 
			
		||||
  return asia;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST(DiscreteBayesNet, bayesNet) {
 | 
			
		||||
  using ADT = AlgebraicDecisionTree<Key>;
 | 
			
		||||
  DiscreteBayesNet bayesNet;
 | 
			
		||||
  DiscreteKey Parent(0, 2), Child(1, 2);
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -86,11 +65,12 @@ TEST(DiscreteBayesNet, bayesNet) {
 | 
			
		|||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST(DiscreteBayesNet, Asia) {
 | 
			
		||||
  DiscreteBayesNet asia = constructAsiaExample();
 | 
			
		||||
  using namespace asia_example;
 | 
			
		||||
  const DiscreteBayesNet asia = createAsiaExample();
 | 
			
		||||
 | 
			
		||||
  // Convert to factor graph
 | 
			
		||||
  DiscreteFactorGraph fg(asia);
 | 
			
		||||
  LONGS_EQUAL(3, fg.back()->size());
 | 
			
		||||
  LONGS_EQUAL(1, fg.back()->size());
 | 
			
		||||
 | 
			
		||||
  // Check the marginals we know (of the parent-less nodes)
 | 
			
		||||
  DiscreteMarginals marginals(fg);
 | 
			
		||||
| 
						 | 
				
			
			@ -99,7 +79,7 @@ TEST(DiscreteBayesNet, Asia) {
 | 
			
		|||
  EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking)));
 | 
			
		||||
 | 
			
		||||
  // Create solver and eliminate
 | 
			
		||||
  const Ordering ordering{0, 1, 2, 3, 4, 5, 6, 7};
 | 
			
		||||
  const Ordering ordering{A, D, T, X, S, E, L, B};
 | 
			
		||||
  DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
 | 
			
		||||
  DiscreteConditional expected2(Bronchitis % "11/9");
 | 
			
		||||
  EXPECT(assert_equal(expected2, *chordal->back()));
 | 
			
		||||
| 
						 | 
				
			
			@ -144,55 +124,50 @@ TEST(DiscreteBayesNet, Sugar) {
 | 
			
		|||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST(DiscreteBayesNet, Dot) {
 | 
			
		||||
  DiscreteBayesNet fragment;
 | 
			
		||||
  fragment.add(Asia % "99/1");
 | 
			
		||||
  fragment.add(Smoking % "50/50");
 | 
			
		||||
  using namespace asia_example;
 | 
			
		||||
  const DiscreteBayesNet fragment = createFragment();
 | 
			
		||||
 | 
			
		||||
  fragment.add(Tuberculosis | Asia = "99/1 95/5");
 | 
			
		||||
  fragment.add(LungCancer | Smoking = "99/1 90/10");
 | 
			
		||||
  fragment.add((Either | Tuberculosis, LungCancer) = "F T T T");
 | 
			
		||||
 | 
			
		||||
  string actual = fragment.dot();
 | 
			
		||||
  EXPECT(actual ==
 | 
			
		||||
         "digraph {\n"
 | 
			
		||||
         "  size=\"5,5\";\n"
 | 
			
		||||
         "\n"
 | 
			
		||||
         "  var0[label=\"0\"];\n"
 | 
			
		||||
         "  var3[label=\"3\"];\n"
 | 
			
		||||
         "  var4[label=\"4\"];\n"
 | 
			
		||||
         "  var5[label=\"5\"];\n"
 | 
			
		||||
         "  var6[label=\"6\"];\n"
 | 
			
		||||
         "\n"
 | 
			
		||||
         "  var3->var5\n"
 | 
			
		||||
         "  var6->var5\n"
 | 
			
		||||
         "  var4->var6\n"
 | 
			
		||||
         "  var0->var3\n"
 | 
			
		||||
         "}");
 | 
			
		||||
  std::string expected =
 | 
			
		||||
      "digraph {\n"
 | 
			
		||||
      "  size=\"5,5\";\n"
 | 
			
		||||
      "\n"
 | 
			
		||||
      "  var4683743612465315848[label=\"A8\"];\n"
 | 
			
		||||
      "  var4971973988617027587[label=\"E3\"];\n"
 | 
			
		||||
      "  var5476377146882523141[label=\"L5\"];\n"
 | 
			
		||||
      "  var5980780305148018695[label=\"S7\"];\n"
 | 
			
		||||
      "  var6052837899185946630[label=\"T6\"];\n"
 | 
			
		||||
      "\n"
 | 
			
		||||
      "  var4683743612465315848->var6052837899185946630\n"
 | 
			
		||||
      "  var5980780305148018695->var5476377146882523141\n"
 | 
			
		||||
      "  var6052837899185946630->var4971973988617027587\n"
 | 
			
		||||
      "  var5476377146882523141->var4971973988617027587\n"
 | 
			
		||||
      "}";
 | 
			
		||||
  std::string actual = fragment.dot();
 | 
			
		||||
  EXPECT(actual.compare(expected) == 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
// Check markdown representation looks as expected.
 | 
			
		||||
TEST(DiscreteBayesNet, markdown) {
 | 
			
		||||
  DiscreteBayesNet fragment;
 | 
			
		||||
  fragment.add(Asia % "99/1");
 | 
			
		||||
  fragment.add(Smoking | Asia = "8/2 7/3");
 | 
			
		||||
  using namespace asia_example;
 | 
			
		||||
  DiscreteBayesNet priors = createPriors();
 | 
			
		||||
 | 
			
		||||
  string expected =
 | 
			
		||||
  std::string expected =
 | 
			
		||||
      "`DiscreteBayesNet` of size 2\n"
 | 
			
		||||
      "\n"
 | 
			
		||||
      " *P(Smoking):*\n\n"
 | 
			
		||||
      "|Smoking|value|\n"
 | 
			
		||||
      "|:-:|:-:|\n"
 | 
			
		||||
      "|0|0.5|\n"
 | 
			
		||||
      "|1|0.5|\n"
 | 
			
		||||
      "\n"
 | 
			
		||||
      " *P(Asia):*\n\n"
 | 
			
		||||
      "|Asia|value|\n"
 | 
			
		||||
      "|:-:|:-:|\n"
 | 
			
		||||
      "|0|0.99|\n"
 | 
			
		||||
      "|1|0.01|\n"
 | 
			
		||||
      "\n"
 | 
			
		||||
      " *P(Smoking|Asia):*\n\n"
 | 
			
		||||
      "|*Asia*|0|1|\n"
 | 
			
		||||
      "|:-:|:-:|:-:|\n"
 | 
			
		||||
      "|0|0.8|0.2|\n"
 | 
			
		||||
      "|1|0.7|0.3|\n\n";
 | 
			
		||||
  auto formatter = [](Key key) { return key == 0 ? "Asia" : "Smoking"; };
 | 
			
		||||
  string actual = fragment.markdown(formatter);
 | 
			
		||||
      "|1|0.01|\n\n";
 | 
			
		||||
  auto formatter = [](Key key) { return key == A ? "Asia" : "Smoking"; };
 | 
			
		||||
  std::string actual = priors.markdown(formatter);
 | 
			
		||||
  EXPECT(actual == expected);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,111 @@
 | 
			
		|||
/* ----------------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
 | 
			
		||||
 * Atlanta, Georgia 30332-0415
 | 
			
		||||
 * All Rights Reserved
 | 
			
		||||
 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
 | 
			
		||||
 | 
			
		||||
 * See LICENSE for the license information
 | 
			
		||||
 | 
			
		||||
 * -------------------------------------------------------------------------- */
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
 * testDiscreteSearch.cpp
 | 
			
		||||
 *
 | 
			
		||||
 *  @date January, 2025
 | 
			
		||||
 *  @author Frank Dellaert
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
#include <CppUnitLite/TestHarness.h>
 | 
			
		||||
#include <gtsam/base/Testable.h>
 | 
			
		||||
#include <gtsam/discrete/DiscreteSearch.h>
 | 
			
		||||
 | 
			
		||||
#include "AsiaExample.h"
 | 
			
		||||
 | 
			
		||||
using namespace gtsam;
 | 
			
		||||
 | 
			
		||||
// Create Asia Bayes net, FG, and Bayes tree once
 | 
			
		||||
namespace asia {
 | 
			
		||||
using namespace asia_example;
 | 
			
		||||
static const DiscreteBayesNet bayesNet = createAsiaExample();
 | 
			
		||||
static const DiscreteFactorGraph factorGraph(bayesNet);
 | 
			
		||||
static const DiscreteValues mpe = factorGraph.optimize();
 | 
			
		||||
static const Ordering ordering{D, X, B, E, L, T, S, A};
 | 
			
		||||
static const DiscreteBayesTree bayesTree =
 | 
			
		||||
    *factorGraph.eliminateMultifrontal(ordering);
 | 
			
		||||
}  // namespace asia
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST(DiscreteBayesNet, EmptyKBest) {
 | 
			
		||||
  DiscreteBayesNet net;  // no factors
 | 
			
		||||
  DiscreteSearch search(net);
 | 
			
		||||
  auto solutions = search.run(3);
 | 
			
		||||
  // Expect one solution with empty assignment, error=0
 | 
			
		||||
  EXPECT_LONGS_EQUAL(1, solutions.size());
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST(DiscreteBayesNet, AsiaKBest) {
 | 
			
		||||
  const DiscreteSearch search(asia::bayesNet);
 | 
			
		||||
 | 
			
		||||
  // Ask for the MPE
 | 
			
		||||
  auto mpe = search.run();
 | 
			
		||||
 | 
			
		||||
  EXPECT_LONGS_EQUAL(1, mpe.size());
 | 
			
		||||
  // Regression test: check the MPE solution
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
 | 
			
		||||
 | 
			
		||||
  // Check it is equal to MPE via inference
 | 
			
		||||
  EXPECT(assert_equal(asia::mpe, mpe[0].assignment));
 | 
			
		||||
 | 
			
		||||
  // Ask for top 4 solutions
 | 
			
		||||
  auto solutions = search.run(4);
 | 
			
		||||
 | 
			
		||||
  EXPECT_LONGS_EQUAL(4, solutions.size());
 | 
			
		||||
  // Regression test: check the first and last solution
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST(DiscreteBayesTree, EmptyTree) {
 | 
			
		||||
  DiscreteBayesTree bt;
 | 
			
		||||
 | 
			
		||||
  DiscreteSearch search(bt);
 | 
			
		||||
  auto solutions = search.run(3);
 | 
			
		||||
 | 
			
		||||
  // We expect exactly 1 solution with error = 0.0 (the empty assignment).
 | 
			
		||||
  EXPECT_LONGS_EQUAL(1, solutions.size());
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
TEST(DiscreteBayesTree, AsiaTreeKBest) {
 | 
			
		||||
  DiscreteSearch search(asia::bayesTree);
 | 
			
		||||
 | 
			
		||||
  // Ask for MPE
 | 
			
		||||
  auto mpe = search.run();
 | 
			
		||||
 | 
			
		||||
  EXPECT_LONGS_EQUAL(1, mpe.size());
 | 
			
		||||
  // Regression test: check the MPE solution
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(mpe[0].error), 1e-5);
 | 
			
		||||
 | 
			
		||||
  // Check it is equal to MPE via inference
 | 
			
		||||
  EXPECT(assert_equal(asia::mpe, mpe[0].assignment));
 | 
			
		||||
 | 
			
		||||
  // Ask for top 4 solutions
 | 
			
		||||
  auto solutions = search.run(4);
 | 
			
		||||
 | 
			
		||||
  EXPECT_LONGS_EQUAL(4, solutions.size());
 | 
			
		||||
  // Regression test: check the first and last solution
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5);
 | 
			
		||||
  EXPECT_DOUBLES_EQUAL(2.201708, std::fabs(solutions[3].error), 1e-5);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
int main() {
 | 
			
		||||
  TestResult tr;
 | 
			
		||||
  return TestRegistry::runAllTests(tr);
 | 
			
		||||
}
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
		Loading…
	
		Reference in New Issue