small improvements
parent
d3901be1c1
commit
5fa04d7622
|
@ -24,13 +24,13 @@
|
|||
#include <gtsam/hybrid/HybridValues.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
|
||||
using namespace std;
|
||||
using std::pair;
|
||||
|
@ -45,9 +45,7 @@ template class GTSAM_EXPORT
|
|||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
||||
const DecisionTreeFactor& f)
|
||||
: BaseFactor(f / (*std::dynamic_pointer_cast<DecisionTreeFactor>(
|
||||
f.sum(nrFrontals)))),
|
||||
BaseConditional(nrFrontals) {}
|
||||
: BaseFactor(f / f.sum(nrFrontals)), BaseConditional(nrFrontals) {}
|
||||
|
||||
/* ************************************************************************** */
|
||||
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||
|
|
|
@ -128,7 +128,7 @@ namespace gtsam {
|
|||
auto denominator = product.max(product.size());
|
||||
|
||||
// Normalize the product factor to prevent underflow.
|
||||
product = product / (*denominator);
|
||||
product = product / denominator;
|
||||
|
||||
return product;
|
||||
}
|
||||
|
|
|
@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
|
|||
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
||||
|
||||
// Normalize newFactor by max for comparison with expected
|
||||
auto normalizer = newFactor.max(newFactor.size());
|
||||
auto denominator = newFactor.max(newFactor.size());
|
||||
|
||||
newFactor = newFactor / *normalizer;
|
||||
newFactor = newFactor / denominator;
|
||||
|
||||
// Check Conditional
|
||||
CHECK(conditional);
|
||||
|
@ -131,9 +131,9 @@ TEST(DiscreteFactorGraph, test) {
|
|||
CHECK(&newFactor);
|
||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||
// Normalize by max.
|
||||
normalizer = expectedFactor.max(expectedFactor.size());
|
||||
// Ensure normalizer is correct.
|
||||
expectedFactor = expectedFactor / *normalizer;
|
||||
denominator = expectedFactor.max(expectedFactor.size());
|
||||
// Ensure denominator is correct.
|
||||
expectedFactor = expectedFactor / denominator;
|
||||
EXPECT(assert_equal(expectedFactor, newFactor));
|
||||
|
||||
// Test using elimination tree
|
||||
|
|
|
@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) {
|
|||
TEST(TableFactor, Empty) {
|
||||
DiscreteKey X(1, 2);
|
||||
|
||||
TableFactor single = *TableFactor({X}, "1 1").sum(1);
|
||||
auto single = TableFactor({X}, "1 1").sum(1);
|
||||
// Should not throw a segfault
|
||||
EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1),
|
||||
single.toDecisionTreeFactor()));
|
||||
auto expected_single = DecisionTreeFactor(X, "1 1").sum(1);
|
||||
EXPECT(assert_equal(expected_single->toDecisionTreeFactor(),
|
||||
single->toDecisionTreeFactor()));
|
||||
|
||||
TableFactor empty = *TableFactor({X}, "0 0").sum(1);
|
||||
auto empty = TableFactor({X}, "0 0").sum(1);
|
||||
// Should not throw a segfault
|
||||
EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1),
|
||||
empty.toDecisionTreeFactor()));
|
||||
auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1);
|
||||
EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(),
|
||||
empty->toDecisionTreeFactor()));
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
Loading…
Reference in New Issue