| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  | """Mixins for reducing the amount of boilerplate in the main wrapper class.""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  | from typing import Any, Tuple, Union | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  | import gtwrap.interface_parser as parser | 
					
						
							|  |  |  | import gtwrap.template_instantiator as instantiator | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class CheckMixin: | 
					
						
							|  |  |  |     """Mixin to provide various checks.""" | 
					
						
							|  |  |  |     # Data types that are primitive types | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |     not_ptr_type: Tuple = ('int', 'double', 'bool', 'char', 'unsigned char', | 
					
						
							|  |  |  |                            'size_t') | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |     # Ignore the namespace for these datatypes | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |     ignore_namespace: Tuple = ('Matrix', 'Vector', 'Point2', 'Point3') | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |     # Methods that should be ignored | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |     ignore_methods: Tuple = ('pickle', ) | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |     # Methods that should not be wrapped directly | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |     whitelist: Tuple = ('serializable', 'serialize') | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |     # Datatypes that do not need to be checked in methods | 
					
						
							|  |  |  |     not_check_type: list = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _has_serialization(self, cls): | 
					
						
							|  |  |  |         for m in cls.methods: | 
					
						
							|  |  |  |             if m.name in self.whitelist: | 
					
						
							|  |  |  |                 return True | 
					
						
							|  |  |  |         return False | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-02-03 00:53:48 +08:00
										 |  |  |     def can_be_pointer(self, arg_type: parser.Type): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Determine if the `arg_type` can have a pointer to it. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         E.g. `Pose3` can have `Pose3*` but  | 
					
						
							|  |  |  |         `Matrix` should not have `Matrix*`. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         return (arg_type.typename.name not in self.not_ptr_type | 
					
						
							|  |  |  |                 and arg_type.typename.name not in self.ignore_namespace | 
					
						
							|  |  |  |                 and arg_type.typename.name != 'string') | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |     def is_shared_ptr(self, arg_type: parser.Type): | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Determine if the `interface_parser.Type` should be treated as a | 
					
						
							|  |  |  |         shared pointer in the wrapper. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2022-02-03 00:53:48 +08:00
										 |  |  |         return arg_type.is_shared_ptr | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |     def is_ptr(self, arg_type: parser.Type): | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Determine if the `interface_parser.Type` should be treated as a | 
					
						
							|  |  |  |         raw pointer in the wrapper. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2022-02-03 00:53:48 +08:00
										 |  |  |         return arg_type.is_ptr | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |     def is_ref(self, arg_type: parser.Type): | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Determine if the `interface_parser.Type` should be treated as a | 
					
						
							|  |  |  |         reference in the wrapper. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         return arg_type.typename.name not in self.ignore_namespace and \ | 
					
						
							|  |  |  |                arg_type.typename.name not in self.not_ptr_type and \ | 
					
						
							|  |  |  |                arg_type.is_ref | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class FormatMixin: | 
					
						
							|  |  |  |     """Mixin to provide formatting utilities.""" | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     ignore_namespace: tuple | 
					
						
							|  |  |  |     data_type: Any | 
					
						
							|  |  |  |     data_type_param: Any | 
					
						
							|  |  |  |     _return_count: Any | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _clean_class_name(self, | 
					
						
							|  |  |  |                           instantiated_class: instantiator.InstantiatedClass): | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         """Reformatted the C++ class name to fit Matlab defined naming
 | 
					
						
							|  |  |  |         standards | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if len(instantiated_class.ctors) != 0: | 
					
						
							|  |  |  |             return instantiated_class.ctors[0].name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return instantiated_class.name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _format_type_name(self, | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |                           type_name: parser.Typename, | 
					
						
							|  |  |  |                           separator: str = '::', | 
					
						
							|  |  |  |                           include_namespace: bool = True, | 
					
						
							|  |  |  |                           is_constructor: bool = False, | 
					
						
							|  |  |  |                           is_method: bool = False): | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             type_name: an interface_parser.Typename to reformat | 
					
						
							|  |  |  |             separator: the statement to add between namespaces and typename | 
					
						
							|  |  |  |             include_namespace: whether to include namespaces when reformatting | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |             is_constructor: if the typename will be in a constructor | 
					
						
							|  |  |  |             is_method: if the typename will be in a method | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         Raises: | 
					
						
							|  |  |  |             constructor and method cannot both be true | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |         if is_constructor and is_method: | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |             raise ValueError( | 
					
						
							|  |  |  |                 'Constructor and method parameters cannot both be True') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         formatted_type_name = '' | 
					
						
							|  |  |  |         name = type_name.name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if include_namespace: | 
					
						
							|  |  |  |             for namespace in type_name.namespaces: | 
					
						
							|  |  |  |                 if name not in self.ignore_namespace and namespace != '': | 
					
						
							|  |  |  |                     formatted_type_name += namespace + separator | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |         if is_constructor: | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |             formatted_type_name += self.data_type.get(name) or name | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |         elif is_method: | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |             formatted_type_name += self.data_type_param.get(name) or name | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2021-12-07 00:01:43 +08:00
										 |  |  |             formatted_type_name += str(name) | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if separator == "::":  # C++ | 
					
						
							|  |  |  |             templates = [] | 
					
						
							| 
									
										
										
										
											2021-10-26 00:46:06 +08:00
										 |  |  |             for idx, _ in enumerate(type_name.instantiations): | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |                 template = '{}'.format( | 
					
						
							|  |  |  |                     self._format_type_name(type_name.instantiations[idx], | 
					
						
							|  |  |  |                                            include_namespace=include_namespace, | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |                                            is_constructor=is_constructor, | 
					
						
							|  |  |  |                                            is_method=is_method)) | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |                 templates.append(template) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if len(templates) > 0:  # If there are no templates | 
					
						
							|  |  |  |                 formatted_type_name += '<{}>'.format(','.join(templates)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2021-10-26 00:46:06 +08:00
										 |  |  |             for idx, _ in enumerate(type_name.instantiations): | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |                 formatted_type_name += '{}'.format( | 
					
						
							|  |  |  |                     self._format_type_name(type_name.instantiations[idx], | 
					
						
							|  |  |  |                                            separator=separator, | 
					
						
							|  |  |  |                                            include_namespace=False, | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |                                            is_constructor=is_constructor, | 
					
						
							|  |  |  |                                            is_method=is_method)) | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return formatted_type_name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _format_return_type(self, | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |                             return_type: parser.function.ReturnType, | 
					
						
							|  |  |  |                             include_namespace: bool = False, | 
					
						
							|  |  |  |                             separator: str = "::"): | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         """Format return_type.
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             return_type: an interface_parser.ReturnType to reformat | 
					
						
							|  |  |  |             include_namespace: whether to include namespaces when reformatting | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         return_wrap = '' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if self._return_count(return_type) == 1: | 
					
						
							|  |  |  |             return_wrap = self._format_type_name( | 
					
						
							|  |  |  |                 return_type.type1.typename, | 
					
						
							|  |  |  |                 separator=separator, | 
					
						
							|  |  |  |                 include_namespace=include_namespace) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return_wrap = 'pair< {type1}, {type2} >'.format( | 
					
						
							|  |  |  |                 type1=self._format_type_name( | 
					
						
							|  |  |  |                     return_type.type1.typename, | 
					
						
							|  |  |  |                     separator=separator, | 
					
						
							|  |  |  |                     include_namespace=include_namespace), | 
					
						
							|  |  |  |                 type2=self._format_type_name( | 
					
						
							|  |  |  |                     return_type.type2.typename, | 
					
						
							|  |  |  |                     separator=separator, | 
					
						
							|  |  |  |                     include_namespace=include_namespace)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return return_wrap | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |     def _format_class_name(self, | 
					
						
							|  |  |  |                            instantiated_class: instantiator.InstantiatedClass, | 
					
						
							|  |  |  |                            separator: str = ''): | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         """Format a template_instantiator.InstantiatedClass name.""" | 
					
						
							|  |  |  |         if instantiated_class.parent == '': | 
					
						
							|  |  |  |             parent_full_ns = [''] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             parent_full_ns = instantiated_class.parent.full_namespaces() | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         parentname = "".join([separator + x | 
					
						
							|  |  |  |                               for x in parent_full_ns]) + separator | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         class_name = parentname[2 * len(separator):] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         class_name += instantiated_class.name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return class_name | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |     def _format_static_method(self, | 
					
						
							|  |  |  |                               static_method: parser.StaticMethod, | 
					
						
							|  |  |  |                               separator: str = ''): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Example: | 
					
						
							|  |  |  |                 gtsam.Point3.staticFunction() | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         method = '' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if isinstance(static_method, parser.StaticMethod): | 
					
						
							| 
									
										
										
										
											2021-12-07 00:01:43 +08:00
										 |  |  |             method += static_method.parent.to_cpp() + separator | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-07 00:01:43 +08:00
										 |  |  |         return method | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |     def _format_global_function(self, | 
					
						
							|  |  |  |                                 function: Union[parser.GlobalFunction, Any], | 
					
						
							|  |  |  |                                 separator: str = ''): | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         """Example:
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 gtsamPoint3.staticFunction | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         method = '' | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-22 23:51:22 +08:00
										 |  |  |         if isinstance(function, parser.GlobalFunction): | 
					
						
							|  |  |  |             method += "".join([separator + x for x in function.parent.full_namespaces()]) + \ | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |                       separator | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return method[2 * len(separator):] |