1729 lines
		
	
	
		
			73 KiB
		
	
	
	
		
			Python
		
	
	
		
		
			
		
	
	
			1729 lines
		
	
	
		
			73 KiB
		
	
	
	
		
			Python
		
	
	
|  | import os | ||
|  | import argparse | ||
|  | import textwrap | ||
|  | 
 | ||
|  | import gtwrap.interface_parser as parser | ||
|  | import gtwrap.template_instantiator as instantiator | ||
|  | 
 | ||
|  | from functools import reduce | ||
|  | from functools import partial | ||
|  | 
 | ||
|  | 
 | ||
|  | class MatlabWrapper(object): | ||
|  |     """ Wrap the given C++ code into Matlab.
 | ||
|  | 
 | ||
|  |     Attributes | ||
|  |         module: the C++ module being wrapped | ||
|  |         module_name: name of the C++ module being wrapped | ||
|  |         top_module_namespace: C++ namespace for the top module (default '') | ||
|  |         ignore_classes: A list of classes to ignore (default []) | ||
|  |     """
 | ||
|  |     """Map the data type to its Matlab class.
 | ||
|  |     Found in Argument.cpp in old wrapper | ||
|  |     """
 | ||
|  |     data_type = { | ||
|  |         'string': 'char', | ||
|  |         'char': 'char', | ||
|  |         'unsigned char': 'unsigned char', | ||
|  |         'Vector': 'double', | ||
|  |         'Matrix': 'double', | ||
|  |         'int': 'numeric', | ||
|  |         'size_t': 'numeric', | ||
|  |         'bool': 'logical' | ||
|  |     } | ||
|  |     """Map the data type into the type used in Matlab methods.
 | ||
|  |     Found in matlab.h in old wrapper | ||
|  |     """
 | ||
|  |     data_type_param = { | ||
|  |         'string': 'char', | ||
|  |         'char': 'char', | ||
|  |         'unsigned char': 'unsigned char', | ||
|  |         'size_t': 'int', | ||
|  |         'int': 'int', | ||
|  |         'double': 'double', | ||
|  |         'Point2': 'double', | ||
|  |         'Point3': 'double', | ||
|  |         'Vector': 'double', | ||
|  |         'Matrix': 'double', | ||
|  |         'bool': 'bool' | ||
|  |     } | ||
|  |     """Methods that should not be wrapped directly""" | ||
|  |     whitelist = ['serializable', 'serialize'] | ||
|  |     """Datatypes that do not need to be checked in methods""" | ||
|  |     not_check_type = [] | ||
|  |     """Data types that are primitive types""" | ||
|  |     not_ptr_type = ['int', 'double', 'bool', 'char', 'unsigned char', 'size_t'] | ||
|  |     """Ignore the namespace for these datatypes""" | ||
|  |     ignore_namespace = ['Matrix', 'Vector', 'Point2', 'Point3'] | ||
|  |     """The amount of times the wrapper has created a call to
 | ||
|  |     geometry_wrapper | ||
|  |     """
 | ||
|  |     wrapper_id = 0 | ||
|  |     """Map each wrapper id to what its collector function namespace, class,
 | ||
|  |     type, and string format"""
 | ||
|  |     wrapper_map = {} | ||
|  |     """Set of all the includes in the namespace""" | ||
|  |     includes = {} | ||
|  |     """Set of all classes in the namespace""" | ||
|  |     classes = [] | ||
|  |     classes_elems = {} | ||
|  |     """Id for ordering global functions in the wrapper""" | ||
|  |     global_function_id = 0 | ||
|  |     """Files and their content""" | ||
|  |     content = [] | ||
|  | 
 | ||
|  |     def __init__(self, module, module_name, top_module_namespace='', ignore_classes=[]): | ||
|  |         self.module = module | ||
|  |         self.module_name = module_name | ||
|  |         self.top_module_namespace = top_module_namespace | ||
|  |         self.ignore_classes = ignore_classes | ||
|  |         self.verbose = False | ||
|  | 
 | ||
|  |     def _debug(self, message): | ||
|  |         if not self.verbose: | ||
|  |             return | ||
|  |         import sys | ||
|  |         print(message, file=sys.stderr) | ||
|  | 
 | ||
|  |     def _add_include(self, include): | ||
|  |         self.includes[include] = 0 | ||
|  | 
 | ||
|  |     def _add_class(self, instantiated_class): | ||
|  |         if self.classes_elems.get(instantiated_class) is None: | ||
|  |             self.classes_elems[instantiated_class] = 0 | ||
|  |             self.classes.append(instantiated_class) | ||
|  | 
 | ||
|  |     def _update_wrapper_id(self, collector_function=None, id_diff=0): | ||
|  |         """Get and define wrapper ids.
 | ||
|  | 
 | ||
|  |         Generates the map of id -> collector function. | ||
|  | 
 | ||
|  |         Args: | ||
|  |             collector_function: tuple storing info about the wrapper function | ||
|  |                 (namespace, class instance, function type, function name, | ||
|  |                 extra) | ||
|  |             id_diff: constant to add to the id in the map | ||
|  | 
 | ||
|  |         Returns: | ||
|  |             the current wrapper id | ||
|  |         """
 | ||
|  |         if collector_function is not None: | ||
|  |             is_instantiated_class = isinstance(collector_function[1], instantiator.InstantiatedClass) | ||
|  | 
 | ||
|  |             if is_instantiated_class: | ||
|  |                 function_name = collector_function[0] + \ | ||
|  |                                 collector_function[1].name + '_' + collector_function[2] | ||
|  |             else: | ||
|  |                 function_name = collector_function[1].name | ||
|  | 
 | ||
|  |             self.wrapper_map[self.wrapper_id] = (collector_function[0], collector_function[1], collector_function[2], | ||
|  |                                                  function_name + '_' + str(self.wrapper_id + id_diff), | ||
|  |                                                  collector_function[3]) | ||
|  | 
 | ||
|  |         self.wrapper_id += 1 | ||
|  | 
 | ||
|  |         return self.wrapper_id - 1 | ||
|  | 
 | ||
|  |     def _qualified_name(self, names): | ||
|  |         return 'handle' if names == '' else names | ||
|  | 
 | ||
|  |     def _insert_spaces(self, x, y): | ||
|  |         """Insert spaces at the beginning of each line
 | ||
|  | 
 | ||
|  |         Args: | ||
|  |             x: the statement currently generated | ||
|  |             y: the addition to add to the statement | ||
|  |         """
 | ||
|  |         return x + '\n' + ('' if y == '' else '  ') + y | ||
|  | 
 | ||
|  |     def _is_ptr(self, arg_type): | ||
|  |         """Determine if the interface_parser.Type should be treated as a
 | ||
|  |         pointer in the wrapper. | ||
|  |         """
 | ||
|  |         return arg_type.is_ptr or (arg_type.typename.name not in self.not_ptr_type | ||
|  |                                    and arg_type.typename.name not in self.ignore_namespace | ||
|  |                                    and arg_type.typename.name != 'string') | ||
|  | 
 | ||
|  |     def _is_ref(self, arg_type): | ||
|  |         """Determine if the interface_parser.Type should be treated as a
 | ||
|  |         reference in the wrapper. | ||
|  |         """
 | ||
|  |         return arg_type.typename.name not in self.ignore_namespace and \ | ||
|  |                arg_type.typename.name not in self.not_ptr_type and \ | ||
|  |                arg_type.is_ref | ||
|  | 
 | ||
|  |     def _group_methods(self, methods): | ||
|  |         """Group overloaded methods together""" | ||
|  |         method_map = {} | ||
|  |         method_out = [] | ||
|  | 
 | ||
|  |         for method in methods: | ||
|  |             method_index = method_map.get(method.name) | ||
|  | 
 | ||
|  |             if method_index is None: | ||
|  |                 method_map[method.name] = len(method_out) | ||
|  |                 method_out.append([method]) | ||
|  |             else: | ||
|  |                 self._debug("[_group_methods] Merging {} with {}".format(method_index, method.name)) | ||
|  |                 method_out[method_index].append(method) | ||
|  | 
 | ||
|  |         return method_out | ||
|  | 
 | ||
|  |     def _clean_class_name(self, instantiated_class): | ||
|  |         """Reformatted the C++ class name to fit Matlab defined naming
 | ||
|  |         standards | ||
|  |         """
 | ||
|  |         if len(instantiated_class.ctors) != 0: | ||
|  |             return instantiated_class.ctors[0].name | ||
|  | 
 | ||
|  |         return instantiated_class.name | ||
|  | 
 | ||
|  |     @classmethod | ||
|  |     def _format_type_name(cls, type_name, separator='::', include_namespace=True, constructor=False, method=False): | ||
|  |         """
 | ||
|  |         Args: | ||
|  |             type_name: an interface_parser.Typename to reformat | ||
|  |             separator: the statement to add between namespaces and typename | ||
|  |             include_namespace: whether to include namespaces when reformatting | ||
|  |             constructor: if the typename will be in a constructor | ||
|  |             method: if the typename will be in a method | ||
|  | 
 | ||
|  |         Raises: | ||
|  |             constructor and method cannot both be true | ||
|  |         """
 | ||
|  |         if constructor and method: | ||
|  |             raise Exception('Constructor and method parameters cannot both be True') | ||
|  | 
 | ||
|  |         formatted_type_name = '' | ||
|  |         name = type_name.name | ||
|  | 
 | ||
|  |         if include_namespace: | ||
|  |             for namespace in type_name.namespaces: | ||
|  |                 if name not in cls.ignore_namespace and namespace != '': | ||
|  |                     formatted_type_name += namespace + separator | ||
|  | 
 | ||
|  |             #self._debug("formatted_ns: {}, ns: {}".format(formatted_type_name, type_name.namespaces)) | ||
|  |         if constructor: | ||
|  |             formatted_type_name += cls.data_type.get(name) or name | ||
|  |         elif method: | ||
|  |             formatted_type_name += cls.data_type_param.get(name) or name | ||
|  |         else: | ||
|  |             formatted_type_name += name | ||
|  | 
 | ||
|  |         if separator == "::":  # C++ | ||
|  |             templates = [] | ||
|  |             for idx in range(len(type_name.instantiations)): | ||
|  |                 template = '{}'.format(cls._format_type_name(type_name.instantiations[idx], | ||
|  |                                                              include_namespace=include_namespace, | ||
|  |                                                              constructor=constructor, method=method)) | ||
|  |                 templates.append(template) | ||
|  | 
 | ||
|  |             if len(templates) > 0:  # If there are no templates | ||
|  |                 formatted_type_name += '<{}>'.format(','.join(templates)) | ||
|  | 
 | ||
|  |         else: | ||
|  |             for idx in range(len(type_name.instantiations)): | ||
|  |                 formatted_type_name += '{}'.format(cls._format_type_name( | ||
|  |                     type_name.instantiations[idx], | ||
|  |                     separator=separator, | ||
|  |                     include_namespace=False, | ||
|  |                     constructor=constructor, method=method | ||
|  |                 )) | ||
|  | 
 | ||
|  |         return formatted_type_name | ||
|  | 
 | ||
