Merge pull request #345 from borglab/feature/logging_optimizer
Add logging (hooked) optimizerrelease/4.3a0
						commit
						18e80b83aa
					
				|  | @ -0,0 +1,79 @@ | |||
| """ | ||||
| Unit tests for optimization that logs to comet.ml. | ||||
| Author: Jing Wu and Frank Dellaert | ||||
| """ | ||||
| # pylint: disable=invalid-name | ||||
| 
 | ||||
| import unittest | ||||
| from datetime import datetime | ||||
| 
 | ||||
| import gtsam | ||||
| import numpy as np | ||||
| from gtsam import Rot3 | ||||
| from gtsam.utils.test_case import GtsamTestCase | ||||
| 
 | ||||
| from gtsam.utils.logging_optimizer import gtsam_optimize | ||||
| 
 | ||||
| KEY = 0 | ||||
| MODEL = gtsam.noiseModel_Unit.Create(3) | ||||
| 
 | ||||
| 
 | ||||
| class TestOptimizeComet(GtsamTestCase): | ||||
|     """Check correct logging to comet.ml.""" | ||||
| 
 | ||||
|     def setUp(self): | ||||
|         """Set up a small Karcher mean optimization example.""" | ||||
|         # Grabbed from KarcherMeanFactor unit tests. | ||||
|         R = Rot3.Expmap(np.array([0.1, 0, 0])) | ||||
|         rotations = {R, R.inverse()}  # mean is the identity | ||||
|         self.expected = Rot3() | ||||
| 
 | ||||
|         graph = gtsam.NonlinearFactorGraph() | ||||
|         for R in rotations: | ||||
|             graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL)) | ||||
|         initial = gtsam.Values() | ||||
|         initial.insert(KEY, R) | ||||
|         self.params = gtsam.GaussNewtonParams() | ||||
|         self.optimizer = gtsam.GaussNewtonOptimizer( | ||||
|             graph, initial, self.params) | ||||
| 
 | ||||
|     def test_simple_printing(self): | ||||
|         """Test with a simple hook.""" | ||||
| 
 | ||||
|         # Provide a hook that just prints | ||||
|         def hook(_, error: float): | ||||
|             print(error) | ||||
| 
 | ||||
|         # Only thing we require from optimizer is an iterate method | ||||
|         gtsam_optimize(self.optimizer, self.params, hook) | ||||
| 
 | ||||
|         # Check that optimizing yields the identity. | ||||
|         actual = self.optimizer.values() | ||||
|         self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6) | ||||
| 
 | ||||
|     @unittest.skip("Not a test we want run every time, as needs comet.ml account") | ||||
|     def test_comet(self): | ||||
|         """Test with a comet hook.""" | ||||
|         from comet_ml import Experiment | ||||
|         comet = Experiment(project_name="Testing", | ||||
|                            auto_output_logging="native") | ||||
|         comet.log_dataset_info(name="Karcher", path="shonan") | ||||
|         comet.add_tag("GaussNewton") | ||||
|         comet.log_parameter("method", "GaussNewton") | ||||
|         time = datetime.now() | ||||
|         comet.set_name("GaussNewton-" + str(time.month) + "/" + str(time.day) + " " | ||||
|                        + str(time.hour)+":"+str(time.minute)+":"+str(time.second)) | ||||
| 
 | ||||
|         # I want to do some comet thing here | ||||
|         def hook(optimizer, error: float): | ||||
|             comet.log_metric("Karcher error", | ||||
|                              error, optimizer.iterations()) | ||||
| 
 | ||||
|         gtsam_optimize(self.optimizer, self.params, hook) | ||||
|         comet.end() | ||||
| 
 | ||||
|         actual = self.optimizer.values() | ||||
|         self.gtsamAssertEquals(actual.atRot3(KEY), self.expected) | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
|  | @ -0,0 +1,54 @@ | |||
| """ | ||||
| Optimization with logging via a hook. | ||||
| Author: Jing Wu and Frank Dellaert | ||||
| """ | ||||
| # pylint: disable=invalid-name | ||||
| 
 | ||||
| from typing import TypeVar | ||||
| 
 | ||||
| from gtsam import NonlinearOptimizer, NonlinearOptimizerParams | ||||
| import gtsam | ||||
| 
 | ||||
| T = TypeVar('T') | ||||
| 
 | ||||
| 
 | ||||
| def optimize(optimizer: T, check_convergence, hook): | ||||
|     """ Given an optimizer and a convergence check, iterate until convergence. | ||||
|         After each iteration, hook(optimizer, error) is called. | ||||
|         After the function, use values and errors to get the result. | ||||
|         Arguments: | ||||
|             optimizer (T): needs an iterate and an error function. | ||||
|             check_convergence: T * float * float -> bool | ||||
|             hook -- hook function to record the error | ||||
|     """ | ||||
|     # the optimizer is created with default values which incur the error below | ||||
|     current_error = optimizer.error() | ||||
|     hook(optimizer, current_error) | ||||
| 
 | ||||
|     # Iterative loop | ||||
|     while True: | ||||
|         # Do next iteration | ||||
|         optimizer.iterate() | ||||
|         new_error = optimizer.error() | ||||
|         hook(optimizer, new_error) | ||||
|         if check_convergence(optimizer, current_error, new_error): | ||||
|             return | ||||
|         current_error = new_error | ||||
| 
 | ||||
| 
 | ||||
| def gtsam_optimize(optimizer: NonlinearOptimizer, | ||||
|                    params: NonlinearOptimizerParams, | ||||
|                    hook): | ||||
|     """ Given an optimizer and params, iterate until convergence. | ||||
|         After each iteration, hook(optimizer) is called. | ||||
|         After the function, use values and errors to get the result. | ||||
|         Arguments: | ||||
|                 optimizer {NonlinearOptimizer} -- Nonlinear optimizer | ||||
|                 params {NonlinearOptimizarParams} -- Nonlinear optimizer parameters | ||||
|                 hook -- hook function to record the error | ||||
|     """ | ||||
|     def check_convergence(optimizer, current_error, new_error): | ||||
|         return (optimizer.iterations() >= params.getMaxIterations()) or ( | ||||
|             gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(), | ||||
|                                    current_error, new_error)) | ||||
|     optimize(optimizer, check_convergence, hook) | ||||
		Loading…
	
		Reference in New Issue