clean up plot.py with modern type hints

release/4.3a0
John Lambert 2021-08-12 08:06:12 -04:00 committed by GitHub
parent 678d1c7270
commit 68794468f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 70 additions and 45 deletions

View File

@ -2,22 +2,25 @@
# pylint: disable=no-member, invalid-name # pylint: disable=no-member, invalid-name
from typing import Iterable, Tuple
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
from matplotlib import patches from matplotlib import patches
from mpl_toolkits.mplot3d import Axes3D # pylint: disable=unused-import from mpl_toolkits.mplot3d import Axes3D # pylint: disable=unused-import
import gtsam import gtsam
from gtsam import Marginals, Point3, Pose2, Values
def set_axes_equal(fignum): def set_axes_equal(fignum: int) -> None:
""" """
Make axes of 3D plot have equal scale so that spheres appear as spheres, Make axes of 3D plot have equal scale so that spheres appear as spheres,
cubes as cubes, etc.. This is one possible solution to Matplotlib's cubes as cubes, etc.. This is one possible solution to Matplotlib's
ax.set_aspect('equal') and ax.axis('equal') not working for 3D. ax.set_aspect('equal') and ax.axis('equal') not working for 3D.
Args: Args:
fignum (int): An integer representing the figure number for Matplotlib. fignum: An integer representing the figure number for Matplotlib.
""" """
fig = plt.figure(fignum) fig = plt.figure(fignum)
ax = fig.gca(projection='3d') ax = fig.gca(projection='3d')
@ -36,18 +39,20 @@ def set_axes_equal(fignum):
ax.set_zlim3d([origin[2] - radius, origin[2] + radius]) ax.set_zlim3d([origin[2] - radius, origin[2] + radius])
def ellipsoid(rx, ry, rz, n): def ellipsoid(
rx: float, ry: float, rz: float, n: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
""" """
Numpy equivalent of Matlab's ellipsoid function. Numpy equivalent of Matlab's ellipsoid function.
Args: Args:
rx (double): Radius of ellipsoid in X-axis. rx: Radius of ellipsoid in X-axis.
ry (double): Radius of ellipsoid in Y-axis. ry: Radius of ellipsoid in Y-axis.
rz (double): Radius of ellipsoid in Z-axis. rz: Radius of ellipsoid in Z-axis.
n (int): The granularity of the ellipsoid plotted. n: The granularity of the ellipsoid plotted.
Returns: Returns:
tuple[numpy.ndarray]: The points in the x, y and z axes to use for the surface plot. The points in the x, y and z axes to use for the surface plot.
""" """
u = np.linspace(0, 2*np.pi, n+1) u = np.linspace(0, 2*np.pi, n+1)
v = np.linspace(0, np.pi, n+1) v = np.linspace(0, np.pi, n+1)
@ -58,7 +63,9 @@ def ellipsoid(rx, ry, rz, n):
return x, y, z return x, y, z
def plot_covariance_ellipse_3d(axes, origin, P, scale=1, n=8, alpha=0.5): def plot_covariance_ellipse_3d(
axes, origin: Point3, P, scale: float = 1, n: int = 8, alpha: float = 0.5
) -> None:
""" """
Plots a Gaussian as an uncertainty ellipse Plots a Gaussian as an uncertainty ellipse
@ -68,12 +75,12 @@ def plot_covariance_ellipse_3d(axes, origin, P, scale=1, n=8, alpha=0.5):
Args: Args:
axes (matplotlib.axes.Axes): Matplotlib axes. axes (matplotlib.axes.Axes): Matplotlib axes.
origin (gtsam.Point3): The origin in the world frame. origin: The origin in the world frame.
P (numpy.ndarray): The marginal covariance matrix of the 3D point P (numpy.ndarray): The marginal covariance matrix of the 3D point
which will be represented as an ellipse. which will be represented as an ellipse.
scale (float): Scaling factor of the radii of the covariance ellipse. scale: Scaling factor of the radii of the covariance ellipse.
n (int): Defines the granularity of the ellipse. Higher values indicate finer ellipses. n: Defines the granularity of the ellipse. Higher values indicate finer ellipses.
alpha (float): Transparency value for the plotted surface in the range [0, 1]. alpha: Transparency value for the plotted surface in the range [0, 1].
""" """
k = 11.82 k = 11.82
U, S, _ = np.linalg.svd(P) U, S, _ = np.linalg.svd(P)
@ -96,14 +103,16 @@ def plot_covariance_ellipse_3d(axes, origin, P, scale=1, n=8, alpha=0.5):
axes.plot_surface(x, y, z, alpha=alpha, cmap='hot') axes.plot_surface(x, y, z, alpha=alpha, cmap='hot')
def plot_pose2_on_axes(axes, pose, axis_length=0.1, covariance=None): def plot_pose2_on_axes(
axes, pose: Pose2, axis_length: float = 0.1, covariance: np.ndarray = None
) -> None:
""" """
Plot a 2D pose on given axis `axes` with given `axis_length`. Plot a 2D pose on given axis `axes` with given `axis_length`.
Args: Args:
axes (matplotlib.axes.Axes): Matplotlib axes. axes (matplotlib.axes.Axes): Matplotlib axes.
pose (gtsam.Pose2): The pose to be plotted. pose: The pose to be plotted.
axis_length (float): The length of the camera axes. axis_length: The length of the camera axes.
covariance (numpy.ndarray): Marginal covariance matrix to plot covariance (numpy.ndarray): Marginal covariance matrix to plot
the uncertainty of the estimation. the uncertainty of the estimation.
""" """
@ -136,16 +145,21 @@ def plot_pose2_on_axes(axes, pose, axis_length=0.1, covariance=None):
axes.add_patch(e1) axes.add_patch(e1)
def plot_pose2(fignum, pose, axis_length=0.1, covariance=None, def plot_pose2(
axis_labels=('X axis', 'Y axis', 'Z axis')): fignum: int,
pose: Pose2,
axis_length: float = 0.1,
covariance: np.ndarray = None,
axis_labels=("X axis", "Y axis", "Z axis"),
) -> plt.Figure:
""" """
Plot a 2D pose on given figure with given `axis_length`. Plot a 2D pose on given figure with given `axis_length`.
Args: Args:
fignum (int): Integer representing the figure number to use for plotting. fignum: Integer representing the figure number to use for plotting.
pose (gtsam.Pose2): The pose to be plotted. pose: The pose to be plotted.
axis_length (float): The length of the camera axes. axis_length: The length of the camera axes.
covariance (numpy.ndarray): Marginal covariance matrix to plot covariance: Marginal covariance matrix to plot
the uncertainty of the estimation. the uncertainty of the estimation.
axis_labels (iterable[string]): List of axis labels to set. axis_labels (iterable[string]): List of axis labels to set.
""" """
@ -176,17 +190,17 @@ def plot_point3_on_axes(axes, point, linespec, P=None):
plot_covariance_ellipse_3d(axes, point, P) plot_covariance_ellipse_3d(axes, point, P)
def plot_point3(fignum, point, linespec, P=None, def plot_point3(fignum: int, point: Point3, linespec: str, P: np.ndarray = None,
axis_labels=('X axis', 'Y axis', 'Z axis')): axis_labels: Iterable[str] = ('X axis', 'Y axis', 'Z axis')) -> plt.Figure:
""" """
Plot a 3D point on given figure with given `linespec`. Plot a 3D point on given figure with given `linespec`.
Args: Args:
fignum (int): Integer representing the figure number to use for plotting. fignum: Integer representing the figure number to use for plotting.
point (gtsam.Point3): The point to be plotted. point: The point to be plotted.
linespec (string): String representing formatting options for Matplotlib. linespec: String representing formatting options for Matplotlib.
P (numpy.ndarray): Marginal covariance matrix to plot the uncertainty of the estimation. P: Marginal covariance matrix to plot the uncertainty of the estimation.
axis_labels (iterable[string]): List of axis labels to set. axis_labels: List of axis labels to set.
Returns: Returns:
fig: The matplotlib figure. fig: The matplotlib figure.
@ -308,18 +322,24 @@ def plot_pose3(fignum, pose, axis_length=0.1, P=None,
return fig return fig
def plot_trajectory(fignum, values, scale=1, marginals=None, def plot_trajectory(
title="Plot Trajectory", axis_labels=('X axis', 'Y axis', 'Z axis')): fignum: int,
values: Values,
scale: float = 1,
marginals: Marginals = None,
title: str = "Plot Trajectory",
axis_labels: Iterable[str] = ("X axis", "Y axis", "Z axis"),
) -> None:
""" """
Plot a complete 2D/3D trajectory using poses in `values`. Plot a complete 2D/3D trajectory using poses in `values`.
Args: Args:
fignum (int): Integer representing the figure number to use for plotting. fignum: Integer representing the figure number to use for plotting.
values (gtsam.Values): Values containing some Pose2 and/or Pose3 values. values: Values containing some Pose2 and/or Pose3 values.
scale (float): Value to scale the poses by. scale: Value to scale the poses by.
marginals (gtsam.Marginals): Marginalized probability values of the estimation. marginals: Marginalized probability values of the estimation.
Used to plot uncertainty bounds. Used to plot uncertainty bounds.
title (string): The title of the plot. title: The title of the plot.
axis_labels (iterable[string]): List of axis labels to set. axis_labels (iterable[string]): List of axis labels to set.
""" """
fig = plt.figure(fignum) fig = plt.figure(fignum)
@ -357,20 +377,25 @@ def plot_trajectory(fignum, values, scale=1, marginals=None,
fig.canvas.set_window_title(title.lower()) fig.canvas.set_window_title(title.lower())
def plot_incremental_trajectory(fignum, values, start=0, def plot_incremental_trajectory(
scale=1, marginals=None, fignum: int,
time_interval=0.0): values: Values,
start: int = 0,
scale: float = 1,
marginals: Marginals = None,
time_interval: float = 0.0
) -> None:
""" """
Incrementally plot a complete 3D trajectory using poses in `values`. Incrementally plot a complete 3D trajectory using poses in `values`.
Args: Args:
fignum (int): Integer representing the figure number to use for plotting. fignum: Integer representing the figure number to use for plotting.
values (gtsam.Values): Values dict containing the poses. values: Values dict containing the poses.
start (int): Starting index to start plotting from. start: Starting index to start plotting from.
scale (float): Value to scale the poses by. scale: Value to scale the poses by.
marginals (gtsam.Marginals): Marginalized probability values of the estimation. marginals: Marginalized probability values of the estimation.
Used to plot uncertainty bounds. Used to plot uncertainty bounds.
time_interval (float): Time in seconds to pause between each rendering. time_interval: Time in seconds to pause between each rendering.
Used to create animation effect. Used to create animation effect.
""" """
fig = plt.figure(fignum) fig = plt.figure(fignum)