|  |     @classmethod | ||
|  |     def _format_return_type(cls, return_type, include_namespace=False, separator="::"): | ||
|  |         """Format return_type.
 | ||
|  | 
 | ||
|  |         Args: | ||
|  |             return_type: an interface_parser.ReturnType to reformat | ||
|  |             include_namespace: whether to include namespaces when reformatting | ||
|  |         """
 | ||
|  |         return_wrap = '' | ||
|  | 
 | ||
|  |         if cls._return_count(return_type) == 1: | ||
|  |             return_wrap = cls._format_type_name( | ||
|  |                 return_type.type1.typename, | ||
|  |                 separator=separator, | ||
|  |                 include_namespace=include_namespace | ||
|  |             ) | ||
|  |         else: | ||
|  |             return_wrap = 'pair< {type1}, {type2} >'.format( | ||
|  |                 type1=cls._format_type_name( | ||
|  |                     return_type.type1.typename, separator=separator, include_namespace=include_namespace | ||
|  |                 ), | ||
|  |                 type2=cls._format_type_name( | ||
|  |                     return_type.type2.typename, separator=separator, include_namespace=include_namespace | ||
|  |                 )) | ||
|  | 
 | ||
|  |         return return_wrap | ||
|  | 
 | ||
|  |     def _format_class_name(self, instantiated_class, separator=''): | ||
|  |         """Format a template_instantiator.InstantiatedClass name.""" | ||
|  |         if instantiated_class.parent == '': | ||
|  |             parent_full_ns = [''] | ||
|  |         else: | ||
|  |             parent_full_ns = instantiated_class.parent.full_namespaces() | ||
|  |         # class_name = instantiated_class.parent.name | ||
|  |         # | ||
|  |         # if class_name != '': | ||
|  |         #     class_name += separator | ||
|  |         # | ||
|  |         # class_name += instantiated_class.name | ||
|  |         parentname = "".join([separator + x for x in parent_full_ns]) + separator | ||
|  | 
 | ||
|  |         class_name = parentname[2 * len(separator):] | ||
|  | 
 | ||
|  |         class_name += instantiated_class.name | ||
|  | 
 | ||
|  |         return class_name | ||
|  | 
 | ||
|  |     def _format_static_method(self, static_method, separator=''): | ||
|  |         """Example:
 | ||
|  | 
 | ||
|  |                 gtsamPoint3.staticFunction | ||
|  |         """
 | ||
|  |         method = '' | ||
|  | 
 | ||
|  |         if isinstance(static_method, parser.StaticMethod): | ||
|  |             method += "".join([separator + x for x in static_method.parent.namespaces()]) + \ | ||
|  |                       separator + static_method.parent.name + separator | ||
|  | 
 | ||
|  |         return method[2 * len(separator):] | ||
|  | 
 | ||
|  |     def _format_instance_method(self, instance_method, separator=''): | ||
|  |         """Example:
 | ||
|  | 
 | ||
|  |                 gtsamPoint3.staticFunction | ||
|  |         """
 | ||
|  |         method = '' | ||
|  | 
 | ||
|  |         if isinstance(instance_method, instantiator.InstantiatedMethod): | ||
|  |             method += "".join([separator + x for x in instance_method.parent.parent.full_namespaces()]) + \ | ||
|  |                       separator | ||
|  | 
 | ||
|  |             method += instance_method.parent.name + separator + \ | ||
|  |                       instance_method.original.name + "<" + instance_method.instantiation.to_cpp() + ">" | ||
|  |         return method[2 * len(separator):] | ||
|  | 
 | ||
|  |     def _format_global_method(self, static_method, separator=''): | ||
|  |         """Example:
 | ||
|  | 
 | ||
|  |                 gtsamPoint3.staticFunction | ||
|  |         """
 | ||
|  |         method = '' | ||
|  | 
 | ||
|  |         if isinstance(static_method, parser.GlobalFunction): | ||
|  |             method += "".join([separator + x for x in static_method.parent.full_namespaces()]) + \ | ||
|  |                       separator | ||
|  | 
 | ||
|  |         return method[2 * len(separator):] | ||
|  | 
 | ||
|  |     def _wrap_args(self, args): | ||
|  |         """Wrap an interface_parser.ArgumentList into a list of arguments.
 | ||
|  | 
 | ||
|  |         Returns: | ||
|  |             A string representation of the arguments. For example: | ||
|  |                 'int x, double y' | ||
|  |         """
 | ||
|  |         arg_wrap = '' | ||
|  | 
 | ||
|  |         for i, arg in enumerate(args.args_list, 1): | ||
|  |             c_type = self._format_type_name(arg.ctype.typename, include_namespace=False) | ||
|  | 
 | ||
|  |             arg_wrap += '{c_type} {arg_name}{comma}'.format(c_type=c_type, | ||
|  |                                                             arg_name=arg.name, | ||
|  |                                                             comma='' if i == len(args.args_list) else ', ') | ||
|  | 
 | ||
|  |         return arg_wrap | ||
|  | 
 | ||
|  |     def _wrap_variable_arguments(self, args, wrap_datatypes=True): | ||
|  |         """ Wrap an interface_parser.ArgumentList into a statement of argument
 | ||
|  |         checks. | ||
|  | 
 | ||
|  |         Returns: | ||
|  |             A string representation of a variable arguments for an if | ||
|  |             statement. For example: | ||
|  |                 ' && isa(varargin{1},'double') && isa(varargin{2},'numeric')' | ||
|  |         """
 | ||
|  |         var_arg_wrap = '' | ||
|  | 
 | ||
|  |         for i, arg in enumerate(args.args_list, 1): | ||
|  |             name = arg.ctype.typename.name | ||
|  |             if name in self.not_check_type: | ||
|  |                 continue | ||
|  | 
 | ||
|  |             check_type = self.data_type_param.get(name) | ||
|  | 
 | ||
|  |             if self.data_type.get(check_type): | ||
|  |                 check_type = self.data_type[check_type] | ||
|  | 
 | ||
|  |             if check_type is None: | ||
|  |                 check_type = self._format_type_name(arg.ctype.typename, separator='.', constructor=not wrap_datatypes) | ||
|  | 
 | ||
|  |             var_arg_wrap += " && isa(varargin{{{num}}},'{data_type}')".format(num=i, data_type=check_type) | ||
|  |             if name == 'Vector': | ||
|  |                 var_arg_wrap += ' && size(varargin{{{num}}},2)==1'.format(num=i) | ||
|  |             if name == 'Point2': | ||
|  |                 var_arg_wrap += ' && size(varargin{{{num}}},1)==2'.format(num=i) | ||
|  |                 var_arg_wrap += ' && size(varargin{{{num}}},2)==1'.format(num=i) | ||
|  |             if name == 'Point3': | ||
|  |                 var_arg_wrap += ' && size(varargin{{{num}}},1)==3'.format(num=i) | ||
|  |                 var_arg_wrap += ' && size(varargin{{{num}}},2)==1'.format(num=i) | ||
|  | 
 | ||
|  |         return var_arg_wrap | ||
|  | 
 | ||
|  |     def _wrap_list_variable_arguments(self, args): | ||
|  |         """ Wrap an interface_parser.ArgumentList into a list of argument
 | ||
|  |         variables. | ||
|  | 
 | ||
|  |         Returns: | ||
|  |             A string representation of a list of variable arguments. | ||
|  |             For example: | ||
|  |                 'varargin{1}, varargin{2}, varargin{3}' | ||
|  |         """
 | ||
|  |         var_list_wrap = '' | ||
|  |         first = True | ||
|  | 
 | ||
|  |         for i in range(1, len(args.args_list) + 1): | ||
|  |             if first: | ||
|  |                 var_list_wrap += 'varargin{{{}}}'.format(i) | ||
|  |                 first = False | ||
|  |             else: | ||
|  |                 var_list_wrap += ', varargin{{{}}}'.format(i) | ||
|  | 
 | ||
|  |         return var_list_wrap | ||
|  | 
 | ||
|  |     def _wrap_method_check_statement(self, args): | ||
|  |         """Wrap the given arguments into either just a varargout call or a
 | ||
|  |         call in an if statement that checks if the parameters are accurate. | ||
|  |         """
 | ||
|  |         check_statement = '' | ||
|  |         arg_id = 1 | ||
|  | 
 | ||
|  |         if check_statement == '': | ||
|  |             check_statement = \ | ||
|  |                 'if length(varargin) == {param_count}'.format( | ||
|  |                     param_count=len(args.args_list)) | ||
|  | 
 | ||
|  |         for i, arg in enumerate(args.args_list): | ||
|  |             name = arg.ctype.typename.name | ||
|  | 
 | ||
|  |             if name in self.not_check_type: | ||
|  |                 arg_id += 1 | ||
|  |                 continue | ||
|  | 
 | ||
|  |             check_type = self.data_type_param.get(name) | ||
|  | 
 | ||
|  |             if self.data_type.get(check_type): | ||
|  |                 check_type = self.data_type[check_type] | ||
|  | 
 | ||
|  |             if check_type is None: | ||
|  |                 check_type = self._format_type_name(arg.ctype.typename, separator='.') | ||
|  | 
 | ||
|  |             check_statement += " && isa(varargin{{{id}}},'{ctype}')".format(id=arg_id, ctype=check_type) | ||
|  | 
 | ||
|  |             if name == 'Vector': | ||
|  |                 check_statement += ' && size(varargin{{{num}}},2)==1'.format(num=arg_id) | ||
|  |             if name == 'Point2': | ||
|  |                 check_statement += ' && size(varargin{{{num}}},1)==2'.format(num=arg_id) | ||
|  |                 check_statement += ' && size(varargin{{{num}}},2)==1'.format(num=arg_id) | ||
|  |             if name == 'Point3': | ||
|  |                 check_statement += ' && size(varargin{{{num}}},1)==3'.format(num=arg_id) | ||
|  |                 check_statement += ' && size(varargin{{{num}}},2)==1'.format(num=arg_id) | ||
|  | 
 | ||
|  |             arg_id += 1 | ||
|  | 
 | ||
|  |         check_statement = check_statement \ | ||
|  |             if check_statement == '' \ | ||
|  |             else check_statement + '\n' | ||
|  | 
 | ||
|  |         return check_statement | ||
|  | 
 | ||
|  |     def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False): | ||
|  |         """Format the interface_parser.Arguments.
 | ||
|  | 
 | ||
|  |         Examples: | ||
|  |             ((a), unsigned char a = unwrap< unsigned char >(in[1]);), | ||
|  |             ((a), Test& t = *unwrap_shared_ptr< Test >(in[1], "ptr_Test");), | ||
|  |             ((a), std::shared_ptr<Test> p1 = unwrap_shared_ptr< Test >(in[1], "ptr_Test");) | ||
|  |         """
 | ||
|  |         params = '' | ||
|  |         body_args = '' | ||
|  | 
 | ||
|  |         for arg in args.args_list: | ||
|  |             if params != '': | ||
|  |                 params += ',' | ||
|  | 
 | ||
