Update `logging_optimizer.gtsam_optimize` to use NonlinearOptimizerParams::iterationHook

release/4.3a0
Gerry Chen 2022-04-19 16:03:38 -04:00
parent ddca736c7b
commit 71aa20ff33
No known key found for this signature in database
GPG Key ID: E9845092D3A57286
2 changed files with 11 additions and 10 deletions

View File

@ -63,12 +63,14 @@ class TestOptimizeComet(GtsamTestCase):
def hook(_, error):
print(error)
# Only thing we require from optimizer is an iterate method
# Wrapper function sets the hook and calls optimizer.optimize() for us.
gtsam_optimize(self.optimizer, self.params, hook)
# Check that optimizing yields the identity.
actual = self.optimizer.values()
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
self.assertEqual(self.capturedOutput.getvalue(),
"0.020000000000000004\n0.010000000000000005\n0.010000000000000004\n")
def test_lm_simple_printing(self):
"""Make sure we are properly terminating LM"""
@ -79,6 +81,8 @@ class TestOptimizeComet(GtsamTestCase):
actual = self.lmoptimizer.values()
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
self.assertEqual(self.capturedOutput.getvalue(),
"0.020000000000000004\n0.010000000000249996\n0.009999999999999998\n")
@unittest.skip("Not a test we want run every time, as needs comet.ml account")
def test_comet(self):

View File

@ -21,7 +21,8 @@ def optimize(optimizer, check_convergence, hook):
current_error = optimizer.error()
hook(optimizer, current_error)
# Iterative loop
# Iterative loop. Cannot use `params.iterationHook` because we don't have access to params
# (backwards compatibility issue).
while True:
# Do next iteration
optimizer.iterate()
@ -35,7 +36,7 @@ def optimize(optimizer, check_convergence, hook):
def gtsam_optimize(optimizer,
params,
hook):
""" Given an optimizer and params, iterate until convergence.
""" Given an optimizer and its params, iterate until convergence.
After each iteration, hook(optimizer) is called.
After the function, use values and errors to get the result.
Arguments:
@ -43,10 +44,6 @@ def gtsam_optimize(optimizer,
params {NonlinearOptimizarParams} -- Nonlinear optimizer parameters
hook -- hook function to record the error
"""
def check_convergence(optimizer, current_error, new_error):
return (optimizer.iterations() >= params.getMaxIterations()) or (
gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(),
current_error, new_error)) or (
isinstance(optimizer, gtsam.LevenbergMarquardtOptimizer) and optimizer.lambda_() > params.getlambdaUpperBound())
optimize(optimizer, check_convergence, hook)
return optimizer.values()
hook(optimizer, optimizer.error()) # call once at start (backwards compatibility)
params.iterationHook = lambda iteration, error_before, error_after: hook(optimizer, error_after)
return optimizer.optimize()