Update `logging_optimizer.gtsam_optimize` to use NonlinearOptimizerParams::iterationHook
parent
ddca736c7b
commit
71aa20ff33
|
@ -63,12 +63,14 @@ class TestOptimizeComet(GtsamTestCase):
|
|||
def hook(_, error):
|
||||
print(error)
|
||||
|
||||
# Only thing we require from optimizer is an iterate method
|
||||
# Wrapper function sets the hook and calls optimizer.optimize() for us.
|
||||
gtsam_optimize(self.optimizer, self.params, hook)
|
||||
|
||||
# Check that optimizing yields the identity.
|
||||
actual = self.optimizer.values()
|
||||
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
|
||||
self.assertEqual(self.capturedOutput.getvalue(),
|
||||
"0.020000000000000004\n0.010000000000000005\n0.010000000000000004\n")
|
||||
|
||||
def test_lm_simple_printing(self):
|
||||
"""Make sure we are properly terminating LM"""
|
||||
|
@ -79,6 +81,8 @@ class TestOptimizeComet(GtsamTestCase):
|
|||
|
||||
actual = self.lmoptimizer.values()
|
||||
self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
|
||||
self.assertEqual(self.capturedOutput.getvalue(),
|
||||
"0.020000000000000004\n0.010000000000249996\n0.009999999999999998\n")
|
||||
|
||||
@unittest.skip("Not a test we want run every time, as needs comet.ml account")
|
||||
def test_comet(self):
|
||||
|
|
|
@ -21,7 +21,8 @@ def optimize(optimizer, check_convergence, hook):
|
|||
current_error = optimizer.error()
|
||||
hook(optimizer, current_error)
|
||||
|
||||
# Iterative loop
|
||||
# Iterative loop. Cannot use `params.iterationHook` because we don't have access to params
|
||||
# (backwards compatibility issue).
|
||||
while True:
|
||||
# Do next iteration
|
||||
optimizer.iterate()
|
||||
|
@ -35,7 +36,7 @@ def optimize(optimizer, check_convergence, hook):
|
|||
def gtsam_optimize(optimizer,
|
||||
params,
|
||||
hook):
|
||||
""" Given an optimizer and params, iterate until convergence.
|
||||
""" Given an optimizer and its params, iterate until convergence.
|
||||
After each iteration, hook(optimizer) is called.
|
||||
After the function, use values and errors to get the result.
|
||||
Arguments:
|
||||
|
@ -43,10 +44,6 @@ def gtsam_optimize(optimizer,
|
|||
params {NonlinearOptimizarParams} -- Nonlinear optimizer parameters
|
||||
hook -- hook function to record the error
|
||||
"""
|
||||
def check_convergence(optimizer, current_error, new_error):
|
||||
return (optimizer.iterations() >= params.getMaxIterations()) or (
|
||||
gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(),
|
||||
current_error, new_error)) or (
|
||||
isinstance(optimizer, gtsam.LevenbergMarquardtOptimizer) and optimizer.lambda_() > params.getlambdaUpperBound())
|
||||
optimize(optimizer, check_convergence, hook)
|
||||
return optimizer.values()
|
||||
hook(optimizer, optimizer.error()) # call once at start (backwards compatibility)
|
||||
params.iterationHook = lambda iteration, error_before, error_after: hook(optimizer, error_after)
|
||||
return optimizer.optimize()
|
||||
|
|
Loading…
Reference in New Issue