|  |             # import sys | ||
|  |             # print("type: {}, is_ref: {}, is_ref: {}".format(arg.ctype, self._is_ref(arg.ctype), arg.ctype.is_ref), file=sys.stderr) | ||
|  |             if self._is_ref(arg.ctype):  # and not constructor: | ||
|  |                 ctype_camel = self._format_type_name(arg.ctype.typename, separator='') | ||
|  |                 body_args += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                    {ctype}& {name} = *unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}"); | ||
|  |                 '''.format(ctype=self._format_type_name(arg.ctype.typename),
 | ||
|  |                            ctype_camel=ctype_camel, | ||
|  |                            name=arg.name, | ||
|  |                            id=arg_id)), | ||
|  |                                              prefix='  ') | ||
|  |             elif self._is_ptr(arg.ctype) and \ | ||
|  |                     arg.ctype.typename.name not in self.ignore_namespace: | ||
|  |                 call_type = arg.ctype.is_ptr | ||
|  |                 body_args += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                     {std_boost}::shared_ptr<{ctype_sep}> {name} = unwrap_shared_ptr< {ctype_sep} >(in[{id}], "ptr_{ctype}"); | ||
|  |                 '''.format(std_boost='boost' if constructor else 'boost',
 | ||
|  |                            ctype_sep=self._format_type_name(arg.ctype.typename), | ||
|  |                            ctype=self._format_type_name(arg.ctype.typename, separator=''), | ||
|  |                            name=arg.name, | ||
|  |                            id=arg_id)), | ||
|  |                                              prefix='  ') | ||
|  |                 if call_type == "": | ||
|  |                     params += "*" | ||
|  |             else: | ||
|  |                 body_args += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                     {ctype} {name} = unwrap< {ctype} >(in[{id}]); | ||
|  |                 '''.format(ctype=arg.ctype.typename.name, name=arg.name, id=arg_id)),
 | ||
|  |                                              prefix='  ') | ||
|  | 
 | ||
|  |             params += arg.name | ||
|  | 
 | ||
|  |             arg_id += 1 | ||
|  | 
 | ||
|  |         return params, body_args | ||
|  | 
 | ||
|  |     @staticmethod | ||
|  |     def _return_count(return_type): | ||
|  |         """The amount of objects returned by the given
 | ||
|  |         interface_parser.ReturnType. | ||
|  |         """
 | ||
|  |         return 1 if return_type.type2 == '' else 2 | ||
|  | 
 | ||
|  |     def _wrapper_name(self): | ||
|  |         """Determine the name of wrapper function.""" | ||
|  |         return self.module_name + '_wrapper' | ||
|  | 
 | ||
|  |     def class_serialize_comment(self, class_name, static_methods): | ||
|  |         """Generate comments for serialize methods.""" | ||
|  |         comment_wrap = '' | ||
|  |         static_methods = sorted(static_methods, key=lambda name: name.name) | ||
|  | 
 | ||
|  |         for static_method in static_methods: | ||
|  |             if comment_wrap == '': | ||
|  |                 comment_wrap = '%-------Static Methods-------\n' | ||
|  | 
 | ||
|  |             comment_wrap += '%{name}({args}) : returns {return_type}\n'.format(name=static_method.name, | ||
|  |                                                                                args=self._wrap_args(static_method.args), | ||
|  |                                                                                return_type=self._format_return_type( | ||
|  |                                                                                    static_method.return_type, | ||
|  |                                                                                    include_namespace=True)) | ||
|  | 
 | ||
|  |         comment_wrap += textwrap.dedent('''\
 | ||
|  |             % | ||
|  |             %-------Serialization Interface------- | ||
|  |             %string_serialize() : returns string | ||
|  |             %string_deserialize(string serialized) : returns {class_name} | ||
|  |             % | ||
|  |         ''').format(class_name=class_name)
 | ||
|  | 
 | ||
|  |         return comment_wrap | ||
|  | 
 | ||
|  |     def class_comment(self, instantiated_class): | ||
|  |         """Generate comments for the given class in Matlab.
 | ||
|  | 
 | ||
|  |         Args | ||
|  |             instantiated_class: the class being wrapped | ||
|  |             ctors: a list of the constructors in the class | ||
|  |             methods: a list of the methods in the class | ||
|  |         """
 | ||
|  |         class_name = instantiated_class.name | ||
|  |         ctors = instantiated_class.ctors | ||
|  |         methods = instantiated_class.methods | ||
|  |         static_methods = instantiated_class.static_methods | ||
|  | 
 | ||
|  |         comment = textwrap.dedent('''\
 | ||
|  |             %class {class_name}, see Doxygen page for details | ||
|  |             %at https://gtsam.org/doxygen/ | ||
|  |         ''').format(class_name=class_name)
 | ||
|  | 
 | ||
|  |         if len(ctors) != 0: | ||
|  |             comment += '%\n%-------Constructors-------\n' | ||
|  | 
 | ||
|  |         # Write constructors | ||
|  |         for ctor in ctors: | ||
|  |             comment += '%{ctor_name}({args})\n'.format(ctor_name=ctor.name, args=self._wrap_args(ctor.args)) | ||
|  | 
 | ||
|  |         if len(methods) != 0: | ||
|  |             comment += '%\n' \ | ||
|  |                        '%-------Methods-------\n' | ||
|  | 
 | ||
|  |         methods = sorted(methods, key=lambda name: name.name) | ||
|  | 
 | ||
|  |         # Write methods | ||
|  |         for method in methods: | ||
|  |             if method.name in self.whitelist: | ||
|  |                 continue | ||
|  | 
 | ||
|  |             comment += '%{name}({args})'.format(name=method.name, args=self._wrap_args(method.args)) | ||
|  | 
 | ||
|  |             if method.return_type.type2 == '': | ||
|  |                 return_type = self._format_type_name(method.return_type.type1.typename) | ||
|  |             else: | ||
|  |                 return_type = 'pair< {type1}, {type2} >'.format( | ||
|  |                     type1=self._format_type_name(method.return_type.type1.typename), | ||
|  |                     type2=self._format_type_name(method.return_type.type2.typename)) | ||
|  | 
 | ||
|  |             comment += ' : returns {return_type}\n'.format(return_type=return_type) | ||
|  | 
 | ||
|  |         comment += '%\n' | ||
|  | 
 | ||
|  |         if len(static_methods) != 0: | ||
|  |             comment += self.class_serialize_comment(class_name, static_methods) | ||
|  | 
 | ||
|  |         return comment | ||
|  | 
 | ||
|  |     def generate_matlab_wrapper(self): | ||
|  |         """Generate the C++ file for the wrapper.""" | ||
|  |         file_name = self._wrapper_name() + '.cpp' | ||
|  | 
 | ||
|  |         wrapper_file = textwrap.dedent('''\
 | ||
|  |             # include <wrap/matlab.h> | ||
|  |             # include <map> | ||
|  |         ''')
 | ||
|  | 
 | ||
|  |         return file_name, wrapper_file | ||
|  | 
 | ||
|  |     def wrap_method(self, methods): | ||
|  |         """Wrap methods in the body of a class.""" | ||
|  |         if not isinstance(methods, list): | ||
|  |             methods = [methods] | ||
|  | 
 | ||
|  |         for method in methods: | ||
|  |             output = '' | ||
|  | 
 | ||
|  |         return '' | ||
|  | 
 | ||
|  |     def wrap_methods(self, methods, globals=False, global_ns=None): | ||
|  |         """Wrap a sequence of methods. Groups methods with the same names
 | ||
|  |         together. If globals is True then output every method into its own | ||
|  |         file. | ||
|  |         """
 | ||
|  |         output = '' | ||
|  |         methods = self._group_methods(methods) | ||
|  | 
 | ||
|  |         for method in methods: | ||
|  |             if globals: | ||
|  |                 self._debug("[wrap_methods] wrapping: {}..{}={}".format(method[0].parent.name, method[0].name, | ||
|  |                                                                         type(method[0].parent.name))) | ||
|  | 
 | ||
|  |                 method_text = self.wrap_global_function(method) | ||
|  |                 self.content.append( | ||
|  |                     ("".join(['+' + x + '/' | ||
|  |                               for x in global_ns.full_namespaces()[1:]])[:-1], [(method[0].name + '.m', method_text)])) | ||
|  |             else: | ||
|  |                 method_text = self.wrap_method(method) | ||
|  |                 output += '' | ||
|  | 
 | ||
|  |         return output | ||
|  | 
 | ||
|  |     def wrap_global_function(self, function): | ||
|  |         """Wrap the given global function.""" | ||
|  |         if not isinstance(function, list): | ||
|  |             function = [function] | ||
|  | 
 | ||
|  |         function_name = function[0].name | ||
|  | 
 | ||
|  |         # Get all combinations of parameters | ||
|  |         param_wrap = '' | ||
|  | 
 | ||
|  |         for i, overload in enumerate(function): | ||
|  |             param_wrap += '      if' if i == 0 else '      elseif' | ||
|  |             param_wrap += ' length(varargin) == ' | ||
|  | 
 | ||
|  |             if len(overload.args.args_list) == 0: | ||
|  |                 param_wrap += '0\n' | ||
|  |             else: | ||
|  |                 param_wrap += str(len(overload.args.args_list)) \ | ||
|  |                               + self._wrap_variable_arguments(overload.args, False) + '\n' | ||
|  | 
 | ||
|  |             # Determine format of return and varargout statements | ||
|  |             return_type_formatted = self._format_return_type( | ||
|  |                 overload.return_type, | ||
|  |                 include_namespace=True, | ||
|  |                 separator="." | ||
|  |             ) | ||
|  |             varargout = self._format_varargout(overload.return_type, return_type_formatted) | ||
|  | 
 | ||
|  |             param_wrap += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                 {varargout}{module_name}_wrapper({num}, varargin{{:}}); | ||
|  |             ''').format(varargout=varargout, module_name=self.module_name,
 | ||
|  |                         num=self._update_wrapper_id(collector_function=(function[0].parent.name, function[i], | ||
|  |                                                                         'global_function', None))), | ||
|  |                                           prefix='        ') | ||
|  | 
 | ||
|  |         param_wrap += textwrap.indent(textwrap.dedent('''\
 | ||
|  |             else | ||
|  |               error('Arguments do not match any overload of function {func_name}'); | ||
|  |         ''').format(func_name=function_name),
 | ||
|  |                                       prefix='      ') | ||
|  | 
 | ||
|  |         global_function = textwrap.indent(textwrap.dedent('''\
 | ||
|  |             function varargout = {m_method}(varargin) | ||
|  |             {statements}      end | ||
|  |         ''').format(m_method=function_name, statements=param_wrap),
 | ||
|  |                                           prefix='') | ||
|  | 
 | ||
|  |         return global_function | ||
|  | 
 | ||
|  |     def wrap_class_constructors(self, namespace_name, inst_class, parent_name, ctors, is_virtual): | ||
|  |         """Wrap class constructor.
 | ||
|  | 
 | ||
|  |         Args: | ||
|  |             namespace_name: the name of the namespace ('' if it does not exist) | ||
|  |             inst_class: instance of the class | ||
|  |             parent_name: the name of the parent class if it exists | ||
|  |             ctors: the interface_parser.Constructor in the class | ||
|  |             is_virtual: whether the class is part of a virtual inheritance | ||
|  |                 chain | ||
|  |         """
 | ||
