From 68794468f22298c3e07058078feca0a194ecb25b Mon Sep 17 00:00:00 2001 From: John Lambert Date: Thu, 12 Aug 2021 08:06:12 -0400 Subject: [PATCH] clean up plot.py with modern type hints --- python/gtsam/utils/plot.py | 115 ++++++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 45 deletions(-) diff --git a/python/gtsam/utils/plot.py b/python/gtsam/utils/plot.py index 7f48d03a3..9e74cf38e 100644 --- a/python/gtsam/utils/plot.py +++ b/python/gtsam/utils/plot.py @@ -2,22 +2,25 @@ # pylint: disable=no-member, invalid-name +from typing import Iterable, Tuple + import matplotlib.pyplot as plt import numpy as np from matplotlib import patches from mpl_toolkits.mplot3d import Axes3D # pylint: disable=unused-import import gtsam +from gtsam import Marginals, Point3, Pose2, Values -def set_axes_equal(fignum): +def set_axes_equal(fignum: int) -> None: """ Make axes of 3D plot have equal scale so that spheres appear as spheres, cubes as cubes, etc.. This is one possible solution to Matplotlib's ax.set_aspect('equal') and ax.axis('equal') not working for 3D. Args: - fignum (int): An integer representing the figure number for Matplotlib. + fignum: An integer representing the figure number for Matplotlib. """ fig = plt.figure(fignum) ax = fig.gca(projection='3d') @@ -36,18 +39,20 @@ def set_axes_equal(fignum): ax.set_zlim3d([origin[2] - radius, origin[2] + radius]) -def ellipsoid(rx, ry, rz, n): +def ellipsoid( + rx: float, ry: float, rz: float, n: int +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Numpy equivalent of Matlab's ellipsoid function. Args: - rx (double): Radius of ellipsoid in X-axis. - ry (double): Radius of ellipsoid in Y-axis. - rz (double): Radius of ellipsoid in Z-axis. - n (int): The granularity of the ellipsoid plotted. + rx: Radius of ellipsoid in X-axis. + ry: Radius of ellipsoid in Y-axis. + rz: Radius of ellipsoid in Z-axis. + n: The granularity of the ellipsoid plotted. Returns: - tuple[numpy.ndarray]: The points in the x, y and z axes to use for the surface plot. + The points in the x, y and z axes to use for the surface plot. """ u = np.linspace(0, 2*np.pi, n+1) v = np.linspace(0, np.pi, n+1) @@ -58,7 +63,9 @@ def ellipsoid(rx, ry, rz, n): return x, y, z -def plot_covariance_ellipse_3d(axes, origin, P, scale=1, n=8, alpha=0.5): +def plot_covariance_ellipse_3d( + axes, origin: Point3, P, scale: float = 1, n: int = 8, alpha: float = 0.5 +) -> None: """ Plots a Gaussian as an uncertainty ellipse @@ -68,12 +75,12 @@ def plot_covariance_ellipse_3d(axes, origin, P, scale=1, n=8, alpha=0.5): Args: axes (matplotlib.axes.Axes): Matplotlib axes. - origin (gtsam.Point3): The origin in the world frame. + origin: The origin in the world frame. P (numpy.ndarray): The marginal covariance matrix of the 3D point which will be represented as an ellipse. - scale (float): Scaling factor of the radii of the covariance ellipse. - n (int): Defines the granularity of the ellipse. Higher values indicate finer ellipses. - alpha (float): Transparency value for the plotted surface in the range [0, 1]. + scale: Scaling factor of the radii of the covariance ellipse. + n: Defines the granularity of the ellipse. Higher values indicate finer ellipses. + alpha: Transparency value for the plotted surface in the range [0, 1]. """ k = 11.82 U, S, _ = np.linalg.svd(P) @@ -96,14 +103,16 @@ def plot_covariance_ellipse_3d(axes, origin, P, scale=1, n=8, alpha=0.5): axes.plot_surface(x, y, z, alpha=alpha, cmap='hot') -def plot_pose2_on_axes(axes, pose, axis_length=0.1, covariance=None): +def plot_pose2_on_axes( + axes, pose: Pose2, axis_length: float = 0.1, covariance: np.ndarray = None +) -> None: """ Plot a 2D pose on given axis `axes` with given `axis_length`. Args: axes (matplotlib.axes.Axes): Matplotlib axes. - pose (gtsam.Pose2): The pose to be plotted. - axis_length (float): The length of the camera axes. + pose: The pose to be plotted. + axis_length: The length of the camera axes. covariance (numpy.ndarray): Marginal covariance matrix to plot the uncertainty of the estimation. """ @@ -136,16 +145,21 @@ def plot_pose2_on_axes(axes, pose, axis_length=0.1, covariance=None): axes.add_patch(e1) -def plot_pose2(fignum, pose, axis_length=0.1, covariance=None, - axis_labels=('X axis', 'Y axis', 'Z axis')): +def plot_pose2( + fignum: int, + pose: Pose2, + axis_length: float = 0.1, + covariance: np.ndarray = None, + axis_labels=("X axis", "Y axis", "Z axis"), +) -> plt.Figure: """ Plot a 2D pose on given figure with given `axis_length`. Args: - fignum (int): Integer representing the figure number to use for plotting. - pose (gtsam.Pose2): The pose to be plotted. - axis_length (float): The length of the camera axes. - covariance (numpy.ndarray): Marginal covariance matrix to plot + fignum: Integer representing the figure number to use for plotting. + pose: The pose to be plotted. + axis_length: The length of the camera axes. + covariance: Marginal covariance matrix to plot the uncertainty of the estimation. axis_labels (iterable[string]): List of axis labels to set. """ @@ -176,17 +190,17 @@ def plot_point3_on_axes(axes, point, linespec, P=None): plot_covariance_ellipse_3d(axes, point, P) -def plot_point3(fignum, point, linespec, P=None, - axis_labels=('X axis', 'Y axis', 'Z axis')): +def plot_point3(fignum: int, point: Point3, linespec: str, P: np.ndarray = None, + axis_labels: Iterable[str] = ('X axis', 'Y axis', 'Z axis')) -> plt.Figure: """ Plot a 3D point on given figure with given `linespec`. Args: - fignum (int): Integer representing the figure number to use for plotting. - point (gtsam.Point3): The point to be plotted. - linespec (string): String representing formatting options for Matplotlib. - P (numpy.ndarray): Marginal covariance matrix to plot the uncertainty of the estimation. - axis_labels (iterable[string]): List of axis labels to set. + fignum: Integer representing the figure number to use for plotting. + point: The point to be plotted. + linespec: String representing formatting options for Matplotlib. + P: Marginal covariance matrix to plot the uncertainty of the estimation. + axis_labels: List of axis labels to set. Returns: fig: The matplotlib figure. @@ -308,18 +322,24 @@ def plot_pose3(fignum, pose, axis_length=0.1, P=None, return fig -def plot_trajectory(fignum, values, scale=1, marginals=None, - title="Plot Trajectory", axis_labels=('X axis', 'Y axis', 'Z axis')): +def plot_trajectory( + fignum: int, + values: Values, + scale: float = 1, + marginals: Marginals = None, + title: str = "Plot Trajectory", + axis_labels: Iterable[str] = ("X axis", "Y axis", "Z axis"), +) -> None: """ Plot a complete 2D/3D trajectory using poses in `values`. Args: - fignum (int): Integer representing the figure number to use for plotting. - values (gtsam.Values): Values containing some Pose2 and/or Pose3 values. - scale (float): Value to scale the poses by. - marginals (gtsam.Marginals): Marginalized probability values of the estimation. + fignum: Integer representing the figure number to use for plotting. + values: Values containing some Pose2 and/or Pose3 values. + scale: Value to scale the poses by. + marginals: Marginalized probability values of the estimation. Used to plot uncertainty bounds. - title (string): The title of the plot. + title: The title of the plot. axis_labels (iterable[string]): List of axis labels to set. """ fig = plt.figure(fignum) @@ -357,20 +377,25 @@ def plot_trajectory(fignum, values, scale=1, marginals=None, fig.canvas.set_window_title(title.lower()) -def plot_incremental_trajectory(fignum, values, start=0, - scale=1, marginals=None, - time_interval=0.0): +def plot_incremental_trajectory( + fignum: int, + values: Values, + start: int = 0, + scale: float = 1, + marginals: Marginals = None, + time_interval: float = 0.0 +) -> None: """ Incrementally plot a complete 3D trajectory using poses in `values`. Args: - fignum (int): Integer representing the figure number to use for plotting. - values (gtsam.Values): Values dict containing the poses. - start (int): Starting index to start plotting from. - scale (float): Value to scale the poses by. - marginals (gtsam.Marginals): Marginalized probability values of the estimation. + fignum: Integer representing the figure number to use for plotting. + values: Values dict containing the poses. + start: Starting index to start plotting from. + scale: Value to scale the poses by. + marginals: Marginalized probability values of the estimation. Used to plot uncertainty bounds. - time_interval (float): Time in seconds to pause between each rendering. + time_interval: Time in seconds to pause between each rendering. Used to create animation effect. """ fig = plt.figure(fignum)