Asia example
							parent
							
								
									1f4d9bbd7e
								
							
						
					
					
						commit
						d879b156f8
					
				|  | @ -0,0 +1,61 @@ | ||||||
|  | /* ----------------------------------------------------------------------------
 | ||||||
|  | 
 | ||||||
|  |  * GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||||||
|  |  * Atlanta, Georgia 30332-0415 | ||||||
|  |  * All Rights Reserved | ||||||
|  |  * Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||||||
|  | 
 | ||||||
|  |  * See LICENSE for the license information | ||||||
|  | 
 | ||||||
|  |  * -------------------------------------------------------------------------- */ | ||||||
|  | 
 | ||||||
|  | /*
 | ||||||
|  |  * AsiaExample.h | ||||||
|  |  * | ||||||
|  |  *  @date Jan, 2025 | ||||||
|  |  *  @author Frank Dellaert | ||||||
|  |  */ | ||||||
|  | 
 | ||||||
|  | #include <gtsam/discrete/DiscreteBayesNet.h> | ||||||
|  | #include <gtsam/inference/Symbol.h> | ||||||
|  | 
 | ||||||
|  | namespace gtsam { | ||||||
|  | namespace asia_example { | ||||||
|  | 
 | ||||||
|  | static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), | ||||||
|  |                  B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), | ||||||
|  |                  S = Symbol('S', 7), A = Symbol('A', 8); | ||||||
|  | 
 | ||||||
|  | static const DiscreteKey Dyspnea(D, 2), XRay(X, 2), Either(E, 2), | ||||||
|  |     Bronchitis(B, 2), LungCancer(L, 2), Tuberculosis(T, 2), Smoking(S, 2), | ||||||
|  |     Asia(A, 2); | ||||||
|  | 
 | ||||||
|  | // Function to construct the incomplete Asia example
 | ||||||
|  | DiscreteBayesNet createPriors() { | ||||||
|  |   DiscreteBayesNet priors; | ||||||
|  |   priors.add(Smoking % "50/50"); | ||||||
|  |   priors.add(Asia, "99/1"); | ||||||
|  |   return priors; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Function to construct the incomplete Asia example
 | ||||||
|  | DiscreteBayesNet createFragment() { | ||||||
|  |   DiscreteBayesNet fragment; | ||||||
|  |   fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); | ||||||
|  |   fragment.add(LungCancer | Smoking = "99/1 90/10"); | ||||||
|  |   fragment.add(Tuberculosis | Asia = "99/1 95/5"); | ||||||
|  |   for (const auto& factor : createPriors()) fragment.push_back(factor); | ||||||
|  |   return fragment; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // Function to construct the Asia example
 | ||||||
|  | DiscreteBayesNet createAsiaExample() { | ||||||
|  |   DiscreteBayesNet asia; | ||||||
|  |   asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); | ||||||
|  |   asia.add(XRay | Either = "95/5 2/98"); | ||||||
|  |   asia.add(Bronchitis | Smoking = "70/30 40/60"); | ||||||
|  |   for (const auto& factor : createFragment()) asia.push_back(factor); | ||||||
|  |   return asia; | ||||||
|  | } | ||||||
|  | }  // namespace asia_example
 | ||||||
|  | }  // namespace gtsam
 | ||||||
|  | @ -29,40 +29,13 @@ | ||||||
| #include <string> | #include <string> | ||||||
| #include <vector> | #include <vector> | ||||||
| 
 | 
 | ||||||
| using namespace std; | #include "AsiaExample.h" | ||||||
|  | 
 | ||||||
| using namespace gtsam; | using namespace gtsam; | ||||||
| 
 | 
 | ||||||