|  |         has_parent = parent_name != '' | ||
|  |         class_name = inst_class.name | ||
|  |         if has_parent: | ||
|  |             parent_name = self._format_type_name(parent_name, separator=".") | ||
|  |         if type(ctors) != list: | ||
|  |             ctors = [ctors] | ||
|  | 
 | ||
|  |         # import sys | ||
|  |         # if class_name: | ||
|  |         #     print("[Constructor] class: {} ns: {}" | ||
|  |         #           .format(class_name, | ||
|  |         #                   inst_class.namespaces()) | ||
|  |         #           , file=sys.stderr) | ||
|  |         # full_name = "".join(obj_type.full_namespaces()) + obj_type.name | ||
|  | 
 | ||
|  |         methods_wrap = textwrap.indent(textwrap.dedent("""\
 | ||
|  |             methods | ||
|  |               function obj = {class_name}(varargin) | ||
|  |         """).format(class_name=class_name),
 | ||
|  |                                        prefix='') | ||
|  | 
 | ||
|  |         if is_virtual: | ||
|  |             methods_wrap += "    if (nargin == 2 || (nargin == 3 && strcmp(varargin{3}, 'void')))" | ||
|  |         else: | ||
|  |             methods_wrap += '    if nargin == 2' | ||
|  | 
 | ||
|  |         methods_wrap += " && isa(varargin{1}, 'uint64') && varargin{1} == uint64(5139824614673773682)\n" | ||
|  | 
 | ||
|  |         if is_virtual: | ||
|  |             methods_wrap += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                 if nargin == 2 | ||
|  |                   my_ptr = varargin{{2}}; | ||
|  |                 else | ||
|  |                   my_ptr = {wrapper_name}({id}, varargin{{2}}); | ||
|  |                 end | ||
|  |             ''').format(wrapper_name=self._wrapper_name(), id=self._update_wrapper_id() + 1),
 | ||
|  |                                             prefix='      ') | ||
|  |         else: | ||
|  |             methods_wrap += '      my_ptr = varargin{2};\n' | ||
|  | 
 | ||
|  |         collector_base_id = self._update_wrapper_id((namespace_name, inst_class, 'collectorInsertAndMakeBase', None), | ||
|  |                                                     id_diff=-1 if is_virtual else 0) | ||
|  | 
 | ||
|  |         methods_wrap += '      {ptr}{wrapper_name}({id}, my_ptr);\n' \ | ||
|  |             .format( | ||
|  |             ptr='base_ptr = ' if has_parent else '', | ||
|  |             wrapper_name=self._wrapper_name(), | ||
|  |             id=collector_base_id - (1 if is_virtual else 0)) | ||
|  | 
 | ||
|  |         for ctor in ctors: | ||
|  |             wrapper_return = '[ my_ptr, base_ptr ] = ' \ | ||
|  |                 if has_parent \ | ||
|  |                 else 'my_ptr = ' | ||
|  | 
 | ||
|  |             methods_wrap += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                 elseif nargin == {len}{varargin} | ||
|  |                   {ptr}{wrapper}({num}{comma}{var_arg}); | ||
|  |             ''').format(len=len(ctor.args.args_list),
 | ||
|  |                         varargin=self._wrap_variable_arguments(ctor.args, False), | ||
|  |                         ptr=wrapper_return, | ||
|  |                         wrapper=self._wrapper_name(), | ||
|  |                         num=self._update_wrapper_id((namespace_name, inst_class, 'constructor', ctor)), | ||
|  |                         comma='' if len(ctor.args.args_list) == 0 else ', ', | ||
|  |                         var_arg=self._wrap_list_variable_arguments(ctor.args)), | ||
|  |                                             prefix='    ') | ||
|  | 
 | ||
|  |         base_obj = '' | ||
|  | 
 | ||
|  |         if has_parent: | ||
|  |             self._debug("class: {} ns: {}".format(parent_name, self._format_class_name(inst_class.parent, | ||
|  |                                                                                        separator="."))) | ||
|  | 
 | ||
|  |         if has_parent: | ||
|  |             base_obj = '  obj = obj@{parent_name}(uint64(5139824614673773682), base_ptr);'.format( | ||
|  |                 parent_name=parent_name) | ||
|  | 
 | ||
|  |         if base_obj: | ||
|  |             base_obj = '\n' + base_obj | ||
|  | 
 | ||
|  |         self._debug("class: {}, name: {}".format(inst_class.name, self._format_class_name(inst_class, separator="."))) | ||
|  |         methods_wrap += textwrap.indent(textwrap.dedent('''\
 | ||
|  |               else | ||
|  |                 error('Arguments do not match any overload of {class_name_doc} constructor'); | ||
|  |               end{base_obj} | ||
|  |               obj.ptr_{class_name} = my_ptr; | ||
|  |             end\n | ||
|  |         ''').format(namespace=namespace_name,
 | ||
|  |                     d='' if namespace_name == '' else '.', | ||
|  |                     class_name_doc=self._format_class_name(inst_class, separator="."), | ||
|  |                     class_name=self._format_class_name(inst_class, separator=""), | ||
|  |                     base_obj=base_obj), | ||
|  |                                         prefix='  ') | ||
|  | 
 | ||
|  |         return methods_wrap | ||
|  | 
 | ||
|  |     def wrap_class_properties(self, class_name): | ||
|  |         """Generate properties of class.""" | ||
|  |         return textwrap.dedent('''\
 | ||
|  |             properties | ||
|  |               ptr_{} = 0 | ||
|  |             end | ||
|  |         ''').format(class_name)
 | ||
|  | 
 | ||
|  |     def wrap_class_deconstructor(self, namespace_name, inst_class): | ||
|  |         """Generate the delete function for the Matlab class.""" | ||
|  |         class_name = inst_class.name | ||
|  | 
 | ||
|  |         methods_text = textwrap.indent(textwrap.dedent("""\
 | ||
|  |             function delete(obj) | ||
|  |               {wrapper}({num}, obj.ptr_{class_name}); | ||
|  |             end\n | ||
|  |         """).format(num=self._update_wrapper_id((namespace_name, inst_class, 'deconstructor', None)),
 | ||
|  |                     wrapper=self._wrapper_name(), | ||
|  |                     class_name="".join(inst_class.parent.full_namespaces()) + class_name), | ||
|  |                                        prefix='  ') | ||
|  | 
 | ||
|  |         return methods_text | ||
|  | 
 | ||
|  |     def wrap_class_display(self): | ||
|  |         """Generate the display function for the Matlab class.""" | ||
|  |         return textwrap.indent(textwrap.dedent("""\
 | ||
|  |             function display(obj), obj.print(''); end | ||
|  |             %DISPLAY Calls print on the object | ||
|  |             function disp(obj), obj.display; end | ||
|  |             %DISP Calls print on the object | ||
|  |         """),
 | ||
|  |                                prefix='  ') | ||
|  | 
 | ||
|  |     def _group_class_methods(self, methods): | ||
|  |         """Group overloaded methods together""" | ||
|  |         method_map = {} | ||
|  |         method_out = [] | ||
|  | 
 | ||
|  |         for method in methods: | ||
|  |             method_index = method_map.get(method.name) | ||
|  | 
 | ||
|  |             if method_index is None: | ||
|  |                 method_map[method.name] = len(method_out) | ||
|  |                 method_out.append([method]) | ||
|  |             else: | ||
|  |                 # import sys | ||
|  |                 # print("[_group_methods] Merging {} with {}".format(method_index, method.name)) | ||
|  |                 method_out[method_index].append(method) | ||
|  | 
 | ||
|  |         return method_out | ||
|  | 
 | ||
|  |     @classmethod | ||
|  |     def _format_varargout(cls, return_type, return_type_formatted): | ||
|  |         """Determine format of return and varargout statements""" | ||
|  |         if cls._return_count(return_type) == 1: | ||
|  |             varargout = '' \ | ||
|  |                 if return_type_formatted == 'void' \ | ||
|  |                 else 'varargout{1} = ' | ||
|  |         else: | ||
|  |             varargout = '[ varargout{1} varargout{2} ] = ' | ||
|  | 
 | ||
|  |         return varargout | ||
|  | 
 | ||
|  |     def wrap_class_methods(self, namespace_name, inst_class, methods, serialize=[False]): | ||
|  |         """Wrap the methods in the class.
 | ||
|  | 
 | ||
|  |         Args: | ||
|  |             namespace_name: the name of the class's namespace | ||
|  |             inst_class: the instantiated class whose methods to wrap | ||
|  |             methods: the methods to wrap in the order to wrap them | ||
|  |             serialize: mutable param storing if one of the methods is serialize | ||
|  |         """
 | ||
|  |         method_text = '' | ||
|  | 
 | ||
|  |         methods = self._group_class_methods(methods) | ||
|  | 
 | ||
|  |         for method in methods: | ||
|  |             method_name = method[0].name | ||
|  |             if method_name in self.whitelist and method_name != 'serialize': | ||
|  |                 continue | ||
|  | 
 | ||
|  |             if method_name == 'serialize': | ||
|  |                 serialize[0] = True | ||
|  |                 method_text += self.wrap_class_serialize_method(namespace_name, inst_class) | ||
|  |             else: | ||
|  |                 # Generate method code | ||
|  |                 method_text += textwrap.indent(textwrap.dedent("""\
 | ||
|  |                     function varargout = {method_name}(this, varargin) | ||
|  |                     """).format(caps_name=method_name.upper(), method_name=method_name),
 | ||
|  |                                                prefix='') | ||
|  | 
 | ||
|  |                 for overload in method: | ||
|  |                     method_text += textwrap.indent(textwrap.dedent("""\
 | ||
|  |                     % {caps_name} usage: {method_name}(""").format(caps_name=method_name.upper(),
 | ||
|  |                                                                    method_name=method_name), | ||
|  |                                                    prefix='  ') | ||
|  | 
 | ||
|  |                     # Determine format of return and varargout statements | ||
|  |                     return_type_formatted = self._format_return_type( | ||
|  |                         overload.return_type, | ||
|  |                         include_namespace=True, | ||
|  |                         separator="." | ||
|  |                     ) | ||
|  |                     varargout = self._format_varargout(overload.return_type, return_type_formatted) | ||
|  | 
 | ||
|  |                     check_statement = self._wrap_method_check_statement(overload.args) | ||
|  |                     class_name = namespace_name + ('' if namespace_name == '' else '.') + inst_class.name | ||
|  | 
 | ||
|  |                     end_statement = '' \ | ||
|  |                         if check_statement == '' \ | ||
|  |                         else textwrap.indent(textwrap.dedent("""\
 | ||
|  |                               return | ||
|  |                             end | ||
|  |                         """).format(
 | ||
|  |                         class_name=class_name, | ||
|  |                         method_name=overload.original.name), prefix='  ') | ||
|  | 
 | ||
