Update `logging_optimizer.gtsam_optimize` to use NonlinearOptimizerParams::iterationHook

release/4.3a0
Gerry Chen 2022-04-19 16:03:38 -04:00
parent ddca736c7b
commit 71aa20ff33
No known key found for this signature in database
GPG Key ID: E9845092D3A57286
2 changed files with 11 additions and 10 deletions

View File

@ -63,12 +63,14 @@ class TestOptimizeComet(GtsamTestCase):
def hook(_, error): def hook(_, error):
print(error) print(error)
# Only thing we require from optimizer is an iterate method # Wrapper function sets the hook and calls optimizer.optimize() for us.
gtsam_optimize(self.optimizer, self.params, hook) gtsam_optimize(self.optimizer, self.params, hook)
# Check that optimizing yields the identity. # Check that optimizing yields the identity.
actual = self.optimizer.values() actual = self.optimizer.values()
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
self.assertEqual(self.capturedOutput.getvalue(),
"0.020000000000000004\n0.010000000000000005\n0.010000000000000004\n")
def test_lm_simple_printing(self): def test_lm_simple_printing(self):
"""Make sure we are properly terminating LM""" """Make sure we are properly terminating LM"""
@ -79,6 +81,8 @@ class TestOptimizeComet(GtsamTestCase):
actual = self.lmoptimizer.values() actual = self.lmoptimizer.values()
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
self.assertEqual(self.capturedOutput.getvalue(),
"0.020000000000000004\n0.010000000000249996\n0.009999999999999998\n")
@unittest.skip("Not a test we want run every time, as needs comet.ml account") @unittest.skip("Not a test we want run every time, as needs comet.ml account")
def test_comet(self): def test_comet(self):

View File

@ -21,7 +21,8 @@ def optimize(optimizer, check_convergence, hook):
current_error = optimizer.error() current_error = optimizer.error()
hook(optimizer, current_error) hook(optimizer, current_error)
# Iterative loop # Iterative loop. Cannot use `params.iterationHook` because we don't have access to params
# (backwards compatibility issue).
while True: while True:
# Do next iteration # Do next iteration
optimizer.iterate() optimizer.iterate()
@ -35,7 +36,7 @@ def optimize(optimizer, check_convergence, hook):
def gtsam_optimize(optimizer, def gtsam_optimize(optimizer,
params, params,
hook): hook):
""" Given an optimizer and params, iterate until convergence. """ Given an optimizer and its params, iterate until convergence.
After each iteration, hook(optimizer) is called. After each iteration, hook(optimizer) is called.
After the function, use values and errors to get the result. After the function, use values and errors to get the result.
Arguments: Arguments:
@ -43,10 +44,6 @@ def gtsam_optimize(optimizer,
params {NonlinearOptimizarParams} -- Nonlinear optimizer parameters params {NonlinearOptimizarParams} -- Nonlinear optimizer parameters
hook -- hook function to record the error hook -- hook function to record the error
""" """
def check_convergence(optimizer, current_error, new_error): hook(optimizer, optimizer.error()) # call once at start (backwards compatibility)
return (optimizer.iterations() >= params.getMaxIterations()) or ( params.iterationHook = lambda iteration, error_before, error_after: hook(optimizer, error_after)
gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(), return optimizer.optimize()
current_error, new_error)) or (
isinstance(optimizer, gtsam.LevenbergMarquardtOptimizer) and optimizer.lambda_() > params.getlambdaUpperBound())
optimize(optimizer, check_convergence, hook)
return optimizer.values()