small improvements
parent
d3901be1c1
commit
5fa04d7622
|
@ -24,13 +24,13 @@
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cassert>
|
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using std::pair;
|
using std::pair;
|
||||||
|
@ -45,9 +45,7 @@ template class GTSAM_EXPORT
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
|
||||||
const DecisionTreeFactor& f)
|
const DecisionTreeFactor& f)
|
||||||
: BaseFactor(f / (*std::dynamic_pointer_cast<DecisionTreeFactor>(
|
: BaseFactor(f / f.sum(nrFrontals)), BaseConditional(nrFrontals) {}
|
||||||
f.sum(nrFrontals)))),
|
|
||||||
BaseConditional(nrFrontals) {}
|
|
||||||
|
|
||||||
/* ************************************************************************** */
|
/* ************************************************************************** */
|
||||||
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
|
||||||
|
|
|
@ -128,7 +128,7 @@ namespace gtsam {
|
||||||
auto denominator = product.max(product.size());
|
auto denominator = product.max(product.size());
|
||||||
|
|
||||||
// Normalize the product factor to prevent underflow.
|
// Normalize the product factor to prevent underflow.
|
||||||
product = product / (*denominator);
|
product = product / denominator;
|
||||||
|
|
||||||
return product;
|
return product;
|
||||||
}
|
}
|
||||||
|
|
|
@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
|
||||||
|
|
||||||
// Normalize newFactor by max for comparison with expected
|
// Normalize newFactor by max for comparison with expected
|
||||||
auto normalizer = newFactor.max(newFactor.size());
|
auto denominator = newFactor.max(newFactor.size());
|
||||||
|
|
||||||
newFactor = newFactor / *normalizer;
|
newFactor = newFactor / denominator;
|
||||||
|
|
||||||
// Check Conditional
|
// Check Conditional
|
||||||
CHECK(conditional);
|
CHECK(conditional);
|
||||||
|
@ -131,9 +131,9 @@ TEST(DiscreteFactorGraph, test) {
|
||||||
CHECK(&newFactor);
|
CHECK(&newFactor);
|
||||||
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
|
||||||
// Normalize by max.
|
// Normalize by max.
|
||||||
normalizer = expectedFactor.max(expectedFactor.size());
|
denominator = expectedFactor.max(expectedFactor.size());
|
||||||
// Ensure normalizer is correct.
|
// Ensure denominator is correct.
|
||||||
expectedFactor = expectedFactor / *normalizer;
|
expectedFactor = expectedFactor / denominator;
|
||||||
EXPECT(assert_equal(expectedFactor, newFactor));
|
EXPECT(assert_equal(expectedFactor, newFactor));
|
||||||
|
|
||||||
// Test using elimination tree
|
// Test using elimination tree
|
||||||
|
|
|
@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) {
|
||||||
TEST(TableFactor, Empty) {
|
TEST(TableFactor, Empty) {
|
||||||
DiscreteKey X(1, 2);
|
DiscreteKey X(1, 2);
|
||||||
|
|
||||||
TableFactor single = *TableFactor({X}, "1 1").sum(1);
|
auto single = TableFactor({X}, "1 1").sum(1);
|
||||||
// Should not throw a segfault
|
// Should not throw a segfault
|
||||||
EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1),
|
auto expected_single = DecisionTreeFactor(X, "1 1").sum(1);
|
||||||
single.toDecisionTreeFactor()));
|
EXPECT(assert_equal(expected_single->toDecisionTreeFactor(),
|
||||||
|
single->toDecisionTreeFactor()));
|
||||||
|
|
||||||
TableFactor empty = *TableFactor({X}, "0 0").sum(1);
|
auto empty = TableFactor({X}, "0 0").sum(1);
|
||||||
// Should not throw a segfault
|
// Should not throw a segfault
|
||||||
EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1),
|
auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1);
|
||||||
empty.toDecisionTreeFactor()));
|
EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(),
|
||||||
|
empty->toDecisionTreeFactor()));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
|
Loading…
Reference in New Issue