|  |                     method_text += textwrap.dedent("""\
 | ||
|  |                         {method_args}) : returns {return_type} | ||
|  |                           % Doxygen can be found at https://gtsam.org/doxygen/ | ||
|  |                           {check_statement}{spacing}{varargout}{wrapper}({num}, this, varargin{{:}}); | ||
|  |                         {end_statement}""").format(method_args=self._wrap_args(overload.args),
 | ||
|  |                                                    return_type=return_type_formatted, | ||
|  |                                                    num=self._update_wrapper_id( | ||
|  |                                                        (namespace_name, inst_class, overload.original.name, overload)), | ||
|  |                                                    check_statement=check_statement, | ||
|  |                                                    spacing='' if check_statement == '' else '    ', | ||
|  |                                                    varargout=varargout, | ||
|  |                                                    wrapper=self._wrapper_name(), | ||
|  |                                                    end_statement=end_statement) | ||
|  | 
 | ||
|  |                 final_statement = textwrap.indent(textwrap.dedent("""\
 | ||
|  |                     error('Arguments do not match any overload of function {class_name}.{method_name}'); | ||
|  |                 """.format(class_name=class_name, method_name=method_name)),
 | ||
|  |                                                   prefix='  ') | ||
|  |                 method_text += final_statement + 'end\n\n' | ||
|  | 
 | ||
|  |         return method_text | ||
|  | 
 | ||
|  |     def wrap_static_methods(self, namespace_name, instantiated_class, serialize): | ||
|  |         class_name = instantiated_class.name | ||
|  | 
 | ||
|  |         method_text = 'methods(Static = true)\n' | ||
|  |         static_methods = sorted(instantiated_class.static_methods, key=lambda name: name.name) | ||
|  | 
 | ||
|  |         static_methods = self._group_class_methods(static_methods) | ||
|  | 
 | ||
|  |         for static_method in static_methods: | ||
|  |             format_name = list(static_method[0].name) | ||
|  |             format_name[0] = format_name[0].upper() | ||
|  | 
 | ||
|  |             method_text += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                     function varargout = {name}(varargin) | ||
|  |                     '''.format(name=''.join(format_name))),
 | ||
|  |                                            prefix="  ") | ||
|  | 
 | ||
|  |             for static_overload in static_method: | ||
|  |                 check_statement = self._wrap_method_check_statement(static_overload.args) | ||
|  | 
 | ||
|  |                 end_statement = '' \ | ||
|  |                     if check_statement == '' \ | ||
|  |                     else textwrap.indent(textwrap.dedent("""
 | ||
|  |                               return | ||
|  |                             end | ||
|  |                             """), prefix='')
 | ||
|  |                 method_text += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                       % {name_caps} usage: {name_upper_case}({args}) : returns {return_type} | ||
|  |                       % Doxygen can be found at https://gtsam.org/doxygen/ | ||
|  |                       {check_statement}{spacing}varargout{{1}} = {wrapper}({id}, varargin{{:}});{end_statement} | ||
|  |                       ''').format(name=''.join(format_name),
 | ||
|  |                                   name_caps=static_overload.name.upper(), | ||
|  |                                   name_upper_case=static_overload.name, | ||
|  |                                   args=self._wrap_args(static_overload.args), | ||
|  |                                   return_type=self._format_return_type(static_overload.return_type, | ||
|  |                                                                        include_namespace=True, separator="."), | ||
|  |                                   length=len(static_overload.args.args_list), | ||
|  |                                   var_args_list=self._wrap_variable_arguments(static_overload.args), | ||
|  |                                   check_statement=check_statement, | ||
|  |                                   spacing='' if check_statement == '' else '  ', | ||
|  |                                   wrapper=self._wrapper_name(), | ||
|  |                                   id=self._update_wrapper_id( | ||
|  |                                       (namespace_name, instantiated_class, static_overload.name, static_overload)), | ||
|  |                                   class_name=instantiated_class.name, | ||
|  |                                   end_statement=end_statement), | ||
|  |                                                prefix='    ') | ||
|  | 
 | ||
|  |             method_text += textwrap.indent(textwrap.dedent("""\
 | ||
|  |                     error('Arguments do not match any overload of function {class_name}.{method_name}'); | ||
|  |                 """.format(class_name=class_name, method_name=static_overload.name)),
 | ||
|  |                                            prefix='    ') | ||
|  |             method_text += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                                     end\n | ||
|  |                 '''.format(name=''.join(format_name))),
 | ||
|  |                                            prefix="  ") | ||
|  | 
 | ||
|  |         if serialize: | ||
|  |             method_text += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                 function varargout = string_deserialize(varargin) | ||
|  |                   % STRING_DESERIALIZE usage: string_deserialize() : returns {class_name} | ||
|  |                   % Doxygen can be found at https://gtsam.org/doxygen/ | ||
|  |                   if length(varargin) == 1 | ||
|  |                     varargout{{1}} = {wrapper}({id}, varargin{{:}}); | ||
|  |                   else | ||
|  |                     error('Arguments do not match any overload of function {class_name}.string_deserialize'); | ||
|  |                   end | ||
|  |                 end\n | ||
|  |                 function obj = loadobj(sobj) | ||
|  |                   % LOADOBJ Saves the object to a matlab-readable format | ||
|  |                   obj = {class_name}.string_deserialize(sobj); | ||
|  |                 end | ||
|  |             ''').format(class_name=namespace_name + '.' + instantiated_class.name,
 | ||
|  |                         wrapper=self._wrapper_name(), | ||
|  |                         id=self._update_wrapper_id( | ||
|  |                             (namespace_name, instantiated_class, 'string_deserialize', 'deserialize'))), | ||
|  |                                            prefix='  ') | ||
|  | 
 | ||
|  |         return method_text | ||
|  | 
 | ||
|  |     def wrap_instantiated_class(self, instantiated_class, namespace_name=''): | ||
|  |         """Generate comments and code for given class.
 | ||
|  | 
 | ||
|  |         Args: | ||
|  |             instantiated_class: template_instantiator.InstantiatedClass | ||
|  |                 instance storing the class to wrap | ||
|  |             namespace_name: the name of the namespace if there is one | ||
|  |         """
 | ||
|  |         file_name = self._clean_class_name(instantiated_class) | ||
|  |         namespace_file_name = namespace_name + file_name | ||
|  | 
 | ||
|  |         if instantiated_class.cpp_class() in self.ignore_classes: | ||
|  |             return None | ||
|  | 
 | ||
|  |         # Class comment | ||
|  |         content_text = self.class_comment(instantiated_class) | ||
|  |         content_text += self.wrap_methods(instantiated_class.methods) | ||
|  | 
 | ||
|  |         # Class definition | ||
|  |         # import sys | ||
|  |         # if namespace_name: | ||
|  |         #     print("nsname: {}, file_name_: {}, filename: {}" | ||
|  |         #           .format(namespace_name, | ||
|  |         #                   self._clean_class_name(instantiated_class), file_name) | ||
|  |         #           , file=sys.stderr) | ||
|  |         content_text += 'classdef {class_name} < {parent}\n'.format( | ||
|  |             class_name=file_name, parent=str(self._qualified_name(instantiated_class.parent_class)).replace("::", ".")) | ||
|  | 
 | ||
|  |         # Class properties | ||
|  |         content_text += '  ' + reduce(self._insert_spaces, | ||
|  |                                       self.wrap_class_properties(namespace_file_name).splitlines()) + '\n' | ||
|  | 
 | ||
|  |         # Class constructor | ||
|  |         content_text += '  ' + reduce( | ||
|  |             self._insert_spaces, | ||
|  |             self.wrap_class_constructors( | ||
|  |                 namespace_name, | ||
|  |                 instantiated_class, | ||
|  |                 instantiated_class.parent_class, | ||
|  |                 instantiated_class.ctors, | ||
|  |                 instantiated_class.is_virtual, | ||
|  |             ).splitlines()) + '\n' | ||
|  | 
 | ||
|  |         # Delete function | ||
|  |         content_text += '  ' + reduce(self._insert_spaces, | ||
|  |                                       self.wrap_class_deconstructor(namespace_name, | ||
|  |                                                                     instantiated_class).splitlines()) + '\n' | ||
|  | 
 | ||
|  |         # Display function | ||
|  |         content_text += '  ' + reduce(self._insert_spaces, self.wrap_class_display().splitlines()) + '\n' | ||
|  | 
 | ||
|  |         # Class methods | ||
|  |         serialize = [False] | ||
|  | 
 | ||
|  |         if len(instantiated_class.methods) != 0: | ||
|  |             methods = sorted(instantiated_class.methods, key=lambda name: name.name) | ||
|  |             class_methods_wrapped = self.wrap_class_methods(namespace_name, | ||
|  |                                                             instantiated_class, | ||
|  |                                                             methods, | ||
|  |                                                             serialize=serialize).splitlines() | ||
|  |             if len(class_methods_wrapped) > 0: | ||
|  |                 content_text += '    ' + reduce(lambda x, y: x + '\n' + | ||
|  |                                                              ('' if y == '' else '    ') + y, | ||
|  |                                                 class_methods_wrapped) + '\n' | ||
|  | 
 | ||
|  |         # Static class methods | ||
|  |         content_text += '  end\n\n  ' + reduce( | ||
|  |             self._insert_spaces, | ||
|  |             self.wrap_static_methods(namespace_name, instantiated_class, serialize[0]).splitlines()) + '\n' | ||
|  | 
 | ||
|  |         content_text += textwrap.dedent('''\
 | ||
|  |               end | ||
|  |             end | ||
|  |         ''')
 | ||
|  | 
 | ||
|  |         return file_name + '.m', content_text | ||
|  | 
 | ||
|  |     def wrap_namespace(self, namespace, parent=[]): | ||
|  |         """Wrap a namespace by wrapping all of its components.
 | ||
|  | 
 | ||
|  |         Args: | ||
|  |             namespace: the interface_parser.namespace instance of the namespace | ||
|  |             parent: parent namespace | ||
|  |         """
 | ||
|  |         test_output = '' | ||
|  |         namespaces = namespace.full_namespaces() | ||
|  |         inner_namespace = namespace.name != '' | ||
|  |         wrapped = [] | ||
|  |         self._debug("wrapping ns: {}, parent: {}".format(namespace.full_namespaces(), parent)) | ||
|  | 
 | ||
|  |         matlab_wrapper = self.generate_matlab_wrapper() | ||
|  |         self.content.append((matlab_wrapper[0], matlab_wrapper[1])) | ||
|  | 
 | ||
|  |         current_scope = [] | ||
|  |         namespace_scope = [] | ||
|  | 
 | ||
|  |         for element in namespace.content: | ||
|  |             if isinstance(element, parser.Include): | ||
|  |                 self._add_include(element) | ||
|  |             elif isinstance(element, parser.Namespace): | ||
|  |                 self.wrap_namespace(element, namespaces) | ||
|  |             elif isinstance(element, instantiator.InstantiatedClass): | ||
|  |                 self._add_class(element) | ||
|  | 
 | ||
|  |                 if inner_namespace: | ||
|  |                     class_text = self.wrap_instantiated_class(element, "".join(namespace.full_namespaces())) | ||
|  | 
 | ||
|  |                     if not class_text is None: | ||
|  |                         namespace_scope.append(("".join(['+' + x + '/' for x in namespace.full_namespaces()[1:]])[:-1], | ||
|  |                                                 [(class_text[0], class_text[1])])) | ||
|  |                 else: | ||
|  |                     class_text = self.wrap_instantiated_class(element) | ||
|  |                     current_scope.append((class_text[0], class_text[1])) | ||
|  | 
 | ||
|  |         self.content.extend(current_scope) | ||
|  | 
 | ||
|  |         if inner_namespace: | ||
|  |             self.content.append(namespace_scope) | ||
|  | 
 | ||
|  |         # Global functions | ||
|  |         all_funcs = [func for func in namespace.content if isinstance(func, parser.GlobalFunction)] | ||
|  | 
 | ||
|  |         test_output += self.wrap_methods(all_funcs, True, global_ns=namespace) | ||
|  | 
 | ||
|  |         return wrapped | ||
|  | 
 | ||
|  |     def wrap_collector_function_shared_return(self, return_type_name, shared_obj, id, new_line=True): | ||
|  |         new_line = '\n' if new_line else '' | ||
|  | 
 | ||
|  |         return textwrap.indent(textwrap.dedent('''\
 | ||
|  |             {{ | ||
|  |             boost::shared_ptr<{name}> shared({shared_obj}); | ||
|  |             out[{id}] = wrap_shared_ptr(shared,"{name}"); | ||
|  |             }}{new_line}''').format(name=self._format_type_name(return_type_name, include_namespace=False),
 | ||
|  |                                     shared_obj=shared_obj, | ||
|  |                                     id=id, | ||
|  |                                     new_line=new_line), | ||
|  |                                prefix='  ') | ||
|  | 
 | ||
|  |     def wrap_collector_function_return_types(self, return_type, id): | ||
|  |         return_type_text = '  out[' + str(id) + '] = ' | ||
|  |         pair_value = 'first' if id == 0 else 'second' | ||
|  |         new_line = '\n' if id == 0 else '' | ||
|  | 
 | ||
|  |         if self._is_ptr(return_type): | ||
|  |             shared_obj = 'pairResult.' + pair_value | ||
|  | 
 | ||
|  |             if not return_type.is_ptr: | ||
|  |                 shared_obj = 'boost::make_shared<{name}>({shared_obj})' \ | ||
|  |                     .format( | ||
|  |                     name=self._format_type_name(return_type.typename), | ||
|  |                     formatted_name=self._format_type_name( | ||
|  |                         return_type.typename), | ||
|  |                     shared_obj='pairResult.' + pair_value) | ||
|  | 
 | ||
|  |             if return_type.typename.name in self.ignore_namespace: | ||
|  |                 return_type_text = self.wrap_collector_function_shared_return(return_type.typename, shared_obj, id, | ||
|  |                                                                               True if id == 0 else False) | ||
|  |             else: | ||
|  |                 return_type_text += 'wrap_shared_ptr({},"{}", false);{new_line}' \ | ||
|  |                     .format( | ||
|  |                     shared_obj, | ||
|  |                     self._format_type_name( | ||
|  |                         return_type.typename, separator='.'), | ||
|  |                     new_line=new_line) | ||
|  |         else: | ||
|  |             return_type_text += 'wrap< {} >(pairResult.{});{}'.format( | ||
|  |                 self._format_type_name(return_type.typename, separator='.'), pair_value, new_line) | ||
|  | 
 | ||
|  |         return return_type_text | ||
|  | 
 | ||
|  |     def wrap_collector_function_return(self, method): | ||
|  |         expanded = '' | ||
|  | 
 | ||
|  |         params = self._wrapper_unwrap_arguments(method.args, arg_id=1)[0] | ||
|  | 
 | ||
|  |         return_1 = method.return_type.type1 | ||
|  |         return_count = self._return_count(method.return_type) | ||
|  |         return_1_name = method.return_type.type1.typename.name | ||
|  |         obj_start = '' | ||
|  | 
 | ||
|  |         if isinstance(method, instantiator.InstantiatedMethod): | ||
|  |             # method_name = method.original.name | ||
|  |             method_name = method.to_cpp() | ||
|  |             obj_start = 'obj->' | ||
|  | 
 | ||
|  |             if method.instantiation: | ||
|  |                 # method_name += '<{}>'.format( | ||
|  |                 #     self._format_type_name(method.instantiation)) | ||
|  |                 # method_name = self._format_instance_method(method, '::') | ||
|  |                 method = method.to_cpp() | ||
|  |         elif isinstance(method, parser.GlobalFunction): | ||
|  |             method_name = self._format_global_method(method, '::') | ||
|  |             method_name += method.name | ||
|  |         else: | ||
|  |             if isinstance(method.parent, instantiator.InstantiatedClass): | ||
|  |                 method_name = method.parent.cpp_class() + "::" | ||
|  |             else: | ||
|  |                 method_name = self._format_static_method(method, '::') | ||
|  |             method_name += method.name | ||
|  | 
 | ||
|  |         if "MeasureRange" in method_name: | ||
|  |             self._debug("method: {}, method: {}, inst: {}".format(method_name, method.name, method.parent.cpp_class())) | ||
|  | 
 | ||
|  |         obj = '  ' if return_1_name == 'void' else '' | ||
|  |         obj += '{}{}({})'.format(obj_start, method_name, params) | ||
|  | 
 | ||
|  |         if return_1_name != 'void': | ||
|  |             if return_count == 1: | ||
|  |                 if self._is_ptr(return_1): | ||
|  |                     sep_method_name = partial(self._format_type_name, return_1.typename, include_namespace=True) | ||
|  | 
 | ||
|  |                     if return_1.typename.name in self.ignore_namespace: | ||
|  |                         expanded += self.wrap_collector_function_shared_return(return_1.typename, | ||
|  |                                                                                obj, | ||
|  |                                                                                0, | ||
|  |                                                                                new_line=False) | ||
|  | 
 | ||
|  |                     if return_1.is_ptr: | ||
|  |                         shared_obj = '{obj},"{method_name_sep}"'.format(obj=obj, method_name_sep=sep_method_name('.')) | ||
|  |                     else: | ||
|  |                         self._debug("Non-PTR: {}, {}".format(return_1, type(return_1))) | ||
|  |                         self._debug("Inner type is: {}, {}".format(return_1.typename.name, sep_method_name('.'))) | ||
|  |                         self._debug("Inner type instantiations: {}".format(return_1.typename.instantiations)) | ||
|  |                         method_name_sep_dot = sep_method_name('.') | ||
|  |                         shared_obj = 'boost::make_shared<{method_name_sep_col}>({obj}),"{method_name_sep_dot}"' \ | ||
|  |                             .format( | ||
|  |                             method_name=return_1.typename.name, | ||
|  |                             method_name_sep_col=sep_method_name(), | ||
|  |                             method_name_sep_dot=method_name_sep_dot, | ||
|  |                             obj=obj) | ||
|  | 
 | ||
|  |                     if return_1.typename.name not in self.ignore_namespace: | ||
|  |                         expanded += textwrap.indent('out[0] = wrap_shared_ptr({}, false);'.format(shared_obj), | ||
|  |                                                     prefix='  ') | ||
|  |                 else: | ||
|  |                     expanded += '  out[0] = wrap< {} >({});'.format(return_1.typename.name, obj) | ||
|  |             elif return_count == 2: | ||
|  |                 return_2 = method.return_type.type2 | ||
|  | 
 | ||
|  |                 expanded += '  auto pairResult = {};\n'.format(obj) | ||
|  |                 expanded += self.wrap_collector_function_return_types(return_1, 0) | ||
|  |                 expanded += self.wrap_collector_function_return_types(return_2, 1) | ||
|  |         else: | ||
|  |             expanded += obj + ';' | ||
|  | 
 | ||
|  |         return expanded | ||
|  | 
 | ||
|  |     def wrap_collector_function_upcast_from_void(self, class_name, id, cpp_name): | ||
|  |         return textwrap.dedent('''\
 | ||
|  |             void {class_name}_upcastFromVoid_{id}(int nargout, mxArray *out[], int nargin, const mxArray *in[]) {{ | ||
|  |               mexAtExit(&_deleteAllObjects); | ||
|  |               typedef boost::shared_ptr<{cpp_name}> Shared; | ||
|  |               boost::shared_ptr<void> *asVoid = *reinterpret_cast<boost::shared_ptr<void>**> (mxGetData(in[0])); | ||
|  |               out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); | ||
|  |               Shared *self = new Shared(boost::static_pointer_cast<{cpp_name}>(*asVoid)); | ||
|  |               *reinterpret_cast<Shared**>(mxGetData(out[0])) = self; | ||
|  |             }}\n | ||
|  |         ''').format(class_name=class_name, cpp_name=cpp_name, id=id)
 | ||
|  | 
 | ||
|  |     def generate_collector_function(self, id): | ||
|  |         collector_func = self.wrapper_map.get(id) | ||
|  | 
 | ||
|  |         if collector_func is None: | ||
|  |             return '' | ||
|  | 
 | ||
|  |         method_name = collector_func[3] | ||
|  |         # import sys | ||
|  |         # print("[Collector Gen] id: {}, obj: {}".format(id, method_name), file=sys.stderr) | ||
|  | 
 | ||
|  |         collector_function = 'void {}(int nargout, mxArray *out[], int nargin, const mxArray *in[])\n' \ | ||
|  |             .format(method_name) | ||
|  | 
 | ||
|  |         if isinstance(collector_func[1], instantiator.InstantiatedClass): | ||
|  |             body = '{\n' | ||
|  | 
 | ||
|  |             extra = collector_func[4] | ||
|  | 
 | ||
|  |             class_name = collector_func[0] + collector_func[1].name | ||
|  |             class_name_separated = collector_func[1].cpp_class() | ||
|  |             is_method = isinstance(extra, parser.Method) | ||
|  |             is_static_method = isinstance(extra, parser.StaticMethod) | ||
|  | 
 | ||
|  |             if collector_func[2] == 'collectorInsertAndMakeBase': | ||
|  |                 body += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                     mexAtExit(&_deleteAllObjects); | ||
|  |                     typedef boost::shared_ptr<{class_name_sep}> Shared;\n | ||
|  |                     Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0])); | ||
|  |                     collector_{class_name}.insert(self); | ||
|  |                 ''').format(class_name_sep=class_name_separated, class_name=class_name),
 | ||
|  |                                         prefix='  ') | ||
|  | 
 | ||
|  |                 if collector_func[1].parent_class: | ||
|  |                     body += textwrap.indent(textwrap.dedent('''
 | ||
|  |                         typedef boost::shared_ptr<{}> SharedBase; | ||
|  |                         out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); | ||
|  |                         *reinterpret_cast<SharedBase**>(mxGetData(out[0])) = new SharedBase(*self); | ||
|  |                     ''').format(collector_func[1].parent_class),
 | ||
|  |                                             prefix='  ') | ||
|  |             elif collector_func[2] == 'constructor': | ||
|  |                 base = '' | ||
|  |                 params, body_args = self._wrapper_unwrap_arguments(extra.args, constructor=True) | ||
|  | 
 | ||
|  |                 if collector_func[1].parent_class: | ||
|  |                     base += textwrap.indent(textwrap.dedent('''
 | ||
|  |                         typedef boost::shared_ptr<{}> SharedBase; | ||
|  |                         out[1] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); | ||
|  |                         *reinterpret_cast<SharedBase**>(mxGetData(out[1])) = new SharedBase(*self); | ||
|  |                     ''').format(collector_func[1].parent_class),
 | ||
|  |                                             prefix='  ') | ||
|  | 
 | ||
|  |                 body += textwrap.dedent('''\
 | ||
|  |                       mexAtExit(&_deleteAllObjects); | ||
|  |                       typedef boost::shared_ptr<{class_name_sep}> Shared;\n | ||
|  |                     {body_args}  Shared *self = new Shared(new {class_name_sep}({params})); | ||
|  |                       collector_{class_name}.insert(self); | ||
|  |                       out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); | ||
|  |                       *reinterpret_cast<Shared**> (mxGetData(out[0])) = self; | ||
|  |                     {base}''').format(class_name_sep=class_name_separated,
 | ||
|  |                                       body_args=body_args, | ||
|  |                                       params=params, | ||
|  |                                       class_name=class_name, | ||
|  |                                       base=base) | ||
|  |             elif collector_func[2] == 'deconstructor': | ||
|  |                 body += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                     typedef boost::shared_ptr<{class_name_sep}> Shared; | ||
|  |                     checkArguments("delete_{class_name}",nargout,nargin,1); | ||
|  |                     Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0])); | ||
|  |                     Collector_{class_name}::iterator item; | ||
|  |                     item = collector_{class_name}.find(self); | ||
|  |                     if(item != collector_{class_name}.end()) {{ | ||
|  |                       delete self; | ||
|  |                       collector_{class_name}.erase(item); | ||
|  |                     }} | ||
|  |                 ''').format(class_name_sep=class_name_separated, class_name=class_name),
 | ||
|  |                                         prefix='  ') | ||
|  |             elif extra == 'serialize': | ||
|  |                 body += self.wrap_collector_function_serialize(collector_func[1].name, | ||
|  |                                                                full_name=collector_func[1].cpp_class(), | ||
|  |                                                                namespace=collector_func[0]) | ||
|  |             elif extra == 'deserialize': | ||
|  |                 body += self.wrap_collector_function_deserialize(collector_func[1].name, | ||
|  |                                                                  full_name=collector_func[1].cpp_class(), | ||
|  |                                                                  namespace=collector_func[0]) | ||
|  |             elif is_method or is_static_method: | ||
|  |                 method_name = '' | ||
|  | 
 | ||
|  |                 if is_static_method: | ||
|  |                     method_name = self._format_static_method(extra) + '.' | ||
|  | 
 | ||
|  |                 method_name += extra.name | ||
|  | 
 | ||
|  |                 return_type = extra.return_type | ||
|  |                 return_count = self._return_count(return_type) | ||
|  | 
 | ||
|  |                 return_body = self.wrap_collector_function_return(extra) | ||
|  |                 params, body_args = self._wrapper_unwrap_arguments(extra.args, arg_id=1 if is_method else 0) | ||
|  | 
 | ||
|  |                 shared_obj = '' | ||
|  | 
 | ||
|  |                 if is_method: | ||
|  |                     shared_obj = \ | ||
|  |                         '  auto obj = unwrap_shared_ptr<{class_name_sep}>(in[0], "ptr_{class_name}");\n'.format( | ||
|  |                             class_name_sep=class_name_separated, | ||
|  |                             class_name=class_name) | ||
|  |                 """
 | ||
|  |                 body += textwrap.dedent('''\
 | ||
|  |                       typedef std::shared_ptr<{class_name_sep}> Shared; | ||
|  |                       checkArguments("{method_name}",nargout,nargin{min1},{num_args}); | ||
|  |                     {shared_obj} | ||
|  |                     {body_args} | ||
|  |                     {return_body} | ||
|  |                 ''')
 | ||
|  |                 """
 | ||
|  |                 body += '  checkArguments("{method_name}",nargout,nargin{min1},' \ | ||
|  |                         '{num_args});\n' \ | ||
|  |                         '{shared_obj}' \ | ||
|  |                         '{body_args}' \ | ||
|  |                         '{return_body}\n'.format( | ||
|  |                     class_name_sep=class_name_separated, | ||
|  |                     min1='-1' if is_method else '', | ||
|  |                     shared_obj=shared_obj, | ||
|  |                     method_name=method_name, | ||
|  |                     num_args=len(extra.args.args_list), | ||
|  |                     class_name=class_name, | ||
|  |                     body_args=body_args, | ||
|  |                     return_body=return_body) | ||
|  | 
 | ||
|  |             body += '}\n' | ||
|  | 
 | ||
|  |             if extra not in ['serialize', 'deserialize']: | ||
|  |                 body += '\n' | ||
|  | 
 | ||
|  |             collector_function += body | ||
|  |         else: | ||
|  |             # import sys | ||
|  |             # print("other func: {}".format(collector_func[1]), file=sys.stderr) | ||
|  |             body = textwrap.dedent('''\
 | ||
|  |                 {{ | ||
|  |                   checkArguments("{function_name}",nargout,nargin,{len}); | ||
|  |             ''').format(function_name=collector_func[1].name,
 | ||
|  |                         id=self.global_function_id, | ||
|  |                         len=len(collector_func[1].args.args_list)) | ||
|  | 
 | ||
|  |             body += self._wrapper_unwrap_arguments(collector_func[1].args)[1] | ||
|  |             body += self.wrap_collector_function_return(collector_func[1]) + '\n}\n' | ||
|  | 
 | ||
|  |             collector_function += body | ||
|  | 
 | ||
|  |             self.global_function_id += 1 | ||
|  | 
 | ||
|  |         return collector_function | ||
|  | 
 | ||
|  |     def mex_function(self): | ||
|  |         cases = '' | ||
|  |         next_case = None | ||
|  | 
 | ||
|  |         for wrapper_id in range(self.wrapper_id): | ||
|  |             id_val = self.wrapper_map.get(wrapper_id) | ||
|  |             set_next_case = False | ||
|  | 
 | ||
|  |             if id_val is None: | ||
|  |                 id_val = self.wrapper_map.get(wrapper_id + 1) | ||
|  | 
 | ||
|  |                 if id_val is None: | ||
|  |                     continue | ||
|  | 
 | ||
|  |                 set_next_case = True | ||
|  | 
 | ||
|  |             cases += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                 case {}: | ||
|  |                   {}(nargout, out, nargin-1, in+1); | ||
|  |                   break; | ||
|  |                 ''').format(wrapper_id, next_case if next_case else id_val[3]),
 | ||
|  |                                      prefix='    ') | ||
|  | 
 | ||
|  |             if set_next_case: | ||
|  |                 next_case = '{}_upcastFromVoid_{}'.format(id_val[1].name, wrapper_id + 1) | ||
|  |             else: | ||
|  |                 next_case = None | ||
|  | 
 | ||
|  |         mex_function = textwrap.dedent('''
 | ||
|  |             void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) | ||
|  |             {{ | ||
|  |               mstream mout; | ||
|  |               std::streambuf *outbuf = std::cout.rdbuf(&mout);\n | ||
|  |               _{module_name}_RTTIRegister();\n | ||
|  |               int id = unwrap<int>(in[0]);\n | ||
|  |               try {{ | ||
|  |                 switch(id) {{ | ||
|  |             {cases}    }} | ||
|  |               }} catch(const std::exception& e) {{ | ||
|  |                 mexErrMsgTxt(("Exception from gtsam:\\n" + std::string(e.what()) + "\\n").c_str()); | ||
|  |               }}\n | ||
|  |               std::cout.rdbuf(outbuf); | ||
|  |             }} | ||
|  |         ''').format(module_name=self.module_name, cases=cases)
 | ||
|  | 
 | ||
|  |         return mex_function | ||
|  | 
 | ||
|  |     def generate_wrapper(self, namespace): | ||
|  |         """Generate the c++ wrapper.""" | ||
|  |         # Includes | ||
|  |         wrapper_file = textwrap.dedent('''\
 | ||
|  |             #include <wrap/matlab.h> | ||
|  |             #include <map>\n | ||
|  |             #include <boost/archive/text_iarchive.hpp> | ||
|  |             #include <boost/archive/text_oarchive.hpp> | ||
|  |             #include <boost/serialization/export.hpp>\n | ||
|  |         ''')
 | ||
|  | 
 | ||
|  |         includes_list = sorted(list(self.includes.keys()), key=lambda include: include.header) | ||
|  | 
 | ||
|  |         wrapper_file += reduce(lambda x, y: str(x) + '\n' + str(y), includes_list) + '\n' | ||
|  | 
 | ||
|  |         typedef_instances = '\n' | ||
|  |         typedef_collectors = '' | ||
|  |         boost_class_export_guid = '' | ||
|  |         delete_objs = textwrap.dedent('''\
 | ||
|  |             void _deleteAllObjects() | ||
|  |             { | ||
|  |               mstream mout; | ||
|  |               std::streambuf *outbuf = std::cout.rdbuf(&mout);\n | ||
|  |               bool anyDeleted = false; | ||
|  |         ''')
 | ||
|  |         rtti_reg_start = textwrap.dedent('''\
 | ||
|  |             void _{module_name}_RTTIRegister() {{ | ||
|  |               const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_{module_name}_rttiRegistry_created"); | ||
|  |               if(!alreadyCreated) {{ | ||
|  |                 std::map<std::string, std::string> types; | ||
|  |         ''').format(module_name=self.module_name)
 | ||
|  |         rtti_reg_mid = '' | ||
|  |         rtti_reg_end = textwrap.indent(textwrap.dedent('''
 | ||
|  |                 mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); | ||
|  |                 if(!registry) | ||
|  |                   registry = mxCreateStructMatrix(1, 1, 0, NULL); | ||
|  |                 typedef std::pair<std::string, std::string> StringPair; | ||
|  |                 for(const StringPair& rtti_matlab: types) { | ||
|  |                   int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); | ||
|  |                   if(fieldId < 0) | ||
|  |                     mexErrMsgTxt("gtsam wrap:  Error indexing RTTI types, inheritance will not work correctly"); | ||
|  |                   mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); | ||
|  |                   mxSetFieldByNumber(registry, 0, fieldId, matlabName); | ||
|  |                 } | ||
|  |                 if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) | ||
|  |                   mexErrMsgTxt("gtsam wrap:  Error indexing RTTI types, inheritance will not work correctly"); | ||
|  |                 mxDestroyArray(registry); | ||
|  |         '''),
 | ||
|  |                                        prefix='    ') + '    \n' + textwrap.dedent('''\
 | ||
|  |                 mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); | ||
|  |                 if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) | ||
|  |                   mexErrMsgTxt("gtsam wrap:  Error indexing RTTI types, inheritance will not work correctly"); | ||
|  |                 mxDestroyArray(newAlreadyCreated); | ||
|  |               } | ||
|  |             } | ||
|  |         ''')
 | ||
|  |         ptr_ctor_frag = '' | ||
|  | 
 | ||
|  |         for cls in self.classes: | ||
|  |             if cls.cpp_class().strip() in self.ignore_classes: | ||
|  |                 continue | ||
|  | 
 | ||
|  |             def _has_serialization(cls): | ||
|  |                 for m in cls.methods: | ||
|  |                     if m.name in self.whitelist: | ||
|  |                         return True | ||
|  |                 return False | ||
|  | 
 | ||
|  |             if len(cls.instantiations): | ||
|  |                 cls_insts = '' | ||
|  | 
 | ||
|  |                 for i, inst in enumerate(cls.instantiations): | ||
|  |                     if i != 0: | ||
|  |                         cls_insts += ', ' | ||
|  | 
 | ||
|  |                     cls_insts += self._format_type_name(inst) | ||
|  | 
 | ||
|  |                 typedef_instances += 'typedef {original_class_name} {class_name_sep};\n'.format( | ||
|  |                     namespace_name=namespace.name, original_class_name=cls.cpp_class(), class_name_sep=cls.name) | ||
|  | 
 | ||
|  |                 class_name_sep = cls.name | ||
|  |                 class_name = self._format_class_name(cls) | ||
|  | 
 | ||
|  |                 if len(cls.original.namespaces()) > 1 and _has_serialization(cls): | ||
|  |                     boost_class_export_guid += 'BOOST_CLASS_EXPORT_GUID({}, "{}");\n'.format(class_name_sep, class_name) | ||
|  |             else: | ||
|  |                 class_name_sep = cls.cpp_class() | ||
|  |                 class_name = self._format_class_name(cls) | ||
|  | 
 | ||
|  |                 if len(cls.original.namespaces()) > 1 and _has_serialization(cls): | ||
|  |                     boost_class_export_guid += 'BOOST_CLASS_EXPORT_GUID({}, "{}");\n'.format(class_name_sep, class_name) | ||
|  | 
 | ||
|  |             typedef_collectors += textwrap.dedent('''\
 | ||
