gtsam/gtwrap/matlab_wrapper/mixins.py

223 lines
8.0 KiB
Python
Raw Normal View History

Squashed 'wrap/' changes from 07330d100..d9ae5ce03 d9ae5ce03 Merge pull request #118 from borglab/feature/matlab-multi-files 9adddf7dd update the main script for matlab wrapping 0b0398d46 remove debug statements since they aren't needed for now df064a364 support for parsing mutiple interface files for Matlab wrapping 1929e197c add test for parsing multiple interface files bac442056 Merge pull request #117 from borglab/fix/matlab-refactor 331f4a8ce update tests to remove redundant code 5426e3af4 generate all content from within the wrap function f78612bf9 make directory check common b7acd7a1f fixed import and test setup 88007b153 Merge pull request #116 from borglab/feature/matlab-refactor a074896e6 utils -> mixins 414557e00 structure 187100439 update gitignore adbc55aea don't use class attributes in matlab wrapper f45ba5b2d broke down some large functions into smaller ones 7756f0548 add mixin for checks and call method to wrap global functions a318e2a67 Merge pull request #115 from borglab/feature/multiple-modules b02b74c3d convert matlab_wrapper to a submodule be8641e83 improved function naming in tests 02ddbfbb0 update tests and docs dfbded2c7 small fixes e9ec5af07 update docs d124e2cfb wrap multiple files 7c7342f86 update cmake to take in new changes for multiple modules 54850f724 Merge pull request #114 from borglab/fix/remove-py35 71ee98321 add mypy annotations ccaea6294 remove support for python 3.5 git-subtree-dir: wrap git-subtree-split: d9ae5ce036c4315db3c28b12db9c73eae246f314
2021-07-11 23:10:35 +08:00
"""Mixins for reducing the amount of boilerplate in the main wrapper class."""
import gtwrap.interface_parser as parser
import gtwrap.template_instantiator as instantiator
class CheckMixin:
"""Mixin to provide various checks."""
# Data types that are primitive types
not_ptr_type = ['int', 'double', 'bool', 'char', 'unsigned char', 'size_t']
# Ignore the namespace for these datatypes
ignore_namespace = ['Matrix', 'Vector', 'Point2', 'Point3']
# Methods that should be ignored
ignore_methods = ['pickle']
# Methods that should not be wrapped directly
whitelist = ['serializable', 'serialize']
# Datatypes that do not need to be checked in methods
not_check_type: list = []
def _has_serialization(self, cls):
for m in cls.methods:
if m.name in self.whitelist:
return True
return False
def is_shared_ptr(self, arg_type):
"""
Determine if the `interface_parser.Type` should be treated as a
shared pointer in the wrapper.
"""
return arg_type.is_shared_ptr or (
arg_type.typename.name not in self.not_ptr_type
and arg_type.typename.name not in self.ignore_namespace
and arg_type.typename.name != 'string')
def is_ptr(self, arg_type):
"""
Determine if the `interface_parser.Type` should be treated as a
raw pointer in the wrapper.
"""
return arg_type.is_ptr or (
arg_type.typename.name not in self.not_ptr_type
and arg_type.typename.name not in self.ignore_namespace
and arg_type.typename.name != 'string')
def is_ref(self, arg_type):
"""
Determine if the `interface_parser.Type` should be treated as a
reference in the wrapper.
"""
return arg_type.typename.name not in self.ignore_namespace and \
arg_type.typename.name not in self.not_ptr_type and \
arg_type.is_ref
class FormatMixin:
"""Mixin to provide formatting utilities."""
def _clean_class_name(self, instantiated_class):
"""Reformatted the C++ class name to fit Matlab defined naming
standards
"""
if len(instantiated_class.ctors) != 0:
return instantiated_class.ctors[0].name
return instantiated_class.name
def _format_type_name(self,
type_name,
separator='::',
include_namespace=True,
constructor=False,
method=False):
"""
Args:
type_name: an interface_parser.Typename to reformat
separator: the statement to add between namespaces and typename
include_namespace: whether to include namespaces when reformatting
constructor: if the typename will be in a constructor
method: if the typename will be in a method
Raises:
constructor and method cannot both be true
"""
if constructor and method:
raise ValueError(
'Constructor and method parameters cannot both be True')
formatted_type_name = ''
name = type_name.name
if include_namespace:
for namespace in type_name.namespaces:
if name not in self.ignore_namespace and namespace != '':
formatted_type_name += namespace + separator
if constructor:
formatted_type_name += self.data_type.get(name) or name
elif method:
formatted_type_name += self.data_type_param.get(name) or name
else:
formatted_type_name += name
if separator == "::": # C++
templates = []
for idx in range(len(type_name.instantiations)):
template = '{}'.format(
self._format_type_name(type_name.instantiations[idx],
include_namespace=include_namespace,
constructor=constructor,
method=method))
templates.append(template)
if len(templates) > 0: # If there are no templates
formatted_type_name += '<{}>'.format(','.join(templates))
else:
for idx in range(len(type_name.instantiations)):
formatted_type_name += '{}'.format(
self._format_type_name(type_name.instantiations[idx],
separator=separator,
include_namespace=False,
constructor=constructor,
method=method))
return formatted_type_name
def _format_return_type(self,
return_type,
include_namespace=False,
separator="::"):
"""Format return_type.
Args:
return_type: an interface_parser.ReturnType to reformat
include_namespace: whether to include namespaces when reformatting
"""
return_wrap = ''
if self._return_count(return_type) == 1:
return_wrap = self._format_type_name(
return_type.type1.typename,
separator=separator,
include_namespace=include_namespace)
else:
return_wrap = 'pair< {type1}, {type2} >'.format(
type1=self._format_type_name(
return_type.type1.typename,
separator=separator,
include_namespace=include_namespace),
type2=self._format_type_name(
return_type.type2.typename,
separator=separator,
include_namespace=include_namespace))
return return_wrap
def _format_class_name(self, instantiated_class, separator=''):
"""Format a template_instantiator.InstantiatedClass name."""
if instantiated_class.parent == '':
parent_full_ns = ['']
else:
parent_full_ns = instantiated_class.parent.full_namespaces()
# class_name = instantiated_class.parent.name
#
# if class_name != '':
# class_name += separator
#
# class_name += instantiated_class.name
parentname = "".join([separator + x
for x in parent_full_ns]) + separator
class_name = parentname[2 * len(separator):]
class_name += instantiated_class.name
return class_name
def _format_static_method(self, static_method, separator=''):
"""Example:
gtsamPoint3.staticFunction
"""
method = ''
if isinstance(static_method, parser.StaticMethod):
method += "".join([separator + x for x in static_method.parent.namespaces()]) + \
separator + static_method.parent.name + separator
return method[2 * len(separator):]
def _format_instance_method(self, instance_method, separator=''):
"""Example:
gtsamPoint3.staticFunction
"""
method = ''
if isinstance(instance_method, instantiator.InstantiatedMethod):
method_list = [
separator + x
for x in instance_method.parent.parent.full_namespaces()
]
method += "".join(method_list) + separator
method += instance_method.parent.name + separator
method += instance_method.original.name
method += "<" + instance_method.instantiations.to_cpp() + ">"
return method[2 * len(separator):]
def _format_global_method(self, static_method, separator=''):
"""Example:
gtsamPoint3.staticFunction
"""
method = ''
if isinstance(static_method, parser.GlobalFunction):
method += "".join([separator + x for x in static_method.parent.full_namespaces()]) + \
separator
return method[2 * len(separator):]