Make compose and convertFrom static

release/4.3a0
Frank Dellaert 2024-10-15 15:43:03 +09:00
parent e6dfa7be99
commit f98b9223e8
2 changed files with 20 additions and 16 deletions

View File

@ -580,7 +580,7 @@ namespace gtsam {
template <typename L, typename Y> template <typename L, typename Y>
template <typename Iterator> template <typename Iterator>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
Iterator begin, Iterator end, const L& label) const { Iterator begin, Iterator end, const L& label) {
// find highest label among branches // find highest label among branches
std::optional<L> highestLabel; std::optional<L> highestLabel;
size_t nrChoices = 0; size_t nrChoices = 0;
@ -703,12 +703,9 @@ namespace gtsam {
template <typename M, typename X> template <typename M, typename X>
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom( typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
const typename DecisionTree<M, X>::NodePtr& f, const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M, std::function<L(const M&)> L_of_M, std::function<Y(const X&)> Y_of_X) {
std::function<Y(const X&)> Y_of_X) const {
using LY = DecisionTree<L, Y>; using LY = DecisionTree<L, Y>;
// Ugliness below because apparently we can't have templated virtual
// functions.
// If leaf, apply unary conversion "op" and create a unique leaf. // If leaf, apply unary conversion "op" and create a unique leaf.
using MXLeaf = typename DecisionTree<M, X>::Leaf; using MXLeaf = typename DecisionTree<M, X>::Leaf;
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) { if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
@ -718,19 +715,27 @@ namespace gtsam {
// Check if Choice // Check if Choice
using MXChoice = typename DecisionTree<M, X>::Choice; using MXChoice = typename DecisionTree<M, X>::Choice;
auto choice = std::dynamic_pointer_cast<const MXChoice>(f); auto choice = std::dynamic_pointer_cast<const MXChoice>(f);
if (!choice) throw std::invalid_argument( if (!choice)
"DecisionTree::convertFrom: Invalid NodePtr"); throw std::invalid_argument("DecisionTree::convertFrom: Invalid NodePtr");
// get new label // get new label
const M oldLabel = choice->label(); const M oldLabel = choice->label();
const L newLabel = L_of_M(oldLabel); const L newLabel = L_of_M(oldLabel);
// put together via Shannon expansion otherwise not sorted. // Shannon expansion in this context involves:
// 1. Creating separate subtrees (functions) for each possible value of the new label.
// 2. Combining these subtrees using the 'compose' method, which implements the expansion.
// This approach guarantees that the resulting tree maintains the correct variable ordering
// based on the new labels (L) after translation from the old labels (M).
// Simply creating a Choice node here would not work because it wouldn't account for the
// potentially new ordering of variables resulting from the label translation,
// which is crucial for maintaining consistency and efficiency in the converted tree.
std::vector<LY> functions; std::vector<LY> functions;
for (auto&& branch : choice->branches()) { for (auto&& branch : choice->branches()) {
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X)); functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
} }
return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel)); return Choice::Unique(
LY::compose(functions.begin(), functions.end(), newLabel));
} }
/****************************************************************************/ /****************************************************************************/
@ -740,9 +745,8 @@ namespace gtsam {
* *
* NOTE: We differentiate between leaves and assignments. Concretely, a 3 * NOTE: We differentiate between leaves and assignments. Concretely, a 3
* binary variable tree will have 2^3=8 assignments, but based on pruning, it * binary variable tree will have 2^3=8 assignments, but based on pruning, it
* can have less than 8 leaves. For example, if a tree has all assignment * can have <8 leaves. For example, if a tree has all assignment values as 1,
* values as 1, then pruning will cause the tree to have only 1 leaf yet 8 * then pruning will cause the tree to have only 1 leaf yet 8 assignments.
* assignments.
*/ */
template <typename L, typename Y> template <typename L, typename Y>
struct Visit { struct Visit {

View File

@ -176,9 +176,9 @@ namespace gtsam {
* @return NodePtr * @return NodePtr
*/ */
template <typename M, typename X> template <typename M, typename X>
NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f, static NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
std::function<L(const M&)> L_of_M, std::function<L(const M&)> L_of_M,
std::function<Y(const X&)> Y_of_X) const; std::function<Y(const X&)> Y_of_X);
public: public:
/// @name Standard Constructors /// @name Standard Constructors
@ -402,7 +402,7 @@ namespace gtsam {
// internal use only // internal use only
template<typename Iterator> NodePtr template<typename Iterator> NodePtr
compose(Iterator begin, Iterator end, const L& label) const; static compose(Iterator begin, Iterator end, const L& label);
/// @} /// @}