| namespace keys { |  | ||||||
| static const Key D = Symbol('D', 1), X = Symbol('X', 2), E = Symbol('E', 3), |  | ||||||
|                  B = Symbol('B', 4), L = Symbol('L', 5), T = Symbol('T', 6), |  | ||||||
|                  S = Symbol('S', 7), A = Symbol('A', 8); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| static const DiscreteKey Dyspnea(keys::D, 2), XRay(keys::X, 2), |  | ||||||
|     Either(keys::E, 2), Bronchitis(keys::B, 2), LungCancer(keys::L, 2), |  | ||||||
|     Tuberculosis(keys::T, 2), Smoking(keys::S, 2), Asia(keys::A, 2); |  | ||||||
| 
 |  | ||||||
| using ADT = AlgebraicDecisionTree<Key>; |  | ||||||
| 
 |  | ||||||
| // Function to construct the Asia example
 |  | ||||||
| DiscreteBayesNet constructAsiaExample() { |  | ||||||
|   DiscreteBayesNet asia; |  | ||||||
| 
 |  | ||||||
|   // Add in topological sort order, parents last:
 |  | ||||||
|   asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9"); |  | ||||||
|   asia.add(XRay | Either = "95/5 2/98"); |  | ||||||
|   asia.add((Either | Tuberculosis, LungCancer) = "F T T T"); |  | ||||||
|   asia.add(Bronchitis | Smoking = "70/30 40/60"); |  | ||||||
|   asia.add(LungCancer | Smoking = "99/1 90/10"); |  | ||||||
|   asia.add(Tuberculosis | Asia = "99/1 95/5"); |  | ||||||
|   asia.add(Smoking % "50/50");  // Signature version
 |  | ||||||
|   asia.add(Asia, "99/1"); |  | ||||||
| 
 |  | ||||||
|   return asia; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| TEST(DiscreteBayesNet, bayesNet) { | TEST(DiscreteBayesNet, bayesNet) { | ||||||
|  |   using ADT = AlgebraicDecisionTree<Key>; | ||||||
|   DiscreteBayesNet bayesNet; |   DiscreteBayesNet bayesNet; | ||||||
|   DiscreteKey Parent(0, 2), Child(1, 2); |   DiscreteKey Parent(0, 2), Child(1, 2); | ||||||
| 
 | 
 | ||||||
|  | @ -92,7 +65,8 @@ TEST(DiscreteBayesNet, bayesNet) { | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| TEST(DiscreteBayesNet, Asia) { | TEST(DiscreteBayesNet, Asia) { | ||||||
|   DiscreteBayesNet asia = constructAsiaExample(); |   using namespace asia_example; | ||||||
|  |   const DiscreteBayesNet asia = createAsiaExample(); | ||||||
| 
 | 
 | ||||||
|   // Convert to factor graph
 |   // Convert to factor graph
 | ||||||
|   DiscreteFactorGraph fg(asia); |   DiscreteFactorGraph fg(asia); | ||||||
|  | @ -105,8 +79,7 @@ TEST(DiscreteBayesNet, Asia) { | ||||||
|   EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); |   EXPECT(assert_equal(vs, marginals.marginalProbabilities(Smoking))); | ||||||
| 
 | 
 | ||||||
|   // Create solver and eliminate
 |   // Create solver and eliminate
 | ||||||
|   const Ordering ordering{keys::A, keys::D, keys::T, keys::X, |   const Ordering ordering{A, D, T, X, S, E, L, B}; | ||||||
|                           keys::S, keys::E, keys::L, keys::B}; |  | ||||||
|   DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); |   DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering); | ||||||
|   DiscreteConditional expected2(Bronchitis % "11/9"); |   DiscreteConditional expected2(Bronchitis % "11/9"); | ||||||
|   EXPECT(assert_equal(expected2, *chordal->back())); |   EXPECT(assert_equal(expected2, *chordal->back())); | ||||||
|  | @ -151,16 +124,10 @@ TEST(DiscreteBayesNet, Sugar) { | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| TEST(DiscreteBayesNet, Dot) { | TEST(DiscreteBayesNet, Dot) { | ||||||
|   DiscreteBayesNet fragment; |   using namespace asia_example; | ||||||
|   fragment.add(Asia % "99/1"); |   const DiscreteBayesNet fragment = createFragment(); | ||||||
|   fragment.add(Smoking % "50/50"); |  | ||||||
| 
 | 
 | ||||||
|   fragment.add(Tuberculosis | Asia = "99/1 95/5"); |   std::string expected = | ||||||
|   fragment.add(LungCancer | Smoking = "99/1 90/10"); |  | ||||||
|   fragment.add((Either | Tuberculosis, LungCancer) = "F T T T"); |  | ||||||
| 
 |  | ||||||
|   string actual = fragment.dot(); |  | ||||||
|   EXPECT(actual == |  | ||||||
|       "digraph {\n" |       "digraph {\n" | ||||||
|       "  size=\"5,5\";\n" |       "  size=\"5,5\";\n" | ||||||
|       "\n" |       "\n" | ||||||
|  | @ -170,300 +137,40 @@ TEST(DiscreteBayesNet, Dot) { | ||||||
|       "  var5980780305148018695[label=\"S7\"];\n" |       "  var5980780305148018695[label=\"S7\"];\n" | ||||||
|       "  var6052837899185946630[label=\"T6\"];\n" |       "  var6052837899185946630[label=\"T6\"];\n" | ||||||
|       "\n" |       "\n" | ||||||
|  |       "  var4683743612465315848->var6052837899185946630\n" | ||||||
|  |       "  var5980780305148018695->var5476377146882523141\n" | ||||||
|       "  var6052837899185946630->var4971973988617027587\n" |       "  var6052837899185946630->var4971973988617027587\n" | ||||||
|       "  var5476377146882523141->var4971973988617027587\n" |       "  var5476377146882523141->var4971973988617027587\n" | ||||||
|          "  var5980780305148018695->var5476377146882523141\n" |       "}"; | ||||||
|          "  var4683743612465315848->var6052837899185946630\n" |   std::string actual = fragment.dot(); | ||||||
|          "}"); |   EXPECT(actual.compare(expected) == 0); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| // Check markdown representation looks as expected.
 | // Check markdown representation looks as expected.
 | ||||||
| TEST(DiscreteBayesNet, markdown) { | TEST(DiscreteBayesNet, markdown) { | ||||||
|   DiscreteBayesNet fragment; |   using namespace asia_example; | ||||||
|   fragment.add(Asia % "99/1"); |   DiscreteBayesNet priors = createPriors(); | ||||||
|   fragment.add(Smoking | Asia = "8/2 7/3"); |  | ||||||
| 
 | 
 | ||||||
|   string expected = |   std::string expected = | ||||||
|       "`DiscreteBayesNet` of size 2\n" |       "`DiscreteBayesNet` of size 2\n" | ||||||
|       "\n" |       "\n" | ||||||
|  |       " *P(Smoking):*\n\n" | ||||||
|  |       "|Smoking|value|\n" | ||||||
|  |       "|:-:|:-:|\n" | ||||||
|  |       "|0|0.5|\n" | ||||||
|  |       "|1|0.5|\n" | ||||||
|  |       "\n" | ||||||
|       " *P(Asia):*\n\n" |       " *P(Asia):*\n\n" | ||||||
|       "|Asia|value|\n" |       "|Asia|value|\n" | ||||||
|       "|:-:|:-:|\n" |       "|:-:|:-:|\n" | ||||||
|       "|0|0.99|\n" |       "|0|0.99|\n" | ||||||
|       "|1|0.01|\n" |       "|1|0.01|\n\n"; | ||||||
|       "\n" |   auto formatter = [](Key key) { return key == A ? "Asia" : "Smoking"; }; | ||||||
|       " *P(Smoking|Asia):*\n\n" |   std::string actual = priors.markdown(formatter); | ||||||
|       "|*Asia*|0|1|\n" |  | ||||||
|       "|:-:|:-:|:-:|\n" |  | ||||||
|       "|0|0.8|0.2|\n" |  | ||||||
|       "|1|0.7|0.3|\n\n"; |  | ||||||
|   auto formatter = [](Key key) { return key == keys::A ? "Asia" : "Smoking"; }; |  | ||||||
|   string actual = fragment.markdown(formatter); |  | ||||||
|   EXPECT(actual == expected); |   EXPECT(actual == expected); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /* ************************************************************************* */ |  | ||||||
| #include <algorithm> |  | ||||||
| #include <cmath> |  | ||||||
| #include <iostream> |  | ||||||
| #include <map> |  | ||||||
| #include <queue> |  | ||||||
| #include <vector> |  | ||||||
| 
 |  | ||||||
| using Value = size_t; |  | ||||||
| 
 |  | ||||||
| // ----------------------------------------------------------------------------
 |  | ||||||
| // 1) SearchNode: store partial assignment and next factor to expand
 |  | ||||||
| // ----------------------------------------------------------------------------
 |  | ||||||
| struct SearchNode { |  | ||||||
|   DiscreteValues assignment; |  | ||||||
|   double error; |  | ||||||
|   double bound; |  | ||||||
|   int nextConditional;  // index into conditionals
 |  | ||||||
| 
 |  | ||||||
|   /// if nextConditional < 0, we've assigned everything.
 |  | ||||||
|   bool isComplete() const { return nextConditional < 0; } |  | ||||||
| 
 |  | ||||||
|   /// lower bound on final error for unassigned variables. Stub=0.
 |  | ||||||
|   double computeBound() const { |  | ||||||
|     // Real code might do partial factor analysis or heuristics.
 |  | ||||||
|     return 0.0; |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// Expand this node by assigning the next variable
 |  | ||||||
|   SearchNode expand(const DiscreteConditional& conditional, |  | ||||||
|                     const DiscreteValues& fa) const { |  | ||||||
|     // Combine the new frontal assignment with the current partial assignment
 |  | ||||||
|     SearchNode child; |  | ||||||
|     child.assignment = assignment; |  | ||||||
|     for (auto& kv : fa) { |  | ||||||
|       child.assignment[kv.first] = kv.second; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // Compute the incremental error for this factor
 |  | ||||||
|     child.error = error + conditional.error(child.assignment); |  | ||||||
| 
 |  | ||||||
|     // Compute new bound
 |  | ||||||
|     child.bound = child.error + computeBound(); |  | ||||||
| 
 |  | ||||||
|     // Next factor index
 |  | ||||||
|     child.nextConditional = nextConditional - 1; |  | ||||||
| 
 |  | ||||||
|     return child; |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   friend std::ostream& operator<<(std::ostream& os, const SearchNode& sn) { |  | ||||||
|     os << "[ error=" << sn.error << " bound=" << sn.bound |  | ||||||
|        << " nextConditional=" << sn.nextConditional << " assignment={" |  | ||||||
|        << sn.assignment << "}]"; |  | ||||||
|     return os; |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| // ----------------------------------------------------------------------------
 |  | ||||||
| // 2) Priority functor to make a min-heap by bound
 |  | ||||||
| // ----------------------------------------------------------------------------
 |  | ||||||
| struct CompareByBound { |  | ||||||
|   bool operator()(const SearchNode& a, const SearchNode& b) const { |  | ||||||
|     return a.bound > b.bound;  // smallest bound -> highest priority
 |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| // ----------------------------------------------------------------------------
 |  | ||||||
| // 4) A Solution
 |  | ||||||
| // ----------------------------------------------------------------------------
 |  | ||||||
| struct Solution { |  | ||||||
|   double error; |  | ||||||
|   DiscreteValues assignment; |  | ||||||
|   Solution(double err, const DiscreteValues& assign) |  | ||||||
|       : error(err), assignment(assign) {} |  | ||||||
|   friend std::ostream& operator<<(std::ostream& os, const Solution& sn) { |  | ||||||
|     os << "[ error=" << sn.error << " assignment={" << sn.assignment << "}]"; |  | ||||||
|     return os; |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| struct CompareByError { |  | ||||||
|   bool operator()(const Solution& a, const Solution& b) const { |  | ||||||
|     return a.error < b.error; |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| // Define the Solutions class
 |  | ||||||
| class Solutions { |  | ||||||
|  private: |  | ||||||
|   size_t maxSize_; |  | ||||||
|   std::priority_queue<Solution, std::vector<Solution>, CompareByError> pq_; |  | ||||||
| 
 |  | ||||||
|  public: |  | ||||||
|   Solutions(size_t maxSize) : maxSize_(maxSize) {} |  | ||||||
| 
 |  | ||||||
|   /// Add a solution to the priority queue, possibly evicting the worst one.
 |  | ||||||
|   /// Return true if we added the solution.
 |  | ||||||
|   bool maybeAdd(double error, const DiscreteValues& assignment) { |  | ||||||
|     const bool full = pq_.size() == maxSize_; |  | ||||||
|     if (full && error >= pq_.top().error) return false; |  | ||||||
|     if (full) pq_.pop(); |  | ||||||
|     pq_.emplace(error, assignment); |  | ||||||
|     return true; |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// Check if we have any solutions
 |  | ||||||
|   bool empty() const { return pq_.empty(); } |  | ||||||
| 
 |  | ||||||
|   // Method to print all solutions
 |  | ||||||
|   void print() const { |  | ||||||
|     auto pq = pq_; |  | ||||||
|     while (!pq.empty()) { |  | ||||||
|       const Solution& best = pq.top(); |  | ||||||
|       std::cout << "Error: " << best.error << ", Values: " << best.assignment |  | ||||||
|                 << std::endl; |  | ||||||
|       pq.pop(); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// Check if (partial) solution with given bound can be pruned. If we have
 |  | ||||||
|   /// room, we never prune. Otherwise, prune if lower bound on error is worse
 |  | ||||||
|   /// than our current worst error.
 |  | ||||||
|   bool prune(double bound) const { |  | ||||||
|     if (pq_.size() < maxSize_) return false; |  | ||||||
|     double worstError = pq_.top().error; |  | ||||||
|     return (bound >= worstError); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   // Method to extract solutions in ascending order of error
 |  | ||||||
|   std::vector<Solution> extractSolutions() { |  | ||||||
|     std::vector<Solution> result; |  | ||||||
|     while (!pq_.empty()) { |  | ||||||
|       result.push_back(pq_.top()); |  | ||||||
|       pq_.pop(); |  | ||||||
|     } |  | ||||||
|     std::sort( |  | ||||||
|         result.begin(), result.end(), |  | ||||||
|         [](const Solution& a, const Solution& b) { return a.error < b.error; }); |  | ||||||
|     return result; |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| /**
 |  | ||||||
|  * BestKSearch: Search for the K best solutions. |  | ||||||
|  */ |  | ||||||
| class BestKSearch { |  | ||||||
|  public: |  | ||||||
|   /**
 |  | ||||||
|    * Construct from a DiscreteBayesNet and K. |  | ||||||
|    */ |  | ||||||
|   BestKSearch(const DiscreteBayesNet& bayesNet, size_t K) |  | ||||||
|       : bayesNet_(bayesNet), solutions_(K) { |  | ||||||
|     // Copy out the conditionals
 |  | ||||||
|     for (auto& factor : bayesNet_) { |  | ||||||
|       conditionals_.push_back(factor); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // Create the root node: no variables assigned, nextConditional = last.
 |  | ||||||
|     SearchNode root{ |  | ||||||
|         .assignment = DiscreteValues(), |  | ||||||
|         .error = 0.0, |  | ||||||
|         .nextConditional = static_cast<int>(conditionals_.size()) - 1}; |  | ||||||
|     root.bound = root.computeBound(); |  | ||||||
|     std::cout << "Root: " << root << std::endl; |  | ||||||
|     expansions_.push(root); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /**
 |  | ||||||
|    * @brief Search for the K best solutions. |  | ||||||
|    * |  | ||||||
|    * This method performs a search to find the K best solutions for the given |  | ||||||
|    * DiscreteBayesNet. It uses a priority queue to manage the search nodes, |  | ||||||
|    * expanding nodes with the smallest bound first. The search continues until |  | ||||||
|    * all possible nodes have been expanded or pruned. |  | ||||||
|    * |  | ||||||
|    * @return A vector of the K best solutions found during the search. |  | ||||||
|    */ |  | ||||||
|   std::vector<Solution> run() { |  | ||||||
|     size_t numExpansions = 0; |  | ||||||
|     while (!expansions_.empty()) { |  | ||||||
|       expandNextNode(); |  | ||||||
|       numExpansions++; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     std::cout << "Expansions: " << numExpansions << std::endl; |  | ||||||
| 
 |  | ||||||
|     // Extract solutions from bestSolutions in ascending order of error
 |  | ||||||
|     return solutions_.extractSolutions(); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|  private: |  | ||||||
|   //
 |  | ||||||
|   void expandNextNode() { |  | ||||||
|     // Pop the partial assignment with the smallest bound
 |  | ||||||
|     SearchNode current = expansions_.top(); |  | ||||||
|     expansions_.pop(); |  | ||||||
|     std::cout << "Expanding: " << current << std::endl; |  | ||||||
| 
 |  | ||||||
|     // If we already have K solutions, prune if we cannot beat the worst one.
 |  | ||||||
|     if (solutions_.prune(current.bound)) { |  | ||||||
|       std::cout << "Pruning: bound=" << current.bound << std::endl; |  | ||||||
|       return; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // Check if we have a complete assignment
 |  | ||||||
|     if (current.isComplete()) { |  | ||||||
|       const bool added = solutions_.maybeAdd(current.error, current.assignment); |  | ||||||
|       if (added) { |  | ||||||
|         std::cout << "Best solutions so far:" << std::endl; |  | ||||||
|         solutions_.print(); |  | ||||||
|       } |  | ||||||
|       return; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // Expand on the next factor
 |  | ||||||
|     const auto& conditional = conditionals_[current.nextConditional]; |  | ||||||
| 
 |  | ||||||
|     for (auto& fa : conditional->frontalAssignments()) { |  | ||||||
|       std::cout << "Frontal assignment: " << fa << std::endl; |  | ||||||
|       auto childNode = current.expand(*conditional, fa); |  | ||||||
| 
 |  | ||||||
|       // Again, prune if we cannot beat the worst solution
 |  | ||||||
|       if (solutions_.prune(current.bound)) { |  | ||||||
|         std::cout << "Pruning: bound=" << childNode.bound << std::endl; |  | ||||||
|         continue; |  | ||||||
|       } |  | ||||||
| 
 |  | ||||||
|       expansions_.push(childNode); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   const DiscreteBayesNet& bayesNet_; |  | ||||||
|   std::vector<std::shared_ptr<DiscreteConditional>> conditionals_; |  | ||||||
|   std::priority_queue<SearchNode, std::vector<SearchNode>, CompareByBound> |  | ||||||
|       expansions_; |  | ||||||
|   Solutions solutions_; |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| // ----------------------------------------------------------------------------
 |  | ||||||
| // Example “Unit Tests” (trivial stubs)
 |  | ||||||
| // ----------------------------------------------------------------------------
 |  | ||||||
| 
 |  | ||||||
| TEST(DiscreteBayesNet, EmptyKBest) { |  | ||||||
|   DiscreteBayesNet net;  // no factors
 |  | ||||||
|   BestKSearch search(net, 3); |  | ||||||
|   auto solutions = search.run(); |  | ||||||
|   // Expect one solution with empty assignment, error=0
 |  | ||||||
|   EXPECT_LONGS_EQUAL(1, solutions.size()); |  | ||||||
|   EXPECT_DOUBLES_EQUAL(0, std::fabs(solutions[0].error), 1e-9); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| TEST(DiscreteBayesNet, AsiaKBest) { |  | ||||||
|   DiscreteBayesNet asia = constructAsiaExample(); |  | ||||||
|   BestKSearch search(asia, 4); |  | ||||||
|   auto solutions = search.run(); |  | ||||||
|   EXPECT(!solutions.empty()); |  | ||||||
|   // Regression test: check the first solution
 |  | ||||||
|   EXPECT_DOUBLES_EQUAL(1.236627, std::fabs(solutions[0].error), 1e-5); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /* ************************************************************************* */ | /* ************************************************************************* */ | ||||||
| int main() { | int main() { | ||||||
|   TestResult tr; |   TestResult tr; | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue