Merge pull request #345 from borglab/feature/logging_optimizer
Add logging (hooked) optimizerrelease/4.3a0
						commit
						18e80b83aa
					
				| 
						 | 
				
			
			@ -0,0 +1,79 @@
 | 
			
		|||
"""
 | 
			
		||||
Unit tests for optimization that logs to comet.ml.
 | 
			
		||||
Author: Jing Wu and Frank Dellaert
 | 
			
		||||
"""
 | 
			
		||||
# pylint: disable=invalid-name
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
 | 
			
		||||
import gtsam
 | 
			
		||||
import numpy as np
 | 
			
		||||
from gtsam import Rot3
 | 
			
		||||
from gtsam.utils.test_case import GtsamTestCase
 | 
			
		||||
 | 
			
		||||
from gtsam.utils.logging_optimizer import gtsam_optimize
 | 
			
		||||
 | 
			
		||||
KEY = 0
 | 
			
		||||
MODEL = gtsam.noiseModel_Unit.Create(3)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestOptimizeComet(GtsamTestCase):
 | 
			
		||||
    """Check correct logging to comet.ml."""
 | 
			
		||||
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        """Set up a small Karcher mean optimization example."""
 | 
			
		||||
        # Grabbed from KarcherMeanFactor unit tests.
 | 
			
		||||
        R = Rot3.Expmap(np.array([0.1, 0, 0]))
 | 
			
		||||
        rotations = {R, R.inverse()}  # mean is the identity
 | 
			
		||||
        self.expected = Rot3()
 | 
			
		||||
 | 
			
		||||
        graph = gtsam.NonlinearFactorGraph()
 | 
			
		||||
        for R in rotations:
 | 
			
		||||
            graph.add(gtsam.PriorFactorRot3(KEY, R, MODEL))
 | 
			
		||||
        initial = gtsam.Values()
 | 
			
		||||
        initial.insert(KEY, R)
 | 
			
		||||
        self.params = gtsam.GaussNewtonParams()
 | 
			
		||||
        self.optimizer = gtsam.GaussNewtonOptimizer(
 | 
			
		||||
            graph, initial, self.params)
 | 
			
		||||
 | 
			
		||||
    def test_simple_printing(self):
 | 
			
		||||
        """Test with a simple hook."""
 | 
			
		||||
 | 
			
		||||
        # Provide a hook that just prints
 | 
			
		||||
        def hook(_, error: float):
 | 
			
		||||
            print(error)
 | 
			
		||||
 | 
			
		||||
        # Only thing we require from optimizer is an iterate method
 | 
			
		||||
        gtsam_optimize(self.optimizer, self.params, hook)
 | 
			
		||||
 | 
			
		||||
        # Check that optimizing yields the identity.
 | 
			
		||||
        actual = self.optimizer.values()
 | 
			
		||||
        self.gtsamAssertEquals(actual.atRot3(KEY), self.expected, tol=1e-6)
 | 
			
		||||
 | 
			
		||||
    @unittest.skip("Not a test we want run every time, as needs comet.ml account")
 | 
			
		||||
    def test_comet(self):
 | 
			
		||||
        """Test with a comet hook."""
 | 
			
		||||
        from comet_ml import Experiment
 | 
			
		||||
        comet = Experiment(project_name="Testing",
 | 
			
		||||
                           auto_output_logging="native")
 | 
			
		||||
        comet.log_dataset_info(name="Karcher", path="shonan")
 | 
			
		||||
        comet.add_tag("GaussNewton")
 | 
			
		||||
        comet.log_parameter("method", "GaussNewton")
 | 
			
		||||
        time = datetime.now()
 | 
			
		||||
        comet.set_name("GaussNewton-" + str(time.month) + "/" + str(time.day) + " "
 | 
			
		||||
                       + str(time.hour)+":"+str(time.minute)+":"+str(time.second))
 | 
			
		||||
 | 
			
		||||
        # I want to do some comet thing here
 | 
			
		||||
        def hook(optimizer, error: float):
 | 
			
		||||
            comet.log_metric("Karcher error",
 | 
			
		||||
                             error, optimizer.iterations())
 | 
			
		||||
 | 
			
		||||
        gtsam_optimize(self.optimizer, self.params, hook)
 | 
			
		||||
        comet.end()
 | 
			
		||||
 | 
			
		||||
        actual = self.optimizer.values()
 | 
			
		||||
        self.gtsamAssertEquals(actual.atRot3(KEY), self.expected)
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,54 @@
 | 
			
		|||
"""
 | 
			
		||||
Optimization with logging via a hook.
 | 
			
		||||
Author: Jing Wu and Frank Dellaert
 | 
			
		||||
"""
 | 
			
		||||
# pylint: disable=invalid-name
 | 
			
		||||
 | 
			
		||||
from typing import TypeVar
 | 
			
		||||
 | 
			
		||||
from gtsam import NonlinearOptimizer, NonlinearOptimizerParams
 | 
			
		||||
import gtsam
 | 
			
		||||
 | 
			
		||||
T = TypeVar('T')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def optimize(optimizer: T, check_convergence, hook):
 | 
			
		||||
    """ Given an optimizer and a convergence check, iterate until convergence.
 | 
			
		||||
        After each iteration, hook(optimizer, error) is called.
 | 
			
		||||
        After the function, use values and errors to get the result.
 | 
			
		||||
        Arguments:
 | 
			
		||||
            optimizer (T): needs an iterate and an error function.
 | 
			
		||||
            check_convergence: T * float * float -> bool
 | 
			
		||||
            hook -- hook function to record the error
 | 
			
		||||
    """
 | 
			
		||||
    # the optimizer is created with default values which incur the error below
 | 
			
		||||
    current_error = optimizer.error()
 | 
			
		||||
    hook(optimizer, current_error)
 | 
			
		||||
 | 
			
		||||
    # Iterative loop
 | 
			
		||||
    while True:
 | 
			
		||||
        # Do next iteration
 | 
			
		||||
        optimizer.iterate()
 | 
			
		||||
        new_error = optimizer.error()
 | 
			
		||||
        hook(optimizer, new_error)
 | 
			
		||||
        if check_convergence(optimizer, current_error, new_error):
 | 
			
		||||
            return
 | 
			
		||||
        current_error = new_error
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def gtsam_optimize(optimizer: NonlinearOptimizer,
 | 
			
		||||
                   params: NonlinearOptimizerParams,
 | 
			
		||||
                   hook):
 | 
			
		||||
    """ Given an optimizer and params, iterate until convergence.
 | 
			
		||||
        After each iteration, hook(optimizer) is called.
 | 
			
		||||
        After the function, use values and errors to get the result.
 | 
			
		||||
        Arguments:
 | 
			
		||||
                optimizer {NonlinearOptimizer} -- Nonlinear optimizer
 | 
			
		||||
                params {NonlinearOptimizarParams} -- Nonlinear optimizer parameters
 | 
			
		||||
                hook -- hook function to record the error
 | 
			
		||||
    """
 | 
			
		||||
    def check_convergence(optimizer, current_error, new_error):
 | 
			
		||||
        return (optimizer.iterations() >= params.getMaxIterations()) or (
 | 
			
		||||
            gtsam.checkConvergence(params.getRelativeErrorTol(), params.getAbsoluteErrorTol(), params.getErrorTol(),
 | 
			
		||||
                                   current_error, new_error))
 | 
			
		||||
    optimize(optimizer, check_convergence, hook)
 | 
			
		||||
		Loading…
	
		Reference in New Issue