small improvements

release/4.3a0
Varun Agrawal 2025-01-05 09:08:57 -05:00
parent d3901be1c1
commit 5fa04d7622
4 changed files with 16 additions and 16 deletions

View File

@ -24,13 +24,13 @@
#include <gtsam/hybrid/HybridValues.h>
#include <algorithm>
#include <cassert>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
#include <cassert>
using namespace std;
using std::pair;
@ -45,9 +45,7 @@ template class GTSAM_EXPORT
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f)
: BaseFactor(f / (*std::dynamic_pointer_cast<DecisionTreeFactor>(
f.sum(nrFrontals)))),
BaseConditional(nrFrontals) {}
: BaseFactor(f / f.sum(nrFrontals)), BaseConditional(nrFrontals) {}
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,

View File

@ -128,7 +128,7 @@ namespace gtsam {
auto denominator = product.max(product.size());
// Normalize the product factor to prevent underflow.
product = product / (*denominator);
product = product / denominator;
return product;
}

View File

@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
*std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
// Normalize newFactor by max for comparison with expected
auto normalizer = newFactor.max(newFactor.size());
auto denominator = newFactor.max(newFactor.size());
newFactor = newFactor / *normalizer;
newFactor = newFactor / denominator;
// Check Conditional
CHECK(conditional);
@ -131,9 +131,9 @@ TEST(DiscreteFactorGraph, test) {
CHECK(&newFactor);
DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
// Normalize by max.
normalizer = expectedFactor.max(expectedFactor.size());
// Ensure normalizer is correct.
expectedFactor = expectedFactor / *normalizer;
denominator = expectedFactor.max(expectedFactor.size());
// Ensure denominator is correct.
expectedFactor = expectedFactor / denominator;
EXPECT(assert_equal(expectedFactor, newFactor));
// Test using elimination tree

View File

@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) {
TEST(TableFactor, Empty) {
DiscreteKey X(1, 2);
TableFactor single = *TableFactor({X}, "1 1").sum(1);
auto single = TableFactor({X}, "1 1").sum(1);
// Should not throw a segfault
EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1),
single.toDecisionTreeFactor()));
auto expected_single = DecisionTreeFactor(X, "1 1").sum(1);
EXPECT(assert_equal(expected_single->toDecisionTreeFactor(),
single->toDecisionTreeFactor()));
TableFactor empty = *TableFactor({X}, "0 0").sum(1);
auto empty = TableFactor({X}, "0 0").sum(1);
// Should not throw a segfault
EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1),
empty.toDecisionTreeFactor()));
auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1);
EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(),
empty->toDecisionTreeFactor()));
}
/* ************************************************************************* */