gtsam/gtwrap/pybind_wrapper.py

321 lines
14 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
GTSAM Copyright 2010-2020, Georgia Tech Research Corporation,
Atlanta, Georgia 30332-0415
All Rights Reserved
See LICENSE for the license information
Code generator for wrapping a C++ module with Pybind11
Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar and Frank Dellaert
"""
import re
import textwrap
Squashed 'wrap/' changes from dfa624e77..09f8bbf71 09f8bbf71 Merge pull request #25 from borglab/fix/function-name 0dbfb6c13 fix function name to be the correct one f69f8b01f Merge pull request #24 from borglab/fix/pip 6519a6627 use pip install to overcome superuser issues b11ecf4e8 Merge pull request #23 from borglab/fix/remove-pip-args 813030108 remove pip-args since we are using setup.py 498d233e0 Merge pull request #22 from borglab/fix/package-install 846212ac3 set correct flags for installing gtwrap package 62161cd20 Merge pull request #21 from borglab/feature/script-vars 93be1d9f8 set script variables and move pybind11 loading so gtwrap can be used under gtsam 8770e3c7e Merge pull request #20 from borglab/fix/pybind-include 8c3c83618 proper placement of pybind11 include a9ad4f504 Merge pull request #19 from borglab/feature/package 99d8a12c7 added more documentation 4cbec1579 change to macro so we don't have to deal with function scopes b83e405b8 updates to completely install the package 38a64b3de new scripts which will be installed to bin directory bf9646235 Merge pull request #18 from borglab/fix/cmake-min c7c280099 Consistent cmake minimum required 42df58f62 Merge pull request #17 from borglab/fix/cleanup e580b282d version bump 4ccd66fa5 More finegrained handling of Python version 6476fd710 Merge pull request #16 from borglab/feature/better-find-python 8ac1296a0 use setup.py to install dependencies e9ac473be install dependencies and support versions of CMake<3.12 cf272dbd2 Merge pull request #15 from borglab/feature/utils ffc9cc4f7 new utils to reduce boilerplate 20e8e8b7a Merge pull request #11 from borglab/feature/package 04b844bd6 use new version of FindPython and be consistent 3f9d7a32a Merge pull request #13 from borglab/add_license c791075a6 Add LICENSE 517b67c46 correct working directory for setup.py 1b22b47ae move matlab.h to root directory 37b407214 Proper source directory path for use in other projects 61696dd5d configure PybindWrap within the cmake directory 1b91fc9af add config file so we can use find_package a1e6f4f53 small typo da9f351be updated README and housekeeping 64b8f78d5 files needed to allow for packaging bddda7f54 package structure git-subtree-dir: wrap git-subtree-split: 09f8bbf7172ba8b1bd3d2484795743f16e1a5893
2021-01-05 02:11:36 +08:00
import gtwrap.interface_parser as parser
import gtwrap.template_instantiator as instantiator
class PybindWrapper(object):
def __init__(self,
module,
module_name,
top_module_namespaces='',
use_boost=False,
ignore_classes=[],
module_template=""):
self.module = module
self.module_name = module_name
self.top_module_namespaces = top_module_namespaces
self.use_boost = use_boost
self.ignore_classes = ignore_classes
self._serializing_classes = list()
self.module_template = module_template
self.python_keywords = ['print', 'lambda']
def _py_args_names(self, args_list):
names = args_list.args_names()
if names:
py_args = ['py::arg("{}")'.format(name) for name in names]
return ", " + ", ".join(py_args)
else:
return ''
def _method_args_signature_with_names(self, args_list):
cpp_types = args_list.to_cpp(self.use_boost)
names = args_list.args_names()
types_names = ["{} {}".format(ctype, name) for ctype, name in zip(cpp_types, names)]
return ','.join(types_names)
def wrap_ctors(self, my_class):
res = ""
for ctor in my_class.ctors:
res += ('\n' + ' ' * 8 + '.def(py::init<{args_cpp_types}>()'
'{py_args_names})'.format(
args_cpp_types=", ".join(ctor.args.to_cpp(self.use_boost)),
py_args_names=self._py_args_names(ctor.args),
))
return res
def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""):
py_method = method.name + method_suffix
cpp_method = method.to_cpp()
if cpp_method in ["serialize", "serializable"]:
if not cpp_class in self._serializing_classes:
self._serializing_classes.append(cpp_class)
return textwrap.dedent('''
.def("serialize",
[]({class_inst} self){{
return gtsam::serialize(*self);
}}
)
.def("deserialize",
[]({class_inst} self, string serialized){{
gtsam::deserialize(serialized, *self);
}}, py::arg("serialized"))
'''.format(class_inst=cpp_class + '*'))
is_method = isinstance(method, instantiator.InstantiatedMethod)
is_static = isinstance(method, parser.StaticMethod)
return_void = method.return_type.is_void()
args_names = method.args.args_names()
py_args_names = self._py_args_names(method.args)
args_signature_with_names = self._method_args_signature_with_names(method.args)
caller = cpp_class + "::" if not is_method else "self->"
function_call = ('{opt_return} {caller}{function_name}'
'({args_names});'.format(
opt_return='return' if not return_void else '',
caller=caller,
function_name=cpp_method,
args_names=', '.join(args_names),
))
ret = ('{prefix}.{cdef}("{py_method}",'
'[]({opt_self}{opt_comma}{args_signature_with_names}){{'
'{function_call}'
'}}'
'{py_args_names}){suffix}'.format(
prefix=prefix,
cdef="def_static" if is_static else "def",
py_method=py_method if not py_method in self.python_keywords else py_method + "_",
opt_self="{cpp_class}* self".format(cpp_class=cpp_class) if is_method else "",
cpp_class=cpp_class,
cpp_method=cpp_method,
opt_comma=',' if is_method and args_names else '',
args_signature_with_names=args_signature_with_names,
function_call=function_call,
py_args_names=py_args_names,
suffix=suffix,
))
if method.name == 'print':
type_list = method.args.to_cpp(self.use_boost)
if len(type_list) > 0 and type_list[0].strip() == 'string':
ret += '''{prefix}.def("__repr__",
[](const {cpp_class} &a) {{
gtsam::RedirectCout redirect;
a.print("");
return redirect.str();
}}){suffix}'''.format(
prefix=prefix,
cpp_class=cpp_class,
suffix=suffix,
)
else:
ret += '''{prefix}.def("__repr__",
[](const {cpp_class} &a) {{
gtsam::RedirectCout redirect;
a.print();
return redirect.str();
}}){suffix}'''.format(
prefix=prefix,
cpp_class=cpp_class,
suffix=suffix,
)
return ret
def wrap_methods(self, methods, cpp_class, prefix='\n' + ' ' * 8, suffix=''):
res = ""
for method in methods:
# To avoid type confusion for insert, currently unused
if method.name == 'insert' and cpp_class == 'gtsam::Values':
name_list = method.args.args_names()
type_list = method.args.to_cpp(self.use_boost)
if type_list[0].strip() == 'size_t': # inserting non-wrapped value types
method_suffix = '_' + name_list[1].strip()
res += self._wrap_method(method=method,
cpp_class=cpp_class,
prefix=prefix,
suffix=suffix,
method_suffix=method_suffix)
res += self._wrap_method(
method=method,
cpp_class=cpp_class,
prefix=prefix,
suffix=suffix,
)
return res
def wrap_properties(self, properties, cpp_class, prefix='\n' + ' ' * 8):
res = ""
for prop in properties:
res += ('{prefix}.def_{property}("{property_name}", '
'&{cpp_class}::{property_name})'.format(
prefix=prefix,
property="readonly" if prop.ctype.is_const else "readwrite",
cpp_class=cpp_class,
property_name=prop.name,
))
return res
def wrap_instantiated_class(self, instantiated_class):
module_var = self._gen_module_var(instantiated_class.namespaces())
cpp_class = instantiated_class.cpp_class()
if cpp_class in self.ignore_classes:
return ""
return ('\n py::class_<{cpp_class}, {class_parent}'
'{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")'
'{wrapped_ctors}'
'{wrapped_methods}'
'{wrapped_static_methods}'
'{wrapped_properties};\n'.format(
shared_ptr_type=('boost' if self.use_boost else 'std'),
cpp_class=cpp_class,
class_name=instantiated_class.name,
class_parent=str(instantiated_class.parent_class) +
(', ' if instantiated_class.parent_class else ''),
module_var=module_var,
wrapped_ctors=self.wrap_ctors(instantiated_class),
wrapped_methods=self.wrap_methods(instantiated_class.methods, cpp_class),
wrapped_static_methods=self.wrap_methods(instantiated_class.static_methods, cpp_class),
wrapped_properties=self.wrap_properties(instantiated_class.properties, cpp_class),
))
def wrap_stl_class(self, stl_class):
module_var = self._gen_module_var(stl_class.namespaces())
cpp_class = stl_class.cpp_class()
if cpp_class in self.ignore_classes:
return ""
return ('\n py::class_<{cpp_class}, {class_parent}'
'{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")'
'{wrapped_ctors}'
'{wrapped_methods}'
'{wrapped_static_methods}'
'{wrapped_properties};\n'.format(
shared_ptr_type=('boost' if self.use_boost else 'std'),
cpp_class=cpp_class,
class_name=stl_class.name,
class_parent=str(stl_class.parent_class) + (', ' if stl_class.parent_class else ''),
module_var=module_var,
wrapped_ctors=self.wrap_ctors(stl_class),
wrapped_methods=self.wrap_methods(stl_class.methods, cpp_class),
wrapped_static_methods=self.wrap_methods(stl_class.static_methods, cpp_class),
wrapped_properties=self.wrap_properties(stl_class.properties, cpp_class),
))
def _partial_match(self, namespaces1, namespaces2):
for i in range(min(len(namespaces1), len(namespaces2))):
if namespaces1[i] != namespaces2[i]:
return False
return True
def _gen_module_var(self, namespaces):
sub_module_namespaces = namespaces[len(self.top_module_namespaces):]
return "m_{}".format('_'.join(sub_module_namespaces))
def _add_namespaces(self, name, namespaces):
if namespaces:
# Ignore the first empty global namespace.
idx = 1 if not namespaces[0] else 0
return '::'.join(namespaces[idx:] + [name])
else:
return name
def wrap_namespace(self, namespace):
wrapped = ""
includes = ""
namespaces = namespace.full_namespaces()
if not self._partial_match(namespaces, self.top_module_namespaces):
return "", ""
if len(namespaces) < len(self.top_module_namespaces):
for element in namespace.content:
if isinstance(element, parser.Include):
includes += ("{}\n".format(element).replace('<', '"').replace('>', '"'))
if isinstance(element, parser.Namespace):
(
wrapped_namespace,
includes_namespace,
) = self.wrap_namespace( # noqa
element)
wrapped += wrapped_namespace
includes += includes_namespace
else:
module_var = self._gen_module_var(namespaces)
if len(namespaces) > len(self.top_module_namespaces):
wrapped += (' ' * 4 + 'pybind11::module {module_var} = '
'{parent_module_var}.def_submodule("{namespace}", "'
'{namespace} submodule");\n'.format(
module_var=module_var,
namespace=namespace.name,
parent_module_var=self._gen_module_var(namespaces[:-1]),
))
for element in namespace.content:
if isinstance(element, parser.Include):
includes += ("{}\n".format(element).replace('<', '"').replace('>', '"'))
elif isinstance(element, parser.Namespace):
(
wrapped_namespace,
includes_namespace,
) = self.wrap_namespace( # noqa
element)
wrapped += wrapped_namespace
includes += includes_namespace
elif isinstance(element, instantiator.InstantiatedClass):
wrapped += self.wrap_instantiated_class(element)
# Global functions.
all_funcs = [func for func in namespace.content if isinstance(func, parser.GlobalFunction)]
wrapped += self.wrap_methods(
all_funcs,
self._add_namespaces('', namespaces)[:-2],
prefix='\n' + ' ' * 4 + module_var,
suffix=';',
)
return wrapped, includes
def wrap(self):
wrapped_namespace, includes = self.wrap_namespace(self.module)
# Export classes for serialization.
boost_class_export = ""
for cpp_class in self._serializing_classes:
new_name = cpp_class
# The boost's macro doesn't like commas, so we have to typedef.
if ',' in cpp_class:
new_name = re.sub("[,:<> ]", "", cpp_class)
boost_class_export += "typedef {cpp_class} {new_name};\n".format( # noqa
cpp_class=cpp_class,
new_name=new_name,
)
boost_class_export += "BOOST_CLASS_EXPORT({new_name})\n".format(new_name=new_name, )
return self.module_template.format(
include_boost="#include <boost/shared_ptr.hpp>" if self.use_boost else "",
module_name=self.module_name,
includes=includes,
hoder_type=
"PYBIND11_DECLARE_HOLDER_TYPE(TYPE_PLACEHOLDER_DONOTUSE, {shared_ptr_type}::shared_ptr<TYPE_PLACEHOLDER_DONOTUSE>);"
.format(shared_ptr_type=('boost' if self.use_boost else 'std')) if self.use_boost else "",
wrapped_namespace=wrapped_namespace,
boost_class_export=boost_class_export,
)