|  |                 typedef std::set<boost::shared_ptr<{class_name_sep}>*> Collector_{class_name}; | ||
|  |                 static Collector_{class_name} collector_{class_name}; | ||
|  |             ''').format(class_name_sep=class_name_sep, class_name=class_name)
 | ||
|  |             delete_objs += textwrap.indent(textwrap.dedent('''\
 | ||
|  |                 {{ for(Collector_{class_name}::iterator iter = collector_{class_name}.begin(); | ||
|  |                     iter != collector_{class_name}.end(); ) {{ | ||
|  |                   delete *iter; | ||
|  |                   collector_{class_name}.erase(iter++); | ||
|  |                   anyDeleted = true; | ||
|  |                 }} }} | ||
|  |             ''').format(class_name=class_name),
 | ||
|  |                                            prefix='  ') | ||
|  | 
 | ||
|  |             if cls.is_virtual: | ||
|  |                 rtti_reg_mid += '    types.insert(std::make_pair(typeid({}).name(), "{}"));\n'.format( | ||
|  |                     class_name_sep, class_name) | ||
|  | 
 | ||
|  |         set_next_case = False | ||
|  | 
 | ||
|  |         for id in range(self.wrapper_id): | ||
|  |             id_val = self.wrapper_map.get(id) | ||
|  |             queue_set_next_case = set_next_case | ||
|  | 
 | ||
|  |             set_next_case = False | ||
|  | 
 | ||
|  |             if id_val is None: | ||
|  |                 id_val = self.wrapper_map.get(id + 1) | ||
|  | 
 | ||
|  |                 if id_val is None: | ||
|  |                     continue | ||
|  | 
 | ||
|  |                 set_next_case = True | ||
|  | 
 | ||
|  |             ptr_ctor_frag += self.generate_collector_function(id) | ||
|  | 
 | ||
|  |             if queue_set_next_case: | ||
|  |                 ptr_ctor_frag += self.wrap_collector_function_upcast_from_void(id_val[1].name, id, | ||
|  |                                                                                id_val[1].cpp_class()) | ||
|  | 
 | ||
|  |         wrapper_file += textwrap.dedent('''\
 | ||
|  |             {typedef_instances} | ||
|  |             {boost_class_export_guid} | ||
|  |             {typedefs_collectors} | ||
|  |             {delete_objs}  if(anyDeleted) | ||
|  |                 cout << | ||
|  |                   "WARNING:  Wrap modules with variables in the workspace have been reloaded due to\\n" | ||
|  |                   "calling destructors, call \'clear all\' again if you plan to now recompile a wrap\\n" | ||
|  |                   "module, so that your recompiled module is used instead of the old one." << endl; | ||
|  |               std::cout.rdbuf(outbuf); | ||
|  |             }}\n | ||
|  |             {rtti_register} | ||
|  |             {pointer_constructor_fragment}{mex_function}''') \
 | ||
|  |             .format(typedef_instances=typedef_instances, | ||
|  |                     boost_class_export_guid=boost_class_export_guid, | ||
|  |                     typedefs_collectors=typedef_collectors, | ||
|  |                     delete_objs=delete_objs, | ||
|  |                     rtti_register=rtti_reg_start + rtti_reg_mid + rtti_reg_end, | ||
|  |                     pointer_constructor_fragment=ptr_ctor_frag, | ||
|  |                     mex_function=self.mex_function()) | ||
|  | 
 | ||
|  |         self.content.append((self._wrapper_name() + '.cpp', wrapper_file)) | ||
|  | 
 | ||
|  |     def wrap_class_serialize_method(self, namespace_name, inst_class): | ||
|  |         class_name = inst_class.name | ||
|  |         wrapper_id = self._update_wrapper_id((namespace_name, inst_class, 'string_serialize', 'serialize')) | ||
|  | 
 | ||
|  |         return textwrap.dedent('''\
 | ||
|  |             function varargout = string_serialize(this, varargin) | ||
|  |               % STRING_SERIALIZE usage: string_serialize() : returns string | ||
|  |               % Doxygen can be found at https://gtsam.org/doxygen/ | ||
|  |               if length(varargin) == 0 | ||
|  |                 varargout{{1}} = {wrapper}({wrapper_id}, this, varargin{{:}}); | ||
|  |               else | ||
|  |                 error('Arguments do not match any overload of function {class_name}.string_serialize'); | ||
|  |               end | ||
|  |             end\n | ||
|  |             function sobj = saveobj(obj) | ||
|  |               % SAVEOBJ Saves the object to a matlab-readable format | ||
|  |               sobj = obj.string_serialize(); | ||
|  |             end | ||
|  |         ''').format(wrapper=self._wrapper_name(), wrapper_id=wrapper_id, class_name=namespace_name + '.' + class_name)
 | ||
|  | 
 | ||
|  |     def wrap_collector_function_serialize(self, class_name, full_name='', namespace=''): | ||
|  |         return textwrap.indent(textwrap.dedent('''\
 | ||
|  |             typedef boost::shared_ptr<{full_name}> Shared; | ||
|  |             checkArguments("string_serialize",nargout,nargin-1,0); | ||
|  |             Shared obj = unwrap_shared_ptr<{full_name}>(in[0], "ptr_{namespace}{class_name}"); | ||
|  |             ostringstream out_archive_stream; | ||
|  |             boost::archive::text_oarchive out_archive(out_archive_stream); | ||
|  |             out_archive << *obj; | ||
|  |             out[0] = wrap< string >(out_archive_stream.str()); | ||
|  |         ''').format(class_name=class_name, full_name=full_name, namespace=namespace),
 | ||
|  |                                prefix='  ') | ||
|  | 
 | ||
|  |     def wrap_collector_function_deserialize(self, class_name, full_name='', namespace=''): | ||
|  |         return textwrap.indent(textwrap.dedent('''\
 | ||
|  |             typedef boost::shared_ptr<{full_name}> Shared; | ||
|  |             checkArguments("{namespace}{class_name}.string_deserialize",nargout,nargin,1); | ||
|  |             string serialized = unwrap< string >(in[0]); | ||
|  |             istringstream in_archive_stream(serialized); | ||
|  |             boost::archive::text_iarchive in_archive(in_archive_stream); | ||
|  |             Shared output(new {full_name}()); | ||
|  |             in_archive >> *output; | ||
|  |             out[0] = wrap_shared_ptr(output,"{namespace}.{class_name}", false); | ||
|  |         ''').format(class_name=class_name, full_name=full_name, namespace=namespace),
 | ||
|  |                                prefix='  ') | ||
|  | 
 | ||
|  |     def wrap(self): | ||
|  |         self.wrap_namespace(self.module) | ||
|  |         self.generate_wrapper(self.module) | ||
|  | 
 | ||
|  |         return self.content | ||
|  | 
 | ||
|  | 
 | ||
|  | def generate_content(cc_content, path, verbose=False): | ||
|  |     """Generate files and folders from matlab wrapper content.
 | ||
|  | 
 | ||
|  |     Keyword arguments: | ||
|  |     cc_content -- the content to generate formatted as | ||
|  |         (file_name, file_content) or | ||
|  |         (folder_name, [(file_name, file_content)]) | ||
|  |     path -- the path to the files parent folder within the main folder | ||
|  |     """
 | ||
|  | 
 | ||
|  |     def _debug(message): | ||
|  |         if not verbose: | ||
|  |             return | ||
|  |         import sys | ||
|  |         print(message, file=sys.stderr) | ||
|  | 
 | ||
|  |     for c in cc_content: | ||
|  |         if type(c) == list: | ||
|  |             if len(c) == 0: | ||
|  |                 continue | ||
|  |             _debug("c object: {}".format(c[0][0])) | ||
|  |             path_to_folder = path + '/' + c[0][0] | ||
|  | 
 | ||
|  |             if not os.path.isdir(path_to_folder): | ||
|  |                 try: | ||
|  |                     os.makedirs(path_to_folder, exist_ok=True) | ||
|  |                 except OSError: | ||
|  |                     pass | ||
|  | 
 | ||
|  |             for sub_content in c: | ||
|  |                 import sys | ||
|  |                 _debug("sub object: {}".format(sub_content[1][0][0])) | ||
|  |                 generate_content(sub_content[1], path_to_folder) | ||
|  |         elif type(c[1]) == list: | ||
|  |             path_to_folder = path + '/' + c[0] | ||
|  | 
 | ||
|  |             _debug("[generate_content_global]: {}".format(path_to_folder)) | ||
|  |             if not os.path.isdir(path_to_folder): | ||
|  |                 try: | ||
|  |                     os.makedirs(path_to_folder, exist_ok=True) | ||
|  |                 except OSError: | ||
|  |                     pass | ||
|  |             for sub_content in c[1]: | ||
|  |                 import sys | ||
|  |                 path_to_file = path_to_folder + '/' + sub_content[0] | ||
|  |                 _debug("[generate_global_method]: {}".format(path_to_file)) | ||
|  |                 with open(path_to_file, 'w') as f: | ||
|  |                     f.write(sub_content[1]) | ||
|  |         else: | ||
|  |             path_to_file = path + '/' + c[0] | ||
|  | 
 | ||
|  |             _debug("[generate_content]: {}".format(path_to_file)) | ||
|  |             if not os.path.isdir(path_to_file): | ||
|  |                 try: | ||
|  |                     os.mkdir(path) | ||
|  |                 except OSError: | ||
|  |                     pass | ||
|  | 
 | ||
|  |             with open(path_to_file, 'w') as f: | ||
|  |                 f.write(c[1]) |