diff --git a/python/gtsam/tests/test_logging_optimizer.py b/python/gtsam/tests/test_logging_optimizer.py index 47eb32e7b..4ec782635 100644 --- a/python/gtsam/tests/test_logging_optimizer.py +++ b/python/gtsam/tests/test_logging_optimizer.py @@ -63,12 +63,14 @@ class TestOptimizeComet(GtsamTestCase): def hook(_, error): print(error) - # Only thing we require from optimizer is an iterate method + # Wrapper function sets the hook and calls optimizer.optimize() for us. gtsam_optimize(self.optimizer, self.params, hook) # Check that optimizing yields the identity. actual = self.optimizer.values() self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) + self.assertEqual(self.capturedOutput.getvalue(), + "0.020000000000000004\n0.010000000000000005\n0.010000000000000004\n") def test_lm_simple_printing(self): """Make sure we are properly terminating LM""" @@ -79,6 +81,8 @@ class TestOptimizeComet(GtsamTestCase): actual = self.lmoptimizer.values() self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) + self.assertEqual(self.capturedOutput.getvalue(), + "0.020000000000000004\n0.010000000000249996\n0.009999999999999998\n") @unittest.skip("Not a test we want run every time, as needs comet.ml account") def test_comet(self): diff --git a/python/gtsam/utils/logging_optimizer.py b/python/gtsam/utils/logging_optimizer.py index 3d9175951..1e55ce990 100644 --- a/python/gtsam/utils/logging_optimizer.py +++ b/python/gtsam/utils/logging_optimizer.py @@ -21,7 +21,8 @@ def optimize(optimizer, check_convergence, hook): current_error = optimizer.error() hook(optimizer, current_error) - # Iterative loop + # Iterative loop. Cannot use `params.iterationHook` because we don't have access to params + # (backwards compatibility issue). while True: # Do next iteration optimizer.iterate() @@ -35,7 +36,7 @@ def optimize(optimizer, check_convergence, hook): def gtsam_optimize(optimizer, params, hook): - """ Given an optimizer and params, iterate until convergence. + """ Given an optimizer and its params, iterate until convergence. After each iteration, hook(optimizer) is called. After the function, use values and errors to get the result. Arguments: @@ -43,10 +44,6 @@ def gtsam_optimize(optimizer, params {NonlinearOptimizarParams} -- Nonlinear optimizer parameters hook -- hook function to record the error """ - def check_convergence(optimizer, current_error, new_error): - return (optimizer.iterations() >= params.getMaxIterations()) or ( - gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(), - current_error, new_error)) or ( - isinstance(optimizer, gtsam.LevenbergMarquardtOptimizer) and optimizer.lambda_() > params.getlambdaUpperBound()) - optimize(optimizer, check_convergence, hook) - return optimizer.values() + hook(optimizer, optimizer.error()) # call once at start (backwards compatibility) + params.iterationHook = lambda iteration, error_before, error_after: hook(optimizer, error_after) + return optimizer.optimize()