From afc9333a17167aed7a351d763e8b6fd06eb3bb3b Mon Sep 17 00:00:00 2001 From: Ayush Baid Date: Mon, 8 Mar 2021 19:53:26 -0500 Subject: [PATCH] Squashed 'wrap/' changes from d37b8a972..10e1efd6f 10e1efd6f Merge pull request #32 from ayushbaid/feature/pickle 55d5d7fbe ignoring pickle in matlab wrap 8d70c7fe2 adding newlines dee8aaee3 adding markers for pickling 46fc45d82 separating out the marker for pickle git-subtree-dir: wrap git-subtree-split: 10e1efd6f25f76b774868b3da33cb3954008b233 --- gtwrap/matlab_wrapper.py | 12 ++++++++++++ gtwrap/pybind_wrapper.py | 7 ++++++- tests/expected-python/geometry_pybind.cpp | 2 ++ tests/geometry.h | 6 ++++++ 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/gtwrap/matlab_wrapper.py b/gtwrap/matlab_wrapper.py index 669bf474f..fe4ee7e19 100755 --- a/gtwrap/matlab_wrapper.py +++ b/gtwrap/matlab_wrapper.py @@ -49,6 +49,8 @@ class MatlabWrapper(object): } """Methods that should not be wrapped directly""" whitelist = ['serializable', 'serialize'] + """Methods that should be ignored""" + ignore_methods = ['pickle'] """Datatypes that do not need to be checked in methods""" not_check_type = [] """Data types that are primitive types""" @@ -563,6 +565,8 @@ class MatlabWrapper(object): for method in methods: if method.name in self.whitelist: continue + if method.name in self.ignore_methods: + continue comment += '%{name}({args})'.format(name=method.name, args=self._wrap_args(method.args)) @@ -612,6 +616,9 @@ class MatlabWrapper(object): methods = self._group_methods(methods) for method in methods: + if method in self.ignore_methods: + continue + if globals: self._debug("[wrap_methods] wrapping: {}..{}={}".format(method[0].parent.name, method[0].name, type(method[0].parent.name))) @@ -861,6 +868,8 @@ class MatlabWrapper(object): method_name = method[0].name if method_name in self.whitelist and method_name != 'serialize': continue + if method_name in self.ignore_methods: + continue if method_name == 'serialize': serialize[0] = True @@ -932,6 +941,9 @@ class MatlabWrapper(object): format_name = list(static_method[0].name) format_name[0] = format_name[0].upper() + if static_method[0].name in self.ignore_methods: + continue + method_text += textwrap.indent(textwrap.dedent('''\ function varargout = {name}(varargin) '''.format(name=''.join(format_name))), diff --git a/gtwrap/pybind_wrapper.py b/gtwrap/pybind_wrapper.py index ec5480785..a045afcbd 100755 --- a/gtwrap/pybind_wrapper.py +++ b/gtwrap/pybind_wrapper.py @@ -75,6 +75,11 @@ class PybindWrapper(object): []({class_inst} self, string serialized){{ gtsam::deserialize(serialized, *self); }}, py::arg("serialized")) + '''.format(class_inst=cpp_class + '*')) + if cpp_method == "pickle": + if not cpp_class in self._serializing_classes: + raise ValueError("Cannot pickle a class which is not serializable") + return textwrap.dedent(''' .def(py::pickle( [](const {cpp_class} &a){{ // __getstate__ /* Returns a string that encodes the state of the object */ @@ -85,7 +90,7 @@ class PybindWrapper(object): gtsam::deserialize(t[0].cast(), obj); return obj; }})) - '''.format(class_inst=cpp_class + '*', cpp_class=cpp_class)) + '''.format(cpp_class=cpp_class)) is_method = isinstance(method, instantiator.InstantiatedMethod) is_static = isinstance(method, parser.StaticMethod) diff --git a/tests/expected-python/geometry_pybind.cpp b/tests/expected-python/geometry_pybind.cpp index 53d33eabf..be6482d89 100644 --- a/tests/expected-python/geometry_pybind.cpp +++ b/tests/expected-python/geometry_pybind.cpp @@ -47,6 +47,7 @@ PYBIND11_MODULE(geometry_py, m_) { [](gtsam::Point2* self, string serialized){ gtsam::deserialize(serialized, *self); }, py::arg("serialized")) + .def(py::pickle( [](const gtsam::Point2 &a){ // __getstate__ /* Returns a string that encodes the state of the object */ @@ -71,6 +72,7 @@ PYBIND11_MODULE(geometry_py, m_) { [](gtsam::Point3* self, string serialized){ gtsam::deserialize(serialized, *self); }, py::arg("serialized")) + .def(py::pickle( [](const gtsam::Point3 &a){ // __getstate__ /* Returns a string that encodes the state of the object */ diff --git a/tests/geometry.h b/tests/geometry.h index 40d878c9f..ec5d3b277 100644 --- a/tests/geometry.h +++ b/tests/geometry.h @@ -22,6 +22,9 @@ class Point2 { VectorNotEigen vectorConfusion(); void serializable() const; // Sets flag and creates export, but does not make serialization functions + + // enable pickling in python + void pickle() const; }; #include @@ -35,6 +38,9 @@ class Point3 { // enabling serialization functionality void serialize() const; // Just triggers a flag internally and removes actual function + + // enable pickling in python + void pickle() const; }; }