Make compose and convertFrom static
parent
e6dfa7be99
commit
f98b9223e8
|
@ -580,7 +580,7 @@ namespace gtsam {
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
template <typename Iterator>
|
template <typename Iterator>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::compose(
|
||||||
Iterator begin, Iterator end, const L& label) const {
|
Iterator begin, Iterator end, const L& label) {
|
||||||
// find highest label among branches
|
// find highest label among branches
|
||||||
std::optional<L> highestLabel;
|
std::optional<L> highestLabel;
|
||||||
size_t nrChoices = 0;
|
size_t nrChoices = 0;
|
||||||
|
@ -703,12 +703,9 @@ namespace gtsam {
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
typename DecisionTree<L, Y>::NodePtr DecisionTree<L, Y>::convertFrom(
|
||||||
const typename DecisionTree<M, X>::NodePtr& f,
|
const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
std::function<L(const M&)> L_of_M,
|
std::function<L(const M&)> L_of_M, std::function<Y(const X&)> Y_of_X) {
|
||||||
std::function<Y(const X&)> Y_of_X) const {
|
|
||||||
using LY = DecisionTree<L, Y>;
|
using LY = DecisionTree<L, Y>;
|
||||||
|
|
||||||
// Ugliness below because apparently we can't have templated virtual
|
|
||||||
// functions.
|
|
||||||
// If leaf, apply unary conversion "op" and create a unique leaf.
|
// If leaf, apply unary conversion "op" and create a unique leaf.
|
||||||
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
using MXLeaf = typename DecisionTree<M, X>::Leaf;
|
||||||
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
|
if (auto leaf = std::dynamic_pointer_cast<const MXLeaf>(f)) {
|
||||||
|
@ -718,19 +715,27 @@ namespace gtsam {
|
||||||
// Check if Choice
|
// Check if Choice
|
||||||
using MXChoice = typename DecisionTree<M, X>::Choice;
|
using MXChoice = typename DecisionTree<M, X>::Choice;
|
||||||
auto choice = std::dynamic_pointer_cast<const MXChoice>(f);
|
auto choice = std::dynamic_pointer_cast<const MXChoice>(f);
|
||||||
if (!choice) throw std::invalid_argument(
|
if (!choice)
|
||||||
"DecisionTree::convertFrom: Invalid NodePtr");
|
throw std::invalid_argument("DecisionTree::convertFrom: Invalid NodePtr");
|
||||||
|
|
||||||
// get new label
|
// get new label
|
||||||
const M oldLabel = choice->label();
|
const M oldLabel = choice->label();
|
||||||
const L newLabel = L_of_M(oldLabel);
|
const L newLabel = L_of_M(oldLabel);
|
||||||
|
|
||||||
// put together via Shannon expansion otherwise not sorted.
|
// Shannon expansion in this context involves:
|
||||||
|
// 1. Creating separate subtrees (functions) for each possible value of the new label.
|
||||||
|
// 2. Combining these subtrees using the 'compose' method, which implements the expansion.
|
||||||
|
// This approach guarantees that the resulting tree maintains the correct variable ordering
|
||||||
|
// based on the new labels (L) after translation from the old labels (M).
|
||||||
|
// Simply creating a Choice node here would not work because it wouldn't account for the
|
||||||
|
// potentially new ordering of variables resulting from the label translation,
|
||||||
|
// which is crucial for maintaining consistency and efficiency in the converted tree.
|
||||||
std::vector<LY> functions;
|
std::vector<LY> functions;
|
||||||
for (auto&& branch : choice->branches()) {
|
for (auto&& branch : choice->branches()) {
|
||||||
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
functions.emplace_back(convertFrom<M, X>(branch, L_of_M, Y_of_X));
|
||||||
}
|
}
|
||||||
return Choice::Unique(LY::compose(functions.begin(), functions.end(), newLabel));
|
return Choice::Unique(
|
||||||
|
LY::compose(functions.begin(), functions.end(), newLabel));
|
||||||
}
|
}
|
||||||
|
|
||||||
/****************************************************************************/
|
/****************************************************************************/
|
||||||
|
@ -740,9 +745,8 @@ namespace gtsam {
|
||||||
*
|
*
|
||||||
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
|
* NOTE: We differentiate between leaves and assignments. Concretely, a 3
|
||||||
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
|
* binary variable tree will have 2^3=8 assignments, but based on pruning, it
|
||||||
* can have less than 8 leaves. For example, if a tree has all assignment
|
* can have <8 leaves. For example, if a tree has all assignment values as 1,
|
||||||
* values as 1, then pruning will cause the tree to have only 1 leaf yet 8
|
* then pruning will cause the tree to have only 1 leaf yet 8 assignments.
|
||||||
* assignments.
|
|
||||||
*/
|
*/
|
||||||
template <typename L, typename Y>
|
template <typename L, typename Y>
|
||||||
struct Visit {
|
struct Visit {
|
||||||
|
|
|
@ -176,9 +176,9 @@ namespace gtsam {
|
||||||
* @return NodePtr
|
* @return NodePtr
|
||||||
*/
|
*/
|
||||||
template <typename M, typename X>
|
template <typename M, typename X>
|
||||||
NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
|
static NodePtr convertFrom(const typename DecisionTree<M, X>::NodePtr& f,
|
||||||
std::function<L(const M&)> L_of_M,
|
std::function<L(const M&)> L_of_M,
|
||||||
std::function<Y(const X&)> Y_of_X) const;
|
std::function<Y(const X&)> Y_of_X);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
|
@ -402,7 +402,7 @@ namespace gtsam {
|
||||||
|
|
||||||
// internal use only
|
// internal use only
|
||||||
template<typename Iterator> NodePtr
|
template<typename Iterator> NodePtr
|
||||||
compose(Iterator begin, Iterator end, const L& label) const;
|
static compose(Iterator begin, Iterator end, const L& label);
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue