From 5796fe348820c5d6a2e35110b2223e88eb77488c Mon Sep 17 00:00:00 2001 From: Gerry Chen Date: Wed, 20 Apr 2022 16:21:59 -0400 Subject: [PATCH] Create convenience wrapper function in logging_optimizer --- python/gtsam/tests/test_logging_optimizer.py | 54 ++++++++++---------- python/gtsam/utils/logging_optimizer.py | 45 ++++++++++++++++ 2 files changed, 73 insertions(+), 26 deletions(-) diff --git a/python/gtsam/tests/test_logging_optimizer.py b/python/gtsam/tests/test_logging_optimizer.py index 4ec782635..b4f32b14f 100644 --- a/python/gtsam/tests/test_logging_optimizer.py +++ b/python/gtsam/tests/test_logging_optimizer.py @@ -18,7 +18,7 @@ import numpy as np from gtsam import Rot3 from gtsam.utils.test_case import GtsamTestCase -from gtsam.utils.logging_optimizer import gtsam_optimize +from gtsam.utils.logging_optimizer import gtsam_optimize, optimize_using KEY = 0 MODEL = gtsam.noiseModel.Unit.Create(3) @@ -34,19 +34,18 @@ class TestOptimizeComet(GtsamTestCase): rotations = {R, R.inverse()} # mean is the identity self.expected = Rot3() - graph = gtsam.NonlinearFactorGraph() - for R in rotations: - graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL)) - initial = gtsam.Values() - initial.insert(KEY, R) - self.params = gtsam.GaussNewtonParams() - self.optimizer = gtsam.GaussNewtonOptimizer( - graph, initial, self.params) + def check(actual): + # Check that optimizing yields the identity + self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) + # Check that logging output prints out 3 lines (exact intermediate values differ by OS) + self.assertEqual(self.capturedOutput.getvalue().count('\n'), 3) + self.check = check - self.lmparams = gtsam.LevenbergMarquardtParams() - self.lmoptimizer = gtsam.LevenbergMarquardtOptimizer( - graph, initial, self.lmparams - ) + self.graph = gtsam.NonlinearFactorGraph() + for R in rotations: + self.graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL)) + self.initial = gtsam.Values() + self.initial.insert(KEY, R) # setup output capture self.capturedOutput = StringIO() @@ -64,25 +63,28 @@ class TestOptimizeComet(GtsamTestCase): print(error) # Wrapper function sets the hook and calls optimizer.optimize() for us. - gtsam_optimize(self.optimizer, self.params, hook) - - # Check that optimizing yields the identity. - actual = self.optimizer.values() - self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) - self.assertEqual(self.capturedOutput.getvalue(), - "0.020000000000000004\n0.010000000000000005\n0.010000000000000004\n") + params = gtsam.GaussNewtonParams() + actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial) + self.check(actual) + actual = optimize_using(gtsam.GaussNewtonOptimizer, hook)(self.graph, self.initial, params) + self.check(actual) + actual = gtsam_optimize(gtsam.GaussNewtonOptimizer(self.graph, self.initial, params), + params, hook) + self.check(actual) def test_lm_simple_printing(self): """Make sure we are properly terminating LM""" def hook(_, error): print(error) - gtsam_optimize(self.lmoptimizer, self.lmparams, hook) - - actual = self.lmoptimizer.values() - self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) - self.assertEqual(self.capturedOutput.getvalue(), - "0.020000000000000004\n0.010000000000249996\n0.009999999999999998\n") + params = gtsam.LevenbergMarquardtParams() + actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial) + self.check(actual) + actual = optimize_using(gtsam.LevenbergMarquardtOptimizer, hook)(self.graph, self.initial, + params) + self.check(actual) + actual = gtsam_optimize(gtsam.LevenbergMarquardtOptimizer(self.graph, self.initial, params), + params, hook) @unittest.skip("Not a test we want run every time, as needs comet.ml account") def test_comet(self): diff --git a/python/gtsam/utils/logging_optimizer.py b/python/gtsam/utils/logging_optimizer.py index bf727f997..f89208bc5 100644 --- a/python/gtsam/utils/logging_optimizer.py +++ b/python/gtsam/utils/logging_optimizer.py @@ -6,6 +6,50 @@ Author: Jing Wu and Frank Dellaert from gtsam import NonlinearOptimizer, NonlinearOptimizerParams import gtsam +from typing import Any, Callable + +OPTIMIZER_PARAMS_MAP = { + gtsam.GaussNewtonOptimizer: gtsam.GaussNewtonParams, + gtsam.LevenbergMarquardtOptimizer: gtsam.LevenbergMarquardtParams, + gtsam.DoglegOptimizer: gtsam.DoglegParams, + gtsam.GncGaussNewtonOptimizer: gtsam.GaussNewtonParams, + gtsam.GncLMOptimizer: gtsam.LevenbergMarquardtParams +} + + +def optimize_using(OptimizerClass, hook) -> Callable[[Any], gtsam.Values]: + """ Wraps the constructor and "optimize()" call for an Optimizer together and adds an iteration + hook. + Example usage: + solution = optimize_using(gtsam.GaussNewtonOptimizer, hook)(graph, init, params) + + Args: + OptimizerClass (T): A NonlinearOptimizer class (e.g. GaussNewtonOptimizer, + LevenbergMarquadrtOptimizer) + hook ([T, double] -> None): Function to callback after each iteration. Args are (optimizer, + error) and return should be None. + Returns: + (Callable[*, gtsam.Values]): Call the returned function with the usual NonlinearOptimizer + arguments (will be forwarded to constructor) and it will return a Values object + representing the solution. See example usage above. + """ + + def wrapped_optimize(*args): + for arg in args: + if isinstance(arg, gtsam.NonlinearOptimizerParams): + arg.iterationHook = lambda iteration, error_before, error_after: hook( + optimizer, error_after) + break + else: + params = OPTIMIZER_PARAMS_MAP[OptimizerClass]() + params.iterationHook = lambda iteration, error_before, error_after: hook( + optimizer, error_after) + args = (*args, params) + optimizer = OptimizerClass(*args) + hook(optimizer, optimizer.error()) + return optimizer.optimize() + + return wrapped_optimize def optimize(optimizer, check_convergence, hook): @@ -37,6 +81,7 @@ def gtsam_optimize(optimizer, params, hook): """ Given an optimizer and params, iterate until convergence. + Recommend using optimize_using instead. After each iteration, hook(optimizer) is called. After the function, use values and errors to get the result. Arguments: