From a68da21527760daff64161f3feffc6cc1d46d1b1 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sat, 7 Dec 2024 11:02:30 -0500 Subject: [PATCH] operator* version which accepts DiscreteFactor --- gtsam/discrete/DecisionTreeFactor.cpp | 11 +++++++++++ gtsam/discrete/DecisionTreeFactor.h | 5 ++++- gtsam/discrete/DiscreteFactor.h | 7 ++++--- gtsam/discrete/TableFactor.cpp | 9 +++++++-- gtsam/discrete/TableFactor.h | 5 +++-- 5 files changed, 29 insertions(+), 8 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index 9ec3b0ac5..e53f8cb90 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -82,6 +82,17 @@ namespace gtsam { ADT::print("", formatter); } + /* ************************************************************************ */ + DiscreteFactor::shared_ptr DecisionTreeFactor::operator*( + const DiscreteFactor::shared_ptr& f) const { + if (auto derived = std::dynamic_pointer_cast(f)) { + return std::make_shared(this->operator*(*derived)); + } else { + throw std::runtime_error( + "Cannot convert DiscreteFactor to DecisionTreeFactor"); + } + } + /* ************************************************************************ */ DecisionTreeFactor DecisionTreeFactor::apply(ADT::Unary op) const { // apply operand diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index f417a38d7..7afbab0b0 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -144,10 +144,13 @@ namespace gtsam { double error(const DiscreteValues& values) const override; /// multiply two factors - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override { + DecisionTreeFactor operator*(const DecisionTreeFactor& f) const { return apply(f, ADT::Ring::mul); } + DiscreteFactor::shared_ptr operator*( + const DiscreteFactor::shared_ptr& f) const override; + static double safe_div(const double& a, const double& b); /// divide by factor f (safely) diff --git a/gtsam/discrete/DiscreteFactor.h b/gtsam/discrete/DiscreteFactor.h index 7d5047ec6..4c486dca8 100644 --- a/gtsam/discrete/DiscreteFactor.h +++ b/gtsam/discrete/DiscreteFactor.h @@ -107,9 +107,10 @@ class GTSAM_EXPORT DiscreteFactor : public Factor { /// Compute error for each assignment and return as a tree virtual AlgebraicDecisionTree errorTree() const; - /// Multiply in a DecisionTreeFactor and return the result as - /// DecisionTreeFactor - virtual DecisionTreeFactor operator*(const DecisionTreeFactor&) const = 0; + /// Multiply in a DiscreteFactor and return the result as + /// DiscreteFactor + virtual DiscreteFactor::shared_ptr operator*( + const DiscreteFactor::shared_ptr&) const = 0; virtual DecisionTreeFactor toDecisionTreeFactor() const = 0; diff --git a/gtsam/discrete/TableFactor.cpp b/gtsam/discrete/TableFactor.cpp index f4e023a4d..7cf520973 100644 --- a/gtsam/discrete/TableFactor.cpp +++ b/gtsam/discrete/TableFactor.cpp @@ -169,8 +169,13 @@ double TableFactor::error(const HybridValues& values) const { } /* ************************************************************************ */ -DecisionTreeFactor TableFactor::operator*(const DecisionTreeFactor& f) const { - return toDecisionTreeFactor() * f; +DiscreteFactor::shared_ptr TableFactor::operator*( + const DiscreteFactor::shared_ptr& f) const { + if (auto derived = std::dynamic_pointer_cast(f)) { + return std::make_shared(this->operator*(*derived)); + } else { + throw std::runtime_error("Cannot convert DiscreteFactor to TableFactor"); + } } /* ************************************************************************ */ diff --git a/gtsam/discrete/TableFactor.h b/gtsam/discrete/TableFactor.h index b988eebad..29cbd5e9b 100644 --- a/gtsam/discrete/TableFactor.h +++ b/gtsam/discrete/TableFactor.h @@ -186,8 +186,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor { return apply(f, Ring::mul); }; - /// multiply with DecisionTreeFactor - DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override; + /// multiply with DiscreteFactor + DiscreteFactor::shared_ptr operator*( + const DiscreteFactor::shared_ptr& f) const override; static double safe_div(const double& a, const double& b);