operator* version which accepts DiscreteFactor

release/4.3a0
Varun Agrawal 2024-12-07 11:02:30 -05:00
parent d1d440ad34
commit a68da21527
5 changed files with 29 additions and 8 deletions

View File

@ -82,6 +82,17 @@ namespace gtsam {
ADT::print("", formatter); ADT::print("", formatter);
} }
/* ************************************************************************ */
DiscreteFactor::shared_ptr DecisionTreeFactor::operator*(
const DiscreteFactor::shared_ptr& f) const {
if (auto derived = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
return std::make_shared<DecisionTreeFactor>(this->operator*(*derived));
} else {
throw std::runtime_error(
"Cannot convert DiscreteFactor to DecisionTreeFactor");
}
}
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const { DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const {
// apply operand // apply operand

View File

@ -144,10 +144,13 @@ namespace gtsam {
double error(const DiscreteValues& values) const override; double error(const DiscreteValues& values) const override;
/// multiply two factors /// multiply two factors
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { DecisionTreeFactor operator*(const DecisionTreeFactor& f) const {
return apply(f, ADT::Ring::mul); return apply(f, ADT::Ring::mul);
} }
DiscreteFactor::shared_ptr operator*(
const DiscreteFactor::shared_ptr& f) const override;
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);
/// divide by factor f (safely) /// divide by factor f (safely)

View File

@ -107,9 +107,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {
/// Compute error for each assignment and return as a tree /// Compute error for each assignment and return as a tree
virtual AlgebraicDecisionTree<Key> errorTree() const; virtual AlgebraicDecisionTree<Key> errorTree() const;
/// Multiply in a DecisionTreeFactor and return the result as /// Multiply in a DiscreteFactor and return the result as
/// DecisionTreeFactor /// DiscreteFactor
virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; virtual DiscreteFactor::shared_ptr operator*(
const DiscreteFactor::shared_ptr&) const = 0;
virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;

View File

@ -169,8 +169,13 @@ double TableFactor::error(const HybridValues& values) const {
} }
/* ************************************************************************ */ /* ************************************************************************ */
DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { DiscreteFactor::shared_ptr TableFactor::operator*(
return toDecisionTreeFactor() * f; const DiscreteFactor::shared_ptr& f) const {
if (auto derived = std::dynamic_pointer_cast<TableFactor>(f)) {
return std::make_shared<TableFactor>(this->operator*(*derived));
} else {
throw std::runtime_error("Cannot convert DiscreteFactor to TableFactor");
}
} }
/* ************************************************************************ */ /* ************************************************************************ */

View File

@ -186,8 +186,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, Ring::mul); return apply(f, Ring::mul);
}; };
/// multiply with DecisionTreeFactor /// multiply with DiscreteFactor
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; DiscreteFactor::shared_ptr operator*(
const DiscreteFactor::shared_ptr& f) const override;
static double safe_div(const double& a, const double& b); static double safe_div(const double& a, const double& b);