Merge pull request #1867 from borglab/normalize-potentials
commit
9a146eb942
|
@ -182,6 +182,21 @@ namespace gtsam {
|
||||||
this->root_ = DecisionTree<L, double>::convertFrom(other.root_, L_of_M, op);
|
this->root_ = DecisionTree<L, double>::convertFrom(other.root_, L_of_M, op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Create from an arbitrary DecisionTree<L, X> by operating on it
|
||||||
|
* with a functional `f`.
|
||||||
|
*
|
||||||
|
* @tparam X The type of the leaf of the original DecisionTree
|
||||||
|
* @tparam Func Type signature of functional `f`.
|
||||||
|
* @param other The original DecisionTree from which the
|
||||||
|
* AlgbraicDecisionTree is constructed.
|
||||||
|
* @param f Functional used to operate on
|
||||||
|
* the leaves of the input DecisionTree.
|
||||||
|
*/
|
||||||
|
template <typename X, typename Func>
|
||||||
|
AlgebraicDecisionTree(const DecisionTree<L, X>& other, Func f)
|
||||||
|
: Base(other, f) {}
|
||||||
|
|
||||||
/** sum */
|
/** sum */
|
||||||
AlgebraicDecisionTree operator+(const AlgebraicDecisionTree& g) const {
|
AlgebraicDecisionTree operator+(const AlgebraicDecisionTree& g) const {
|
||||||
return this->apply(g, &Ring::add);
|
return this->apply(g, &Ring::add);
|
||||||
|
@ -219,12 +234,9 @@ namespace gtsam {
|
||||||
* @brief Helper method to perform normalization such that all leaves in the
|
* @brief Helper method to perform normalization such that all leaves in the
|
||||||
* tree sum to 1
|
* tree sum to 1
|
||||||
*
|
*
|
||||||
* @param sum
|
|
||||||
* @return AlgebraicDecisionTree
|
* @return AlgebraicDecisionTree
|
||||||
*/
|
*/
|
||||||
AlgebraicDecisionTree normalize(double sum) const {
|
AlgebraicDecisionTree normalize() const { return (*this) / this->sum(); }
|
||||||
return this->apply([&sum](const double& x) { return x / sum; });
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find the minimum values amongst all leaves
|
/// Find the minimum values amongst all leaves
|
||||||
double min() const {
|
double min() const {
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
# Install headers
|
# Install headers
|
||||||
set(subdir discrete)
|
set(subdir discrete)
|
||||||
file(GLOB discrete_headers "*.h")
|
file(GLOB discrete_headers "*.h")
|
||||||
# FIXME: exclude headers
|
|
||||||
install(FILES ${discrete_headers} DESTINATION include/gtsam/discrete)
|
install(FILES ${discrete_headers} DESTINATION include/gtsam/discrete)
|
||||||
|
|
||||||
# Add all tests
|
# Add all tests
|
||||||
|
|
|
@ -562,7 +562,7 @@ TEST(ADT, Sum) {
|
||||||
TEST(ADT, Normalize) {
|
TEST(ADT, Normalize) {
|
||||||
ADT a = exampleADT();
|
ADT a = exampleADT();
|
||||||
double sum = a.sum();
|
double sum = a.sum();
|
||||||
auto actual = a.normalize(sum);
|
auto actual = a.normalize();
|
||||||
|
|
||||||
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
DiscreteKey A(0, 2), B(1, 3), C(2, 2);
|
||||||
DiscreteKeys keys = DiscreteKeys{A, B, C};
|
DiscreteKeys keys = DiscreteKeys{A, B, C};
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
# Install headers
|
# Install headers
|
||||||
set(subdir hybrid)
|
set(subdir hybrid)
|
||||||
file(GLOB hybrid_headers "*.h")
|
file(GLOB hybrid_headers "*.h")
|
||||||
# FIXME: exclude headers
|
|
||||||
install(FILES ${hybrid_headers} DESTINATION include/gtsam/hybrid)
|
install(FILES ${hybrid_headers} DESTINATION include/gtsam/hybrid)
|
||||||
|
|
||||||
# Add all tests
|
# Add all tests
|
||||||
|
|
|
@ -189,8 +189,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
|
||||||
auto errorFunc = [&continuousValues](const GaussianFactorValuePair& pair) {
|
auto errorFunc = [&continuousValues](const GaussianFactorValuePair& pair) {
|
||||||
return PotentiallyPrunedComponentError(pair, continuousValues);
|
return PotentiallyPrunedComponentError(pair, continuousValues);
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> error_tree(factors_, errorFunc);
|
return {factors_, errorFunc};
|
||||||
return error_tree;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
@ -229,13 +229,18 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
/// Take negative log-values, shift them so that the minimum value is 0, and
|
/**
|
||||||
/// then exponentiate to create a DecisionTreeFactor (not normalized yet!).
|
* @brief Take negative log-values, shift them so that the minimum value is 0,
|
||||||
|
* and then exponentiate to create a DecisionTreeFactor (not normalized yet!).
|
||||||
|
*
|
||||||
|
* @param errors DecisionTree of (unnormalized) errors.
|
||||||
|
* @return DecisionTreeFactor::shared_ptr
|
||||||
|
*/
|
||||||
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
|
static DecisionTreeFactor::shared_ptr DiscreteFactorFromErrors(
|
||||||
const DiscreteKeys &discreteKeys,
|
const DiscreteKeys &discreteKeys,
|
||||||
const AlgebraicDecisionTree<Key> &errors) {
|
const AlgebraicDecisionTree<Key> &errors) {
|
||||||
double min_log = errors.min();
|
double min_log = errors.min();
|
||||||
AlgebraicDecisionTree<Key> potentials = DecisionTree<Key, double>(
|
AlgebraicDecisionTree<Key> potentials(
|
||||||
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
|
errors, [&min_log](const double x) { return exp(-(x - min_log)); });
|
||||||
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
|
return std::make_shared<DecisionTreeFactor>(discreteKeys, potentials);
|
||||||
}
|
}
|
||||||
|
@ -258,7 +263,7 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
if (!factor) return std::numeric_limits<double>::infinity();
|
if (!factor) return std::numeric_limits<double>::infinity();
|
||||||
return scalar + factor->error(kEmpty);
|
return scalar + factor->error(kEmpty);
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> errors(gmf->factors(), calculateError);
|
AlgebraicDecisionTree<Key> errors(gmf->factors(), calculateError);
|
||||||
dfg.push_back(DiscreteFactorFromErrors(gmf->discreteKeys(), errors));
|
dfg.push_back(DiscreteFactorFromErrors(gmf->discreteKeys(), errors));
|
||||||
|
|
||||||
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
} else if (auto orphan = dynamic_pointer_cast<OrphanWrapper>(f)) {
|
||||||
|
@ -307,7 +312,7 @@ static std::shared_ptr<Factor> createDiscreteFactor(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DecisionTree<Key, double> errors(eliminationResults, calculateError);
|
AlgebraicDecisionTree<Key> errors(eliminationResults, calculateError);
|
||||||
return DiscreteFactorFromErrors(discreteSeparator, errors);
|
return DiscreteFactorFromErrors(discreteSeparator, errors);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -100,8 +100,7 @@ AlgebraicDecisionTree<Key> HybridNonlinearFactor::errorTree(
|
||||||
auto [factor, val] = f;
|
auto [factor, val] = f;
|
||||||
return factor->error(continuousValues) + val;
|
return factor->error(continuousValues) + val;
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> result(factors_, errorFunc);
|
return {factors_, errorFunc};
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
@ -333,7 +333,7 @@ TEST(HybridBayesNet, Switching) {
|
||||||
CHECK(phi_x1);
|
CHECK(phi_x1);
|
||||||
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
|
EXPECT_LONGS_EQUAL(1, phi_x1->keys().size()); // m0
|
||||||
// We can't really check the error of the decision tree factor phi_x1, because
|
// We can't really check the error of the decision tree factor phi_x1, because
|
||||||
// the continuous factor whose error(kEmpty) we need is not available..
|
// the continuous factor whose error(kEmpty) we need is not available.
|
||||||
|
|
||||||
// Now test full elimination of the graph:
|
// Now test full elimination of the graph:
|
||||||
auto hybridBayesNet = graph.eliminateSequential();
|
auto hybridBayesNet = graph.eliminateSequential();
|
||||||
|
|
|
@ -128,7 +128,10 @@ TEST(HybridGaussianProductFactor, AsProductFactor) {
|
||||||
EXPECT(actual.first.at(0) == f10);
|
EXPECT(actual.first.at(0) == f10);
|
||||||
EXPECT_DOUBLES_EQUAL(10, actual.second, 1e-9);
|
EXPECT_DOUBLES_EQUAL(10, actual.second, 1e-9);
|
||||||
|
|
||||||
// TODO(Frank): when killed hiding, f11 should also be there
|
mode[m1.first] = 1;
|
||||||
|
actual = product(mode);
|
||||||
|
EXPECT(actual.first.at(0) == f11);
|
||||||
|
EXPECT_DOUBLES_EQUAL(11, actual.second, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -145,7 +148,10 @@ TEST(HybridGaussianProductFactor, AddOne) {
|
||||||
EXPECT(actual.first.at(0) == f10);
|
EXPECT(actual.first.at(0) == f10);
|
||||||
EXPECT_DOUBLES_EQUAL(10, actual.second, 1e-9);
|
EXPECT_DOUBLES_EQUAL(10, actual.second, 1e-9);
|
||||||
|
|
||||||
// TODO(Frank): when killed hiding, f11 should also be there
|
mode[m1.first] = 1;
|
||||||
|
actual = product(mode);
|
||||||
|
EXPECT(actual.first.at(0) == f11);
|
||||||
|
EXPECT_DOUBLES_EQUAL(11, actual.second, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************* */
|
/* ************************************************************************* */
|
||||||
|
@ -166,9 +172,8 @@ TEST(HybridGaussianProductFactor, AddTwo) {
|
||||||
EXPECT_DOUBLES_EQUAL(10 + 20, actual00.second, 1e-9);
|
EXPECT_DOUBLES_EQUAL(10 + 20, actual00.second, 1e-9);
|
||||||
|
|
||||||
auto actual12 = product({{M(1), 1}, {M(2), 2}});
|
auto actual12 = product({{M(1), 1}, {M(2), 2}});
|
||||||
// TODO(Frank): when killed hiding, these should also equal:
|
EXPECT(actual12.first.at(0) == f11);
|
||||||
// EXPECT(actual12.first.at(0) == f11);
|
EXPECT(actual12.first.at(1) == f22);
|
||||||
// EXPECT(actual12.first.at(1) == f22);
|
|
||||||
EXPECT_DOUBLES_EQUAL(11 + 22, actual12.second, 1e-9);
|
EXPECT_DOUBLES_EQUAL(11 + 22, actual12.second, 1e-9);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -973,8 +973,6 @@ TEST(HybridNonlinearFactorGraph, DifferentMeans) {
|
||||||
VectorValues cont0 = bn->optimize(dv0);
|
VectorValues cont0 = bn->optimize(dv0);
|
||||||
double error0 = bn->error(HybridValues(cont0, dv0));
|
double error0 = bn->error(HybridValues(cont0, dv0));
|
||||||
|
|
||||||
// TODO(Varun) Perform importance sampling to estimate error?
|
|
||||||
|
|
||||||
// regression
|
// regression
|
||||||
EXPECT_DOUBLES_EQUAL(0.69314718056, error0, 1e-9);
|
EXPECT_DOUBLES_EQUAL(0.69314718056, error0, 1e-9);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue