| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | #!/usr/bin/env python3 | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, | 
					
						
							|  |  |  | Atlanta, Georgia 30332-0415 | 
					
						
							|  |  |  | All Rights Reserved | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | See LICENSE for the license information | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Code generator for wrapping a C++ module with Pybind11 | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  | Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | """
 | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  | # pylint: disable=too-many-arguments, too-many-instance-attributes, no-self-use, no-else-return, too-many-arguments, unused-format-string-argument, line-too-long, consider-using-f-string | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | import re | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  | from pathlib import Path | 
					
						
							| 
									
										
										
										
											2021-12-25 22:12:24 +08:00
										 |  |  | from typing import List | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-01-05 02:11:36 +08:00
										 |  |  | import gtwrap.interface_parser as parser | 
					
						
							|  |  |  | import gtwrap.template_instantiator as instantiator | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  | class PybindWrapper: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Class to generate binding code for Pybind11 specifically. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  module_name, | 
					
						
							|  |  |  |                  top_module_namespaces='', | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |                  use_boost_serialization=False, | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |                  ignore_classes=(), | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                  module_template=""): | 
					
						
							|  |  |  |         self.module_name = module_name | 
					
						
							|  |  |  |         self.top_module_namespaces = top_module_namespaces | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |         self.use_boost_serialization = use_boost_serialization | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         self.ignore_classes = ignore_classes | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         self._serializing_classes = [] | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         self.module_template = module_template | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |         self.python_keywords = [ | 
					
						
							|  |  |  |             'lambda', 'False', 'def', 'if', 'raise', 'None', 'del', 'import', | 
					
						
							|  |  |  |             'return', 'True', 'elif', 'in', 'try', 'and', 'else', 'is', | 
					
						
							|  |  |  |             'while', 'as', 'except', 'lambda', 'with', 'assert', 'finally', | 
					
						
							|  |  |  |             'nonlocal', 'yield', 'break', 'for', 'not', 'class', 'from', 'or', | 
					
						
							|  |  |  |             'continue', 'global', 'pass' | 
					
						
							|  |  |  |         ] | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 04:04:09 +08:00
										 |  |  |         self.dunder_methods = ('len', 'contains', 'iter') | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |         # amount of indentation to add before each function/method declaration. | 
					
						
							|  |  |  |         self.method_indent = '\n' + (' ' * 8) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-25 22:12:24 +08:00
										 |  |  |         # Special methods which are leveraged by ipython/jupyter notebooks | 
					
						
							|  |  |  |         self._ipython_special_methods = [ | 
					
						
							|  |  |  |             "svg", "png", "jpeg", "html", "javascript", "markdown", "latex" | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |     def _py_args_names(self, args): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Set the argument names in Pybind11 format.""" | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |         names = args.names() | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         if names: | 
					
						
							| 
									
										
										
										
											2021-04-17 09:07:09 +08:00
										 |  |  |             py_args = [] | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |             for arg in args.list(): | 
					
						
							|  |  |  |                 if arg.default is not None: | 
					
						
							|  |  |  |                     default = ' = {arg.default}'.format(arg=arg) | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |                 else: | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |                     default = '' | 
					
						
							| 
									
										
										
										
											2021-04-17 09:07:09 +08:00
										 |  |  |                 argument = 'py::arg("{name}"){default}'.format( | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |                     name=arg.name, default='{0}'.format(default)) | 
					
						
							| 
									
										
										
										
											2021-04-17 09:07:09 +08:00
										 |  |  |                 py_args.append(argument) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |             return ", " + ", ".join(py_args) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return '' | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |     def _method_args_signature(self, args): | 
					
						
							|  |  |  |         """Generate the argument types and names as per the method signature.""" | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |         cpp_types = args.to_cpp() | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |         names = args.names() | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |         types_names = [ | 
					
						
							|  |  |  |             "{} {}".format(ctype, name) | 
					
						
							|  |  |  |             for ctype, name in zip(cpp_types, names) | 
					
						
							|  |  |  |         ] | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |         return ', '.join(types_names) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def wrap_ctors(self, my_class): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Wrap the constructors.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         res = "" | 
					
						
							|  |  |  |         for ctor in my_class.ctors: | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |             res += (self.method_indent + '.def(py::init<{args_cpp_types}>()' | 
					
						
							|  |  |  |                     '{py_args_names})'.format( | 
					
						
							|  |  |  |                         args_cpp_types=", ".join(ctor.args.to_cpp()), | 
					
						
							|  |  |  |                         py_args_names=self._py_args_names(ctor.args), | 
					
						
							|  |  |  |                     )) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-25 22:12:24 +08:00
										 |  |  |     def _wrap_serialization(self, cpp_class): | 
					
						
							|  |  |  |         """Helper method to add serialize, deserialize and pickle methods to the wrapped class.""" | 
					
						
							|  |  |  |         if not cpp_class in self._serializing_classes: | 
					
						
							|  |  |  |             self._serializing_classes.append(cpp_class) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         serialize_method = self.method_indent + \ | 
					
						
							|  |  |  |             ".def(\"serialize\", []({class_inst} self){{ return gtsam::serialize(*self); }})".format(class_inst=cpp_class + '*') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         deserialize_method = self.method_indent + \ | 
					
						
							|  |  |  |                     '.def("deserialize", []({class_inst} self, string serialized)' \ | 
					
						
							|  |  |  |                     '{{ gtsam::deserialize(serialized, *self); }}, py::arg("serialized"))' \ | 
					
						
							|  |  |  |                     .format(class_inst=cpp_class + '*') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Since this class supports serialization, we also add the pickle method. | 
					
						
							|  |  |  |         pickle_method = self.method_indent + \ | 
					
						
							|  |  |  |             ".def(py::pickle({indent}    [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent}    [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast<std::string>(), obj); return obj; }}))" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return serialize_method + deserialize_method + \ | 
					
						
							|  |  |  |             pickle_method.format(cpp_class=cpp_class, indent=self.method_indent) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _wrap_print(self, ret: str, method: parser.Method, cpp_class: str, | 
					
						
							|  |  |  |                     args_names: List[str], args_signature_with_names: str, | 
					
						
							|  |  |  |                     py_args_names: str, prefix: str, suffix: str): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Update the print method to print to the output stream and append a __repr__ method. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             ret (str): The result of the parser. | 
					
						
							|  |  |  |             method (parser.Method): The method to be wrapped. | 
					
						
							|  |  |  |             cpp_class (str): The C++ name of the class to which the method belongs. | 
					
						
							|  |  |  |             args_names (List[str]): List of argument variable names passed to the method. | 
					
						
							|  |  |  |             args_signature_with_names (str): C++ arguments containing their names and type signatures. | 
					
						
							|  |  |  |             py_args_names (str): The pybind11 formatted version of the argument list. | 
					
						
							|  |  |  |             prefix (str): Prefix to add to the wrapped method when writing to the cpp file. | 
					
						
							|  |  |  |             suffix (str): Suffix to add to the wrapped method when writing to the cpp file. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             str: The wrapped print method. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # Redirect stdout - see pybind docs for why this is a good idea: | 
					
						
							|  |  |  |         # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream | 
					
						
							|  |  |  |         ret = ret.replace('self->print', | 
					
						
							|  |  |  |                           'py::scoped_ostream_redirect output; self->print') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Make __repr__() call .print() internally | 
					
						
							|  |  |  |         ret += '''{prefix}.def("__repr__",
 | 
					
						
							|  |  |  |                     [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{ | 
					
						
							|  |  |  |                         gtsam::RedirectCout redirect; | 
					
						
							|  |  |  |                         self.{method_name}({method_args}); | 
					
						
							|  |  |  |                         return redirect.str(); | 
					
						
							|  |  |  |                     }}{py_args_names}){suffix}'''.format(
 | 
					
						
							|  |  |  |             prefix=prefix, | 
					
						
							|  |  |  |             cpp_class=cpp_class, | 
					
						
							|  |  |  |             opt_comma=', ' if args_names else '', | 
					
						
							|  |  |  |             args_signature_with_names=args_signature_with_names, | 
					
						
							|  |  |  |             method_name=method.name, | 
					
						
							|  |  |  |             method_args=", ".join(args_names) if args_names else '', | 
					
						
							|  |  |  |             py_args_names=py_args_names, | 
					
						
							|  |  |  |             suffix=suffix) | 
					
						
							|  |  |  |         return ret | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 04:04:09 +08:00
										 |  |  |     def _wrap_dunder(self, | 
					
						
							|  |  |  |                      method, | 
					
						
							|  |  |  |                      cpp_class, | 
					
						
							|  |  |  |                      prefix, | 
					
						
							|  |  |  |                      suffix, | 
					
						
							|  |  |  |                      method_suffix=""): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Wrap a Python double-underscore (dunder) method. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         E.g. __len__() gets wrapped as `.def("__len__", [](gtsam::KeySet* self) {return self->size();})` | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Supported methods are: | 
					
						
							|  |  |  |         - __contains__(T x) | 
					
						
							|  |  |  |         - __len__() | 
					
						
							|  |  |  |         - __iter__() | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         py_method = method.name + method_suffix | 
					
						
							|  |  |  |         args_names = method.args.names() | 
					
						
							|  |  |  |         py_args_names = self._py_args_names(method.args) | 
					
						
							|  |  |  |         args_signature_with_names = self._method_args_signature(method.args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if method.name == 'len': | 
					
						
							|  |  |  |             function_call = "return std::distance(self->begin(), self->end());" | 
					
						
							|  |  |  |         elif method.name == 'contains': | 
					
						
							|  |  |  |             function_call = f"return std::find(self->begin(), self->end(), {method.args.args_list[0].name}) != self->end();" | 
					
						
							|  |  |  |         elif method.name == 'iter': | 
					
						
							|  |  |  |             function_call = "return py::make_iterator(self->begin(), self->end());" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ret = ('{prefix}.def("__{py_method}__",' | 
					
						
							|  |  |  |                '[]({self}{opt_comma}{args_signature_with_names}){{' | 
					
						
							|  |  |  |                '{function_call}' | 
					
						
							|  |  |  |                '}}' | 
					
						
							|  |  |  |                '{py_args_names}){suffix}'.format( | 
					
						
							|  |  |  |                    prefix=prefix, | 
					
						
							|  |  |  |                    py_method=py_method, | 
					
						
							|  |  |  |                    self=f"{cpp_class}* self", | 
					
						
							|  |  |  |                    opt_comma=', ' if args_names else '', | 
					
						
							|  |  |  |                    args_signature_with_names=args_signature_with_names, | 
					
						
							|  |  |  |                    function_call=function_call, | 
					
						
							|  |  |  |                    py_args_names=py_args_names, | 
					
						
							|  |  |  |                    suffix=suffix, | 
					
						
							|  |  |  |                )) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return ret | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |     def _wrap_method(self, | 
					
						
							|  |  |  |                      method, | 
					
						
							|  |  |  |                      cpp_class, | 
					
						
							|  |  |  |                      prefix, | 
					
						
							|  |  |  |                      suffix, | 
					
						
							|  |  |  |                      method_suffix=""): | 
					
						
							| 
									
										
										
										
											2021-12-18 04:55:42 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Wrap the `method` for the class specified by `cpp_class`. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             method: The method to wrap. | 
					
						
							|  |  |  |             cpp_class: The C++ name of the class to which the method belongs. | 
					
						
							|  |  |  |             prefix: Prefix to add to the wrapped method when writing to the cpp file. | 
					
						
							|  |  |  |             suffix: Suffix to add to the wrapped method when writing to the cpp file. | 
					
						
							|  |  |  |             method_suffix: A string to append to the wrapped method name. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         py_method = method.name + method_suffix | 
					
						
							|  |  |  |         cpp_method = method.to_cpp() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-25 22:12:24 +08:00
										 |  |  |         args_names = method.args.names() | 
					
						
							|  |  |  |         py_args_names = self._py_args_names(method.args) | 
					
						
							|  |  |  |         args_signature_with_names = self._method_args_signature(method.args) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-18 04:55:42 +08:00
										 |  |  |         # Special handling for the serialize/serializable method | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         if cpp_method in ["serialize", "serializable"]: | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |             if self.use_boost_serialization: | 
					
						
							|  |  |  |                 return self._wrap_serialization(cpp_class) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 return "" | 
					
						
							| 
									
										
										
										
											2021-12-25 22:12:24 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Special handling of ipython specific methods | 
					
						
							|  |  |  |         # https://ipython.readthedocs.io/en/stable/config/integrating.html | 
					
						
							|  |  |  |         if cpp_method in self._ipython_special_methods: | 
					
						
							|  |  |  |             idx = self._ipython_special_methods.index(cpp_method) | 
					
						
							|  |  |  |             py_method = f"_repr_{self._ipython_special_methods[idx]}_" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |         # Add underscore to disambiguate if the method name matches a python keyword | 
					
						
							|  |  |  |         if py_method in self.python_keywords: | 
					
						
							|  |  |  |             py_method = py_method + "_" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-26 00:46:06 +08:00
										 |  |  |         is_method = isinstance( | 
					
						
							|  |  |  |             method, (parser.Method, instantiator.InstantiatedMethod)) | 
					
						
							|  |  |  |         is_static = isinstance( | 
					
						
							|  |  |  |             method, | 
					
						
							|  |  |  |             (parser.StaticMethod, instantiator.InstantiatedStaticMethod)) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         return_void = method.return_type.is_void() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         caller = cpp_class + "::" if not is_method else "self->" | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |         function_call = ('{opt_return} {caller}{method_name}' | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                          '({args_names});'.format( | 
					
						
							|  |  |  |                              opt_return='return' if not return_void else '', | 
					
						
							|  |  |  |                              caller=caller, | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |                              method_name=cpp_method, | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                              args_names=', '.join(args_names), | 
					
						
							|  |  |  |                          )) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ret = ('{prefix}.{cdef}("{py_method}",' | 
					
						
							|  |  |  |                '[]({opt_self}{opt_comma}{args_signature_with_names}){{' | 
					
						
							|  |  |  |                '{function_call}' | 
					
						
							|  |  |  |                '}}' | 
					
						
							|  |  |  |                '{py_args_names}){suffix}'.format( | 
					
						
							|  |  |  |                    prefix=prefix, | 
					
						
							|  |  |  |                    cdef="def_static" if is_static else "def", | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |                    py_method=py_method, | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |                    opt_self="{cpp_class}* self".format( | 
					
						
							|  |  |  |                        cpp_class=cpp_class) if is_method else "", | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |                    opt_comma=', ' if is_method and args_names else '', | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                    args_signature_with_names=args_signature_with_names, | 
					
						
							|  |  |  |                    function_call=function_call, | 
					
						
							|  |  |  |                    py_args_names=py_args_names, | 
					
						
							|  |  |  |                    suffix=suffix, | 
					
						
							|  |  |  |                )) | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-17 09:07:09 +08:00
										 |  |  |         # Create __repr__ override | 
					
						
							|  |  |  |         # We allow all arguments to .print() and let the compiler handle type mismatches. | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         if method.name == 'print': | 
					
						
							| 
									
										
										
										
											2021-12-25 22:12:24 +08:00
										 |  |  |             ret = self._wrap_print(ret, method, cpp_class, args_names, | 
					
						
							|  |  |  |                                    args_signature_with_names, py_args_names, | 
					
						
							|  |  |  |                                    prefix, suffix) | 
					
						
							| 
									
										
										
										
											2021-04-17 09:07:09 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         return ret | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 04:04:09 +08:00
										 |  |  |     def wrap_dunder_methods(self, | 
					
						
							|  |  |  |                             methods, | 
					
						
							|  |  |  |                             cpp_class, | 
					
						
							|  |  |  |                             prefix='\n' + ' ' * 8, | 
					
						
							|  |  |  |                             suffix=''): | 
					
						
							|  |  |  |         res = "" | 
					
						
							|  |  |  |         for method in methods: | 
					
						
							|  |  |  |             res += self._wrap_dunder(method=method, | 
					
						
							|  |  |  |                                      cpp_class=cpp_class, | 
					
						
							|  |  |  |                                      prefix=prefix, | 
					
						
							|  |  |  |                                      suffix=suffix) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |     def wrap_methods(self, | 
					
						
							|  |  |  |                      methods, | 
					
						
							|  |  |  |                      cpp_class, | 
					
						
							|  |  |  |                      prefix='\n' + ' ' * 8, | 
					
						
							|  |  |  |                      suffix=''): | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Wrap all the methods in the `cpp_class`. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         res = "" | 
					
						
							|  |  |  |         for method in methods: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |             # To avoid type confusion for insert | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |             if method.name == 'insert' and cpp_class == 'gtsam::Values': | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |                 name_list = method.args.names() | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |                 type_list = method.args.to_cpp() | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |                 # inserting non-wrapped value types | 
					
						
							|  |  |  |                 if type_list[0].strip() == 'size_t': | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                     method_suffix = '_' + name_list[1].strip() | 
					
						
							|  |  |  |                     res += self._wrap_method(method=method, | 
					
						
							|  |  |  |                                              cpp_class=cpp_class, | 
					
						
							|  |  |  |                                              prefix=prefix, | 
					
						
							|  |  |  |                                              suffix=suffix, | 
					
						
							|  |  |  |                                              method_suffix=method_suffix) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             res += self._wrap_method( | 
					
						
							|  |  |  |                 method=method, | 
					
						
							|  |  |  |                 cpp_class=cpp_class, | 
					
						
							|  |  |  |                 prefix=prefix, | 
					
						
							|  |  |  |                 suffix=suffix, | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |     def wrap_variable(self, | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |                       namespace, | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |                       module_var, | 
					
						
							|  |  |  |                       variable, | 
					
						
							|  |  |  |                       prefix='\n' + ' ' * 8): | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Wrap a variable that's not part of a class (i.e. global) | 
					
						
							| 
									
										
										
										
											2021-04-20 04:09:06 +08:00
										 |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |         variable_value = "" | 
					
						
							|  |  |  |         if variable.default is None: | 
					
						
							|  |  |  |             variable_value = variable.name | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             variable_value = variable.default | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return '{prefix}{module_var}.attr("{variable_name}") = {namespace}{variable_value};'.format( | 
					
						
							| 
									
										
										
										
											2021-04-20 04:09:06 +08:00
										 |  |  |             prefix=prefix, | 
					
						
							|  |  |  |             module_var=module_var, | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |             variable_name=variable.name, | 
					
						
							|  |  |  |             namespace=namespace, | 
					
						
							|  |  |  |             variable_value=variable_value) | 
					
						
							| 
									
										
										
										
											2021-04-20 04:09:06 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |     def wrap_properties(self, properties, cpp_class, prefix='\n' + ' ' * 8): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Wrap all the properties in the `cpp_class`.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         res = "" | 
					
						
							|  |  |  |         for prop in properties: | 
					
						
							|  |  |  |             res += ('{prefix}.def_{property}("{property_name}", ' | 
					
						
							|  |  |  |                     '&{cpp_class}::{property_name})'.format( | 
					
						
							|  |  |  |                         prefix=prefix, | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |                         property="readonly" | 
					
						
							|  |  |  |                         if prop.ctype.is_const else "readwrite", | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                         cpp_class=cpp_class, | 
					
						
							|  |  |  |                         property_name=prop.name, | 
					
						
							|  |  |  |                     )) | 
					
						
							|  |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-02 11:20:12 +08:00
										 |  |  |     def wrap_operators(self, operators, cpp_class, prefix='\n' + ' ' * 8): | 
					
						
							|  |  |  |         """Wrap all the overloaded operators in the `cpp_class`.""" | 
					
						
							|  |  |  |         res = "" | 
					
						
							|  |  |  |         template = "{prefix}.def({{0}})".format(prefix=prefix) | 
					
						
							|  |  |  |         for op in operators: | 
					
						
							|  |  |  |             if op.operator == "[]":  # __getitem__ | 
					
						
							|  |  |  |                 res += "{prefix}.def(\"__getitem__\", &{cpp_class}::operator[])".format( | 
					
						
							|  |  |  |                     prefix=prefix, cpp_class=cpp_class) | 
					
						
							|  |  |  |             elif op.operator == "()":  # __call__ | 
					
						
							|  |  |  |                 res += "{prefix}.def(\"__call__\", &{cpp_class}::operator())".format( | 
					
						
							|  |  |  |                     prefix=prefix, cpp_class=cpp_class) | 
					
						
							|  |  |  |             elif op.is_unary: | 
					
						
							|  |  |  |                 res += template.format("{0}py::self".format(op.operator)) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 res += template.format("py::self {0} py::self".format( | 
					
						
							|  |  |  |                     op.operator)) | 
					
						
							|  |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |     def wrap_enum(self, enum, class_name='', module=None, prefix=' ' * 4): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Wrap an enum. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             enum: The parsed enum to wrap. | 
					
						
							|  |  |  |             class_name: The class under which the enum is defined. | 
					
						
							|  |  |  |             prefix: The amount of indentation. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if module is None: | 
					
						
							|  |  |  |             module = self._gen_module_var(enum.namespaces()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         cpp_class = enum.cpp_typename().to_cpp() | 
					
						
							|  |  |  |         if class_name: | 
					
						
							|  |  |  |             # If class_name is provided, add that as the namespace | 
					
						
							|  |  |  |             cpp_class = class_name + "::" + cpp_class | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         res = '{prefix}py::enum_<{cpp_class}>({module}, "{enum.name}", py::arithmetic())'.format( | 
					
						
							|  |  |  |             prefix=prefix, module=module, enum=enum, cpp_class=cpp_class) | 
					
						
							|  |  |  |         for enumerator in enum.enumerators: | 
					
						
							|  |  |  |             res += '\n{prefix}    .value("{enumerator.name}", {cpp_class}::{enumerator.name})'.format( | 
					
						
							|  |  |  |                 prefix=prefix, enumerator=enumerator, cpp_class=cpp_class) | 
					
						
							|  |  |  |         res += ";\n\n" | 
					
						
							|  |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def wrap_enums(self, enums, instantiated_class, prefix=' ' * 4): | 
					
						
							|  |  |  |         """Wrap multiple enums defined in a class.""" | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |         cpp_class = instantiated_class.to_cpp() | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |         module_var = instantiated_class.name.lower() | 
					
						
							|  |  |  |         res = '' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for enum in enums: | 
					
						
							|  |  |  |             res += "\n" + self.wrap_enum( | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |                 enum, class_name=cpp_class, module=module_var, prefix=prefix) | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |     def wrap_instantiated_class( | 
					
						
							|  |  |  |             self, instantiated_class: instantiator.InstantiatedClass): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Wrap the class.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         module_var = self._gen_module_var(instantiated_class.namespaces()) | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |         cpp_class = instantiated_class.to_cpp() | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         if cpp_class in self.ignore_classes: | 
					
						
							|  |  |  |             return "" | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |         if instantiated_class.parent_class: | 
					
						
							|  |  |  |             class_parent = "{instantiated_class.parent_class}, ".format( | 
					
						
							|  |  |  |                 instantiated_class=instantiated_class) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             class_parent = '' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if instantiated_class.enums: | 
					
						
							|  |  |  |             # If class has enums, define an instance and set module_var to the instance | 
					
						
							|  |  |  |             instance_name = instantiated_class.name.lower() | 
					
						
							|  |  |  |             class_declaration = ( | 
					
						
							|  |  |  |                 '\n    py::class_<{cpp_class}, {class_parent}' | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |                 'std::shared_ptr<{cpp_class}>> ' | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |                 '{instance_name}({module_var}, "{class_name}");' | 
					
						
							|  |  |  |                 '\n    {instance_name}').format( | 
					
						
							|  |  |  |                     cpp_class=cpp_class, | 
					
						
							|  |  |  |                     class_name=instantiated_class.name, | 
					
						
							|  |  |  |                     class_parent=class_parent, | 
					
						
							|  |  |  |                     instance_name=instance_name, | 
					
						
							|  |  |  |                     module_var=module_var) | 
					
						
							|  |  |  |             module_var = instance_name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             class_declaration = ( | 
					
						
							|  |  |  |                 '\n    py::class_<{cpp_class}, {class_parent}' | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |                 'std::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")' | 
					
						
							|  |  |  |             ).format(cpp_class=cpp_class, | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |                      class_name=instantiated_class.name, | 
					
						
							|  |  |  |                      class_parent=class_parent, | 
					
						
							|  |  |  |                      module_var=module_var) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return ('{class_declaration}' | 
					
						
							|  |  |  |                 '{wrapped_ctors}' | 
					
						
							|  |  |  |                 '{wrapped_methods}' | 
					
						
							|  |  |  |                 '{wrapped_static_methods}' | 
					
						
							| 
									
										
										
										
											2024-01-17 04:04:09 +08:00
										 |  |  |                 '{wrapped_dunder_methods}' | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |                 '{wrapped_properties}' | 
					
						
							|  |  |  |                 '{wrapped_operators};\n'.format( | 
					
						
							|  |  |  |                     class_declaration=class_declaration, | 
					
						
							|  |  |  |                     wrapped_ctors=self.wrap_ctors(instantiated_class), | 
					
						
							|  |  |  |                     wrapped_methods=self.wrap_methods( | 
					
						
							|  |  |  |                         instantiated_class.methods, cpp_class), | 
					
						
							|  |  |  |                     wrapped_static_methods=self.wrap_methods( | 
					
						
							|  |  |  |                         instantiated_class.static_methods, cpp_class), | 
					
						
							| 
									
										
										
										
											2024-01-17 04:04:09 +08:00
										 |  |  |                     wrapped_dunder_methods=self.wrap_dunder_methods( | 
					
						
							|  |  |  |                         instantiated_class.dunder_methods, cpp_class), | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |                     wrapped_properties=self.wrap_properties( | 
					
						
							|  |  |  |                         instantiated_class.properties, cpp_class), | 
					
						
							|  |  |  |                     wrapped_operators=self.wrap_operators( | 
					
						
							|  |  |  |                         instantiated_class.operators, cpp_class))) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |     def wrap_instantiated_declaration( | 
					
						
							|  |  |  |             self, instantiated_decl: instantiator.InstantiatedDeclaration): | 
					
						
							| 
									
										
										
										
											2022-03-21 10:01:12 +08:00
										 |  |  |         """Wrap the forward declaration.""" | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |         module_var = self._gen_module_var(instantiated_decl.namespaces()) | 
					
						
							|  |  |  |         cpp_class = instantiated_decl.to_cpp() | 
					
						
							|  |  |  |         if cpp_class in self.ignore_classes: | 
					
						
							|  |  |  |             return "" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |         res = ('\n    py::class_<{cpp_class}, ' | 
					
						
							|  |  |  |                'std::shared_ptr<{cpp_class}>>({module_var}, "{class_name}");' | 
					
						
							|  |  |  |                ).format(cpp_class=cpp_class, | 
					
						
							|  |  |  |                         class_name=instantiated_decl.name, | 
					
						
							|  |  |  |                         module_var=module_var) | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |     def wrap_stl_class(self, stl_class): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Wrap STL containers.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         module_var = self._gen_module_var(stl_class.namespaces()) | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |         cpp_class = stl_class.to_cpp() | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         if cpp_class in self.ignore_classes: | 
					
						
							|  |  |  |             return "" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |         return ('\n    py::class_<{cpp_class}, {class_parent}' | 
					
						
							|  |  |  |                 'std::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")' | 
					
						
							|  |  |  |                 '{wrapped_ctors}' | 
					
						
							|  |  |  |                 '{wrapped_methods}' | 
					
						
							|  |  |  |                 '{wrapped_static_methods}' | 
					
						
							|  |  |  |                 '{wrapped_properties};\n'.format( | 
					
						
							|  |  |  |                     cpp_class=cpp_class, | 
					
						
							|  |  |  |                     class_name=stl_class.name, | 
					
						
							|  |  |  |                     class_parent=str(stl_class.parent_class) + | 
					
						
							|  |  |  |                     (', ' if stl_class.parent_class else ''), | 
					
						
							|  |  |  |                     module_var=module_var, | 
					
						
							|  |  |  |                     wrapped_ctors=self.wrap_ctors(stl_class), | 
					
						
							|  |  |  |                     wrapped_methods=self.wrap_methods(stl_class.methods, | 
					
						
							|  |  |  |                                                       cpp_class), | 
					
						
							|  |  |  |                     wrapped_static_methods=self.wrap_methods( | 
					
						
							|  |  |  |                         stl_class.static_methods, cpp_class), | 
					
						
							|  |  |  |                     wrapped_properties=self.wrap_properties( | 
					
						
							|  |  |  |                         stl_class.properties, cpp_class), | 
					
						
							|  |  |  |                 )) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |     def wrap_functions(self, | 
					
						
							|  |  |  |                        functions, | 
					
						
							|  |  |  |                        namespace, | 
					
						
							|  |  |  |                        prefix='\n' + ' ' * 8, | 
					
						
							|  |  |  |                        suffix=''): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Wrap all the global functions. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         res = "" | 
					
						
							|  |  |  |         for function in functions: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             function_name = function.name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Add underscore to disambiguate if the function name matches a python keyword | 
					
						
							|  |  |  |             python_keywords = self.python_keywords + ['print'] | 
					
						
							|  |  |  |             if function_name in python_keywords: | 
					
						
							|  |  |  |                 function_name = function_name + "_" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             cpp_method = function.to_cpp() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             is_static = isinstance(function, parser.StaticMethod) | 
					
						
							|  |  |  |             return_void = function.return_type.is_void() | 
					
						
							|  |  |  |             args_names = function.args.names() | 
					
						
							|  |  |  |             py_args_names = self._py_args_names(function.args) | 
					
						
							|  |  |  |             args_signature = self._method_args_signature(function.args) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             caller = namespace + "::" | 
					
						
							|  |  |  |             function_call = ('{opt_return} {caller}{function_name}' | 
					
						
							|  |  |  |                              '({args_names});'.format( | 
					
						
							|  |  |  |                                  opt_return='return' | 
					
						
							|  |  |  |                                  if not return_void else '', | 
					
						
							|  |  |  |                                  caller=caller, | 
					
						
							|  |  |  |                                  function_name=cpp_method, | 
					
						
							|  |  |  |                                  args_names=', '.join(args_names), | 
					
						
							|  |  |  |                              )) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             ret = ('{prefix}.{cdef}("{function_name}",' | 
					
						
							|  |  |  |                    '[]({args_signature}){{' | 
					
						
							|  |  |  |                    '{function_call}' | 
					
						
							|  |  |  |                    '}}' | 
					
						
							|  |  |  |                    '{py_args_names}){suffix}'.format( | 
					
						
							|  |  |  |                        prefix=prefix, | 
					
						
							|  |  |  |                        cdef="def_static" if is_static else "def", | 
					
						
							|  |  |  |                        function_name=function_name, | 
					
						
							|  |  |  |                        args_signature=args_signature, | 
					
						
							|  |  |  |                        function_call=function_call, | 
					
						
							|  |  |  |                        py_args_names=py_args_names, | 
					
						
							|  |  |  |                        suffix=suffix)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             res += ret | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |     def _partial_match(self, namespaces1, namespaces2): | 
					
						
							|  |  |  |         for i in range(min(len(namespaces1), len(namespaces2))): | 
					
						
							|  |  |  |             if namespaces1[i] != namespaces2[i]: | 
					
						
							|  |  |  |                 return False | 
					
						
							|  |  |  |         return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _gen_module_var(self, namespaces): | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |         """Get the Pybind11 module name from the namespaces.""" | 
					
						
							|  |  |  |         # We skip the first value in namespaces since it is empty | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         sub_module_namespaces = namespaces[len(self.top_module_namespaces):] | 
					
						
							|  |  |  |         return "m_{}".format('_'.join(sub_module_namespaces)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _add_namespaces(self, name, namespaces): | 
					
						
							|  |  |  |         if namespaces: | 
					
						
							|  |  |  |             # Ignore the first empty global namespace. | 
					
						
							|  |  |  |             idx = 1 if not namespaces[0] else 0 | 
					
						
							|  |  |  |             return '::'.join(namespaces[idx:] + [name]) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def wrap_namespace(self, namespace): | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |         """Wrap the complete `namespace`.""" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         wrapped = "" | 
					
						
							|  |  |  |         includes = "" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         namespaces = namespace.full_namespaces() | 
					
						
							|  |  |  |         if not self._partial_match(namespaces, self.top_module_namespaces): | 
					
						
							|  |  |  |             return "", "" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if len(namespaces) < len(self.top_module_namespaces): | 
					
						
							|  |  |  |             for element in namespace.content: | 
					
						
							|  |  |  |                 if isinstance(element, parser.Include): | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |                     include = "{}\n".format(element) | 
					
						
							|  |  |  |                     # replace the angle brackets with quotes | 
					
						
							|  |  |  |                     include = include.replace('<', '"').replace('>', '"') | 
					
						
							|  |  |  |                     includes += include | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                 if isinstance(element, parser.Namespace): | 
					
						
							|  |  |  |                     ( | 
					
						
							|  |  |  |                         wrapped_namespace, | 
					
						
							|  |  |  |                         includes_namespace, | 
					
						
							|  |  |  |                     ) = self.wrap_namespace(  # noqa | 
					
						
							|  |  |  |                         element) | 
					
						
							|  |  |  |                     wrapped += wrapped_namespace | 
					
						
							|  |  |  |                     includes += includes_namespace | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             module_var = self._gen_module_var(namespaces) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if len(namespaces) > len(self.top_module_namespaces): | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |                 wrapped += ( | 
					
						
							|  |  |  |                     ' ' * 4 + 'pybind11::module {module_var} = ' | 
					
						
							|  |  |  |                     '{parent_module_var}.def_submodule("{namespace}", "' | 
					
						
							|  |  |  |                     '{namespace} submodule");\n'.format( | 
					
						
							|  |  |  |                         module_var=module_var, | 
					
						
							|  |  |  |                         namespace=namespace.name, | 
					
						
							|  |  |  |                         parent_module_var=self._gen_module_var( | 
					
						
							|  |  |  |                             namespaces[:-1]), | 
					
						
							|  |  |  |                     )) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |             # Wrap an include statement, namespace, class or enum | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |             for element in namespace.content: | 
					
						
							|  |  |  |                 if isinstance(element, parser.Include): | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |                     include = "{}\n".format(element) | 
					
						
							|  |  |  |                     # replace the angle brackets with quotes | 
					
						
							|  |  |  |                     include = include.replace('<', '"').replace('>', '"') | 
					
						
							|  |  |  |                     includes += include | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                 elif isinstance(element, parser.Namespace): | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |                     wrapped_namespace, includes_namespace = self.wrap_namespace( | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                         element) | 
					
						
							|  |  |  |                     wrapped += wrapped_namespace | 
					
						
							|  |  |  |                     includes += includes_namespace | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                 elif isinstance(element, instantiator.InstantiatedClass): | 
					
						
							|  |  |  |                     wrapped += self.wrap_instantiated_class(element) | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |                     wrapped += self.wrap_enums(element.enums, element) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |                 elif isinstance(element, instantiator.InstantiatedDeclaration): | 
					
						
							|  |  |  |                     wrapped += self.wrap_instantiated_declaration(element) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-04-20 04:09:06 +08:00
										 |  |  |                 elif isinstance(element, parser.Variable): | 
					
						
							| 
									
										
										
										
											2021-04-21 12:01:05 +08:00
										 |  |  |                     variable_namespace = self._add_namespaces('', namespaces) | 
					
						
							|  |  |  |                     wrapped += self.wrap_variable(namespace=variable_namespace, | 
					
						
							| 
									
										
										
										
											2021-04-21 05:05:32 +08:00
										 |  |  |                                                   module_var=module_var, | 
					
						
							|  |  |  |                                                   variable=element, | 
					
						
							|  |  |  |                                                   prefix='\n' + ' ' * 4) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 elif isinstance(element, parser.Enum): | 
					
						
							|  |  |  |                     wrapped += self.wrap_enum(element) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # Global functions. | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |             all_funcs = [ | 
					
						
							|  |  |  |                 func for func in namespace.content | 
					
						
							| 
									
										
										
										
											2021-03-24 12:36:02 +08:00
										 |  |  |                 if isinstance(func, (parser.GlobalFunction, | 
					
						
							|  |  |  |                                      instantiator.InstantiatedGlobalFunction)) | 
					
						
							| 
									
										
										
										
											2021-03-14 01:01:28 +08:00
										 |  |  |             ] | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |             wrapped += self.wrap_functions( | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |                 all_funcs, | 
					
						
							|  |  |  |                 self._add_namespaces('', namespaces)[:-2], | 
					
						
							|  |  |  |                 prefix='\n' + ' ' * 4 + module_var, | 
					
						
							|  |  |  |                 suffix=';', | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         return wrapped, includes | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |     def wrap_file(self, content, module_name=None, submodules=None): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Wrap the code in the interface file. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             content: The contents of the interface file. | 
					
						
							|  |  |  |             module_name: The name of the module. | 
					
						
							|  |  |  |             submodules: List of other interface file names that should be linked to. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2021-06-17 01:51:43 +08:00
										 |  |  |         # Parse the contents of the interface file | 
					
						
							|  |  |  |         module = parser.Module.parseString(content) | 
					
						
							|  |  |  |         # Instantiate all templates | 
					
						
							|  |  |  |         module = instantiator.instantiate_namespace(module) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         wrapped_namespace, includes = self.wrap_namespace(module) | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |         if self.use_boost_serialization: | 
					
						
							|  |  |  |             includes += "#include <boost/serialization/export.hpp>" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Export classes for serialization. | 
					
						
							|  |  |  |             boost_class_export = "" | 
					
						
							|  |  |  |             for cpp_class in self._serializing_classes: | 
					
						
							|  |  |  |                 new_name = cpp_class | 
					
						
							|  |  |  |                 # The boost's macro doesn't like commas, so we have to typedef. | 
					
						
							|  |  |  |                 if ',' in cpp_class: | 
					
						
							|  |  |  |                     new_name = re.sub("[,:<> ]", "", cpp_class) | 
					
						
							|  |  |  |                     boost_class_export += "typedef {cpp_class} {new_name};\n".format(  # noqa | 
					
						
							|  |  |  |                         cpp_class=cpp_class, new_name=new_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 boost_class_export += "BOOST_CLASS_EXPORT({new_name})\n".format( | 
					
						
							|  |  |  |                     new_name=new_name, ) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             boost_class_export = "" | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         # Reset the serializing classes list | 
					
						
							|  |  |  |         self._serializing_classes = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         submodules_init = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if submodules is not None: | 
					
						
							|  |  |  |             module_def = "PYBIND11_MODULE({0}, m_)".format(module_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             for idx, submodule in enumerate(submodules): | 
					
						
							|  |  |  |                 submodules[idx] = "void {0}(py::module_ &);".format(submodule) | 
					
						
							|  |  |  |                 submodules_init.append("{0}(m_);".format(submodule)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             module_def = "void {0}(py::module_ &m_)".format(module_name) | 
					
						
							|  |  |  |             submodules = [] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         return self.module_template.format( | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |             module_def=module_def, | 
					
						
							|  |  |  |             module_name=module_name, | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |             includes=includes, | 
					
						
							|  |  |  |             wrapped_namespace=wrapped_namespace, | 
					
						
							|  |  |  |             boost_class_export=boost_class_export, | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |             submodules="\n".join(submodules), | 
					
						
							|  |  |  |             submodules_init="\n".join(submodules_init), | 
					
						
							| 
									
										
										
										
											2020-08-18 02:44:43 +08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-12-23 03:14:47 +08:00
										 |  |  |     def wrap_submodule(self, source): | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2021-12-23 03:14:47 +08:00
										 |  |  |         Wrap a list of submodule files, i.e. a set of interface files which are | 
					
						
							|  |  |  |         in support of a larger wrapping project. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         E.g. This is used in GTSAM where we have a main gtsam.i, but various smaller .i files | 
					
						
							|  |  |  |         which are the submodules. | 
					
						
							|  |  |  |         The benefit of this scheme is that it reduces compute and memory usage during compilation. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             source: Interface file which forms the submodule. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         filename = Path(source).name | 
					
						
							|  |  |  |         module_name = Path(source).stem | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Read in the complete interface (.i) file | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |         with open(source, "r", encoding="UTF-8") as f: | 
					
						
							| 
									
										
										
										
											2021-12-23 03:14:47 +08:00
										 |  |  |             content = f.read() | 
					
						
							|  |  |  |         # Wrap the read-in content | 
					
						
							|  |  |  |         cc_content = self.wrap_file(content, module_name=module_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Generate the C++ code which Pybind11 will use. | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |         with open(filename.replace(".i", ".cpp"), "w", encoding="UTF-8") as f: | 
					
						
							| 
									
										
										
										
											2021-12-23 03:14:47 +08:00
										 |  |  |             f.write(cc_content) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def wrap(self, sources, main_module_name): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Wrap all the main interface file. | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             sources: List of all interface files. | 
					
						
							| 
									
										
										
										
											2021-12-23 03:14:47 +08:00
										 |  |  |                 The first file should be the main module. | 
					
						
							|  |  |  |             main_module_name: The name for the main module. | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         main_module = sources[0] | 
					
						
							| 
									
										
										
										
											2021-12-23 03:14:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Get all the submodule names. | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |         submodules = [] | 
					
						
							|  |  |  |         for source in sources[1:]: | 
					
						
							|  |  |  |             module_name = Path(source).stem | 
					
						
							|  |  |  |             submodules.append(module_name) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |         with open(main_module, "r", encoding="UTF-8") as f: | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |             content = f.read() | 
					
						
							|  |  |  |         cc_content = self.wrap_file(content, | 
					
						
							|  |  |  |                                     module_name=self.module_name, | 
					
						
							|  |  |  |                                     submodules=submodules) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Generate the C++ code which Pybind11 will use. | 
					
						
							| 
									
										
										
										
											2023-02-17 16:13:23 +08:00
										 |  |  |         with open(main_module_name, "w", encoding="UTF-8") as f: | 
					
						
							| 
									
										
										
										
											2021-07-11 23:10:35 +08:00
										 |  |  |             f.write(cc_content) |