| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | #!/usr/bin/env python3 | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, | 
					
						
							|  |  |  | Atlanta, Georgia 30332-0415 | 
					
						
							|  |  |  | All Rights Reserved | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | See LICENSE for the license information | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Code generator for wrapping a C++ module with Pybind11 | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  | Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | """
 | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  | # pylint: disable=too-many-arguments, too-many-instance-attributes, no-self-use, no-else-return, too-many-arguments, unused-format-string-argument, line-too-long | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | import re | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-05 02:11:36 +08:00
										 |  |  | import gtwrap.interface_parser as parser | 
					
						
							|  |  |  | import gtwrap.template_instantiator as instantiator | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  | class PybindWrapper: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Class to generate binding code for Pybind11 specifically. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  module, | 
					
						
							|  |  |  |                  module_name, | 
					
						
							|  |  |  |                  top_module_namespaces='', | 
					
						
							|  |  |  |                  use_boost=False, | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |                  ignore_classes=(), | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                  module_template=""): | 
					
						
							|  |  |  |         self.module = module | 
					
						
							|  |  |  |         self.module_name = module_name | 
					
						
							|  |  |  |         self.top_module_namespaces = top_module_namespaces | 
					
						
							|  |  |  |         self.use_boost = use_boost | 
					
						
							|  |  |  |         self.ignore_classes = ignore_classes | 
					
						
							|  |  |  |         self._serializing_classes = list() | 
					
						
							|  |  |  |         self.module_template = module_template | 
					
						
							|  |  |  |         self.python_keywords = ['print', 'lambda'] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |         # amount of indentation to add before each function/method declaration. | 
					
						
							|  |  |  |         self.method_indent = '\n' + (' ' * 8) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |     def _py_args_names(self, args_list): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Set the argument names in Pybind11 format.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         names = args_list.args_names() | 
					
						
							|  |  |  |         if names: | 
					
						
							| 
									
										
										
										
											2021-04-17 09:07:09 +08:00
										 |  |  |             py_args = [] | 
					
						
							|  |  |  |             for arg in args_list.args_list: | 
					
						
							|  |  |  |                 if arg.default and isinstance(arg.default, str): | 
					
						
							|  |  |  |                     arg.default = "\"{arg.default}\"".format(arg=arg) | 
					
						
							|  |  |  |                 argument = 'py::arg("{name}"){default}'.format( | 
					
						
							|  |  |  |                     name=arg.name, | 
					
						
							|  |  |  |                     default=' = {0}'.format(arg.default) if arg.default else '') | 
					
						
							|  |  |  |                 py_args.append(argument) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |             return ", " + ", ".join(py_args) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return '' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _method_args_signature_with_names(self, args_list): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Define the method signature types with the argument names.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         cpp_types = args_list.to_cpp(self.use_boost) | 
					
						
							|  |  |  |         names = args_list.args_names() | 
					
						
							|  |  |  |         types_names = ["{} {}".format(ctype, name) for ctype, name in zip(cpp_types, names)] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |         return ', '.join(types_names) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def wrap_ctors(self, my_class): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Wrap the constructors.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         res = "" | 
					
						
							|  |  |  |         for ctor in my_class.ctors: | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |             res += (self.method_indent + '.def(py::init<{args_cpp_types}>()' | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                     '{py_args_names})'.format( | 
					
						
							|  |  |  |                         args_cpp_types=", ".join(ctor.args.to_cpp(self.use_boost)), | 
					
						
							|  |  |  |                         py_args_names=self._py_args_names(ctor.args), | 
					
						
							|  |  |  |                     )) | 
					
						
							|  |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): | 
					
						
							|  |  |  |         py_method = method.name + method_suffix | 
					
						
							|  |  |  |         cpp_method = method.to_cpp() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if cpp_method in ["serialize", "serializable"]: | 
					
						
							|  |  |  |             if not cpp_class in self._serializing_classes: | 
					
						
							|  |  |  |                 self._serializing_classes.append(cpp_class) | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |             serialize_method = self.method_indent + \ | 
					
						
							|  |  |  |                 ".def(\"serialize\", []({class_inst} self){{ return gtsam::serialize(*self); }})".format(class_inst=cpp_class + '*') | 
					
						
							|  |  |  |             deserialize_method = self.method_indent + \ | 
					
						
							|  |  |  |                      ".def(\"deserialize\", []({class_inst} self, string serialized){{ gtsam::deserialize(serialized, *self); }}, py::arg(\"serialized\"))" \ | 
					
						
							|  |  |  |                        .format(class_inst=cpp_class + '*') | 
					
						
							|  |  |  |             return serialize_method + deserialize_method | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-10 22:18:36 +08:00
										 |  |  |         if cpp_method == "pickle": | 
					
						
							|  |  |  |             if not cpp_class in self._serializing_classes: | 
					
						
							|  |  |  |                 raise ValueError("Cannot pickle a class which is not serializable") | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |             pickle_method = self.method_indent + \ | 
					
						
							|  |  |  |                 ".def(py::pickle({indent}    [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent}    [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast<std::string>(), obj); return obj; }}))" | 
					
						
							|  |  |  |             return pickle_method.format(cpp_class=cpp_class, indent=self.method_indent) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         is_method = isinstance(method, instantiator.InstantiatedMethod) | 
					
						
							|  |  |  |         is_static = isinstance(method, parser.StaticMethod) | 
					
						
							|  |  |  |         return_void = method.return_type.is_void() | 
					
						
							|  |  |  |         args_names = method.args.args_names() | 
					
						
							|  |  |  |         py_args_names = self._py_args_names(method.args) | 
					
						
							|  |  |  |         args_signature_with_names = self._method_args_signature_with_names(method.args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         caller = cpp_class + "::" if not is_method else "self->" | 
					
						
							|  |  |  |         function_call = ('{opt_return} {caller}{function_name}' | 
					
						
							|  |  |  |                          '({args_names});'.format( | 
					
						
							|  |  |  |                              opt_return='return' if not return_void else '', | 
					
						
							|  |  |  |                              caller=caller, | 
					
						
							|  |  |  |                              function_name=cpp_method, | 
					
						
							|  |  |  |                              args_names=', '.join(args_names), | 
					
						
							|  |  |  |                          )) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ret = ('{prefix}.{cdef}("{py_method}",' | 
					
						
							|  |  |  |                '[]({opt_self}{opt_comma}{args_signature_with_names}){{' | 
					
						
							|  |  |  |                '{function_call}' | 
					
						
							|  |  |  |                '}}' | 
					
						
							|  |  |  |                '{py_args_names}){suffix}'.format( | 
					
						
							|  |  |  |                    prefix=prefix, | 
					
						
							|  |  |  |                    cdef="def_static" if is_static else "def", | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |                    py_method=py_method if not py_method in self.python_keywords | 
					
						
							|  |  |  |                    else py_method + "_", | 
					
						
							|  |  |  |                    opt_self="{cpp_class}* self".format( | 
					
						
							|  |  |  |                        cpp_class=cpp_class) if is_method else "", | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |                    opt_comma=', ' if is_method and args_names else '', | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                    args_signature_with_names=args_signature_with_names, | 
					
						
							|  |  |  |                    function_call=function_call, | 
					
						
							|  |  |  |                    py_args_names=py_args_names, | 
					
						
							|  |  |  |                    suffix=suffix, | 
					
						
							|  |  |  |                )) | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-17 09:07:09 +08:00
										 |  |  |         # Create __repr__ override | 
					
						
							|  |  |  |         # We allow all arguments to .print() and let the compiler handle type mismatches. | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         if method.name == 'print': | 
					
						
							| 
									
										
										
										
											2021-04-13 07:08:29 +08:00
										 |  |  |             # Redirect stdout - see pybind docs for why this is a good idea: | 
					
						
							|  |  |  |             # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream | 
					
						
							| 
									
										
										
										
											2021-04-17 09:07:09 +08:00
										 |  |  |             ret = ret.replace('self->print', 'py::scoped_ostream_redirect output; self->print') | 
					
						
							| 
									
										
										
										
											2021-04-13 07:08:29 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-17 09:07:09 +08:00
										 |  |  |             # Make __repr__() call print() internally | 
					
						
							|  |  |  |             ret += '''{prefix}.def("__repr__",
 | 
					
						
							|  |  |  |                     [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{ | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                         gtsam::RedirectCout redirect; | 
					
						
							| 
									
										
										
										
											2021-04-17 09:07:09 +08:00
										 |  |  |                         self.{method_name}({method_args}); | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                         return redirect.str(); | 
					
						
							| 
									
										
										
										
											2021-04-17 09:07:09 +08:00
										 |  |  |                     }}{py_args_names}){suffix}'''.format(
 | 
					
						
							|  |  |  |                 prefix=prefix, | 
					
						
							|  |  |  |                 cpp_class=cpp_class, | 
					
						
							|  |  |  |                 opt_comma=', ' if args_names else '', | 
					
						
							|  |  |  |                 args_signature_with_names=args_signature_with_names, | 
					
						
							|  |  |  |                 method_name=method.name, | 
					
						
							|  |  |  |                 method_args=", ".join(args_names) if args_names else '', | 
					
						
							|  |  |  |                 py_args_names=py_args_names, | 
					
						
							|  |  |  |                 suffix=suffix) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         return ret | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def wrap_methods(self, methods, cpp_class, prefix='\n' + ' ' * 8, suffix=''): | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Wrap all the methods in the `cpp_class`. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         This function is also used to wrap global functions. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         res = "" | 
					
						
							|  |  |  |         for method in methods: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # To avoid type confusion for insert, currently unused | 
					
						
							|  |  |  |             if method.name == 'insert' and cpp_class == 'gtsam::Values': | 
					
						
							|  |  |  |                 name_list = method.args.args_names() | 
					
						
							|  |  |  |                 type_list = method.args.to_cpp(self.use_boost) | 
					
						
							|  |  |  |                 if type_list[0].strip() == 'size_t':  # inserting non-wrapped value types | 
					
						
							|  |  |  |                     method_suffix = '_' + name_list[1].strip() | 
					
						
							|  |  |  |                     res += self._wrap_method(method=method, | 
					
						
							|  |  |  |                                              cpp_class=cpp_class, | 
					
						
							|  |  |  |                                              prefix=prefix, | 
					
						
							|  |  |  |                                              suffix=suffix, | 
					
						
							|  |  |  |                                              method_suffix=method_suffix) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             res += self._wrap_method( | 
					
						
							|  |  |  |                 method=method, | 
					
						
							|  |  |  |                 cpp_class=cpp_class, | 
					
						
							|  |  |  |                 prefix=prefix, | 
					
						
							|  |  |  |                 suffix=suffix, | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-20 04:09:06 +08:00
										 |  |  |     def wrap_variable(self, module, module_var, variable, prefix='\n' + ' ' * 8): | 
					
						
							|  |  |  |         """Wrap a variable that's not part of a class (i.e. global)
 | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         return '{prefix}{module_var}.attr("{variable_name}") = {module}{variable_name};'.format( | 
					
						
							|  |  |  |             prefix=prefix, | 
					
						
							|  |  |  |             module=module, | 
					
						
							|  |  |  |             module_var=module_var, | 
					
						
							|  |  |  |             variable_name=variable.name | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |     def wrap_properties(self, properties, cpp_class, prefix='\n' + ' ' * 8): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Wrap all the properties in the `cpp_class`.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         res = "" | 
					
						
							|  |  |  |         for prop in properties: | 
					
						
							|  |  |  |             res += ('{prefix}.def_{property}("{property_name}", ' | 
					
						
							|  |  |  |                     '&{cpp_class}::{property_name})'.format( | 
					
						
							|  |  |  |                         prefix=prefix, | 
					
						
							|  |  |  |                         property="readonly" if prop.ctype.is_const else "readwrite", | 
					
						
							|  |  |  |                         cpp_class=cpp_class, | 
					
						
							|  |  |  |                         property_name=prop.name, | 
					
						
							|  |  |  |                     )) | 
					
						
							|  |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-02 11:20:12 +08:00
										 |  |  |     def wrap_operators(self, operators, cpp_class, prefix='\n' + ' ' * 8): | 
					
						
							|  |  |  |         """Wrap all the overloaded operators in the `cpp_class`.""" | 
					
						
							|  |  |  |         res = "" | 
					
						
							|  |  |  |         template = "{prefix}.def({{0}})".format(prefix=prefix) | 
					
						
							|  |  |  |         for op in operators: | 
					
						
							|  |  |  |             if op.operator == "[]":  # __getitem__ | 
					
						
							|  |  |  |                 res += "{prefix}.def(\"__getitem__\", &{cpp_class}::operator[])".format( | 
					
						
							|  |  |  |                     prefix=prefix, cpp_class=cpp_class) | 
					
						
							|  |  |  |             elif op.operator == "()":  # __call__ | 
					
						
							|  |  |  |                 res += "{prefix}.def(\"__call__\", &{cpp_class}::operator())".format( | 
					
						
							|  |  |  |                     prefix=prefix, cpp_class=cpp_class) | 
					
						
							|  |  |  |             elif op.is_unary: | 
					
						
							|  |  |  |                 res += template.format("{0}py::self".format(op.operator)) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 res += template.format("py::self {0} py::self".format( | 
					
						
							|  |  |  |                     op.operator)) | 
					
						
							|  |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |     def wrap_instantiated_class(self, instantiated_class): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Wrap the class.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         module_var = self._gen_module_var(instantiated_class.namespaces()) | 
					
						
							|  |  |  |         cpp_class = instantiated_class.cpp_class() | 
					
						
							|  |  |  |         if cpp_class in self.ignore_classes: | 
					
						
							|  |  |  |             return "" | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         return ( | 
					
						
							|  |  |  |             '\n    py::class_<{cpp_class}, {class_parent}' | 
					
						
							|  |  |  |             '{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")' | 
					
						
							|  |  |  |             '{wrapped_ctors}' | 
					
						
							|  |  |  |             '{wrapped_methods}' | 
					
						
							|  |  |  |             '{wrapped_static_methods}' | 
					
						
							| 
									
										
										
										
											2021-04-02 11:20:12 +08:00
										 |  |  |             '{wrapped_properties}' | 
					
						
							|  |  |  |             '{wrapped_operators};\n'.format( | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |                 shared_ptr_type=('boost' if self.use_boost else 'std'), | 
					
						
							|  |  |  |                 cpp_class=cpp_class, | 
					
						
							|  |  |  |                 class_name=instantiated_class.name, | 
					
						
							| 
									
										
										
										
											2021-04-02 11:20:12 +08:00
										 |  |  |                 class_parent="{instantiated_class.parent_class}, ".format( | 
					
						
							|  |  |  |                     instantiated_class=instantiated_class) | 
					
						
							|  |  |  |                 if instantiated_class.parent_class else '', | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |                 module_var=module_var, | 
					
						
							|  |  |  |                 wrapped_ctors=self.wrap_ctors(instantiated_class), | 
					
						
							|  |  |  |                 wrapped_methods=self.wrap_methods(instantiated_class.methods, | 
					
						
							|  |  |  |                                                   cpp_class), | 
					
						
							|  |  |  |                 wrapped_static_methods=self.wrap_methods( | 
					
						
							|  |  |  |                     instantiated_class.static_methods, cpp_class), | 
					
						
							|  |  |  |                 wrapped_properties=self.wrap_properties( | 
					
						
							|  |  |  |                     instantiated_class.properties, cpp_class), | 
					
						
							| 
									
										
										
										
											2021-04-02 11:20:12 +08:00
										 |  |  |                 wrapped_operators=self.wrap_operators( | 
					
						
							|  |  |  |                     instantiated_class.operators, cpp_class))) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def wrap_stl_class(self, stl_class): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Wrap STL containers.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         module_var = self._gen_module_var(stl_class.namespaces()) | 
					
						
							|  |  |  |         cpp_class = stl_class.cpp_class() | 
					
						
							|  |  |  |         if cpp_class in self.ignore_classes: | 
					
						
							|  |  |  |             return "" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         return ( | 
					
						
							|  |  |  |             '\n    py::class_<{cpp_class}, {class_parent}' | 
					
						
							|  |  |  |             '{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")' | 
					
						
							|  |  |  |             '{wrapped_ctors}' | 
					
						
							|  |  |  |             '{wrapped_methods}' | 
					
						
							|  |  |  |             '{wrapped_static_methods}' | 
					
						
							|  |  |  |             '{wrapped_properties};\n'.format( | 
					
						
							|  |  |  |                 shared_ptr_type=('boost' if self.use_boost else 'std'), | 
					
						
							|  |  |  |                 cpp_class=cpp_class, | 
					
						
							|  |  |  |                 class_name=stl_class.name, | 
					
						
							|  |  |  |                 class_parent=str(stl_class.parent_class) + | 
					
						
							|  |  |  |                 (', ' if stl_class.parent_class else ''), | 
					
						
							|  |  |  |                 module_var=module_var, | 
					
						
							|  |  |  |                 wrapped_ctors=self.wrap_ctors(stl_class), | 
					
						
							|  |  |  |                 wrapped_methods=self.wrap_methods(stl_class.methods, | 
					
						
							|  |  |  |                                                   cpp_class), | 
					
						
							|  |  |  |                 wrapped_static_methods=self.wrap_methods( | 
					
						
							|  |  |  |                     stl_class.static_methods, cpp_class), | 
					
						
							|  |  |  |                 wrapped_properties=self.wrap_properties( | 
					
						
							|  |  |  |                     stl_class.properties, cpp_class), | 
					
						
							|  |  |  |             )) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def _partial_match(self, namespaces1, namespaces2): | 
					
						
							|  |  |  |         for i in range(min(len(namespaces1), len(namespaces2))): | 
					
						
							|  |  |  |             if namespaces1[i] != namespaces2[i]: | 
					
						
							|  |  |  |                 return False | 
					
						
							|  |  |  |         return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _gen_module_var(self, namespaces): | 
					
						
							|  |  |  |         sub_module_namespaces = namespaces[len(self.top_module_namespaces):] | 
					
						
							|  |  |  |         return "m_{}".format('_'.join(sub_module_namespaces)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _add_namespaces(self, name, namespaces): | 
					
						
							|  |  |  |         if namespaces: | 
					
						
							|  |  |  |             # Ignore the first empty global namespace. | 
					
						
							|  |  |  |             idx = 1 if not namespaces[0] else 0 | 
					
						
							|  |  |  |             return '::'.join(namespaces[idx:] + [name]) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def wrap_namespace(self, namespace): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Wrap the complete `namespace`.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         wrapped = "" | 
					
						
							|  |  |  |         includes = "" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         namespaces = namespace.full_namespaces() | 
					
						
							|  |  |  |         if not self._partial_match(namespaces, self.top_module_namespaces): | 
					
						
							|  |  |  |             return "", "" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if len(namespaces) < len(self.top_module_namespaces): | 
					
						
							|  |  |  |             for element in namespace.content: | 
					
						
							|  |  |  |                 if isinstance(element, parser.Include): | 
					
						
							|  |  |  |                     includes += ("{}\n".format(element).replace('<', '"').replace('>', '"')) | 
					
						
							|  |  |  |                 if isinstance(element, parser.Namespace): | 
					
						
							|  |  |  |                     ( | 
					
						
							|  |  |  |                         wrapped_namespace, | 
					
						
							|  |  |  |                         includes_namespace, | 
					
						
							|  |  |  |                     ) = self.wrap_namespace(  # noqa | 
					
						
							|  |  |  |                         element) | 
					
						
							|  |  |  |                     wrapped += wrapped_namespace | 
					
						
							|  |  |  |                     includes += includes_namespace | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             module_var = self._gen_module_var(namespaces) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if len(namespaces) > len(self.top_module_namespaces): | 
					
						
							|  |  |  |                 wrapped += (' ' * 4 + 'pybind11::module {module_var} = ' | 
					
						
							|  |  |  |                             '{parent_module_var}.def_submodule("{namespace}", "' | 
					
						
							|  |  |  |                             '{namespace} submodule");\n'.format( | 
					
						
							|  |  |  |                                 module_var=module_var, | 
					
						
							|  |  |  |                                 namespace=namespace.name, | 
					
						
							|  |  |  |                                 parent_module_var=self._gen_module_var(namespaces[:-1]), | 
					
						
							|  |  |  |                             )) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for element in namespace.content: | 
					
						
							|  |  |  |                 if isinstance(element, parser.Include): | 
					
						
							|  |  |  |                     includes += ("{}\n".format(element).replace('<', '"').replace('>', '"')) | 
					
						
							|  |  |  |                 elif isinstance(element, parser.Namespace): | 
					
						
							|  |  |  |                     ( | 
					
						
							|  |  |  |                         wrapped_namespace, | 
					
						
							|  |  |  |                         includes_namespace, | 
					
						
							|  |  |  |                     ) = self.wrap_namespace(  # noqa | 
					
						
							|  |  |  |                         element) | 
					
						
							|  |  |  |                     wrapped += wrapped_namespace | 
					
						
							|  |  |  |                     includes += includes_namespace | 
					
						
							|  |  |  |                 elif isinstance(element, instantiator.InstantiatedClass): | 
					
						
							|  |  |  |                     wrapped += self.wrap_instantiated_class(element) | 
					
						
							| 
									
										
										
										
											2021-04-20 04:09:06 +08:00
										 |  |  |                 elif isinstance(element, parser.Variable): | 
					
						
							|  |  |  |                     wrapped += self.wrap_variable( | 
					
						
							|  |  |  |                         module=self._add_namespaces('', namespaces), | 
					
						
							|  |  |  |                         module_var=module_var, | 
					
						
							|  |  |  |                         variable=element, | 
					
						
							|  |  |  |                         prefix='\n' + ' ' * 4 | 
					
						
							|  |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # Global functions. | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |             all_funcs = [ | 
					
						
							|  |  |  |                 func for func in namespace.content | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |                 if isinstance(func, (parser.GlobalFunction, | 
					
						
							|  |  |  |                                      instantiator.InstantiatedGlobalFunction)) | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |             ] | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |             wrapped += self.wrap_methods( | 
					
						
							|  |  |  |                 all_funcs, | 
					
						
							|  |  |  |                 self._add_namespaces('', namespaces)[:-2], | 
					
						
							|  |  |  |                 prefix='\n' + ' ' * 4 + module_var, | 
					
						
							|  |  |  |                 suffix=';', | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |         return wrapped, includes | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def wrap(self): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Wrap the code in the interface file.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         wrapped_namespace, includes = self.wrap_namespace(self.module) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Export classes for serialization. | 
					
						
							|  |  |  |         boost_class_export = "" | 
					
						
							|  |  |  |         for cpp_class in self._serializing_classes: | 
					
						
							|  |  |  |             new_name = cpp_class | 
					
						
							|  |  |  |             # The boost's macro doesn't like commas, so we have to typedef. | 
					
						
							|  |  |  |             if ',' in cpp_class: | 
					
						
							|  |  |  |                 new_name = re.sub("[,:<> ]", "", cpp_class) | 
					
						
							|  |  |  |                 boost_class_export += "typedef {cpp_class} {new_name};\n".format(  # noqa | 
					
						
							|  |  |  |                     cpp_class=cpp_class, | 
					
						
							|  |  |  |                     new_name=new_name, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             boost_class_export += "BOOST_CLASS_EXPORT({new_name})\n".format(new_name=new_name, ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         holder_type = "PYBIND11_DECLARE_HOLDER_TYPE(TYPE_PLACEHOLDER_DONOTUSE, " \ | 
					
						
							|  |  |  |                       "{shared_ptr_type}::shared_ptr<TYPE_PLACEHOLDER_DONOTUSE>);" | 
					
						
							|  |  |  |         include_boost = "#include <boost/shared_ptr.hpp>" if self.use_boost else "" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         return self.module_template.format( | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |             include_boost=include_boost, | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |             module_name=self.module_name, | 
					
						
							|  |  |  |             includes=includes, | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |             holder_type=holder_type.format(shared_ptr_type=('boost' if self.use_boost else 'std')) | 
					
						
							|  |  |  |             if self.use_boost else "", | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |             wrapped_namespace=wrapped_namespace, | 
					
						
							|  |  |  |             boost_class_export=boost_class_export, | 
					
						
							|  |  |  |         ) |