implement dim() for MixtureFactor

release/4.3a0
Varun Agrawal 2023-01-04 02:55:06 -05:00
parent 34daecd7a4
commit 7dd4bc990a
2 changed files with 25 additions and 6 deletions

View File

@ -162,14 +162,20 @@ class MixtureFactor : public HybridFactor {
}
/// Error for HybridValues is not provided for nonlinear hybrid factor.
double error(const HybridValues &values) const override {
double error(const HybridValues& values) const override {
throw std::runtime_error(
"MixtureFactor::error(HybridValues) not implemented.");
}
/**
* @brief Get the dimension of the factor (number of rows on linearization).
* Returns the dimension of the first component factor.
* @return size_t
*/
size_t dim() const {
// TODO(Varun)
throw std::runtime_error("MixtureFactor::dim not implemented.");
const auto assignments = DiscreteValues::CartesianProduct(discreteKeys_);
auto factor = factors_(assignments.at(0));
return factor->dim();
}
/// Testable

View File

@ -70,8 +70,7 @@ MixtureFactor
}
/* ************************************************************************* */
// Test the error of the MixtureFactor
TEST(MixtureFactor, Error) {
static MixtureFactor getMixtureFactor() {
DiscreteKey m1(1, 2);
double between0 = 0.0;
@ -86,7 +85,13 @@ TEST(MixtureFactor, Error) {
boost::make_shared<BetweenFactor<double>>(X(1), X(2), between1, model);
std::vector<NonlinearFactor::shared_ptr> factors{f0, f1};
MixtureFactor mixtureFactor({X(1), X(2)}, {m1}, factors);
return MixtureFactor({X(1), X(2)}, {m1}, factors);
}
/* ************************************************************************* */
// Test the error of the MixtureFactor
TEST(MixtureFactor, Error) {
auto mixtureFactor = getMixtureFactor();
Values continuousValues;
continuousValues.insert<double>(X(1), 0);
@ -94,6 +99,7 @@ TEST(MixtureFactor, Error) {
AlgebraicDecisionTree<Key> error_tree = mixtureFactor.error(continuousValues);
DiscreteKey m1(1, 2);
std::vector<DiscreteKey> discrete_keys = {m1};
std::vector<double> errors = {0.5, 0};
AlgebraicDecisionTree<Key> expected_error(discrete_keys, errors);
@ -101,6 +107,13 @@ TEST(MixtureFactor, Error) {
EXPECT(assert_equal(expected_error, error_tree));
}
/* ************************************************************************* */
// Test dim of the MixtureFactor
TEST(MixtureFactor, Dim) {
auto mixtureFactor = getMixtureFactor();
EXPECT_LONGS_EQUAL(1, mixtureFactor.dim());
}
/* ************************************************************************* */
int main() {
TestResult tr;