Skip to content

Commit

Permalink
Use correct matplotlib public interfaces for type hints. (#595)
Browse files Browse the repository at this point in the history
matplotlib/matplotlib#26812

Also fixes access to colors. Closes: #602
  • Loading branch information
MichaelGrupp authored Nov 14, 2023
1 parent 13d5588 commit aa115ab
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 36 deletions.
5 changes: 4 additions & 1 deletion evo/main_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import argparse
import datetime
import itertools
import logging
import os

Expand Down Expand Up @@ -274,6 +275,7 @@ def run(args):
from evo.tools import plot
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns

plot_collection = plot.PlotCollection("evo_traj - trajectory plot")
fig_xyz, axarr_xyz = plt.subplots(3, sharex="col",
Expand Down Expand Up @@ -331,10 +333,11 @@ def run(args):
if SETTINGS.plot_multi_cmap.lower() != "none":
cmap = getattr(cm, SETTINGS.plot_multi_cmap)
cmap_colors = iter(cmap(np.linspace(0, 1, len(trajectories))))
color_palette = itertools.cycle(sns.color_palette())

for name, traj in trajectories.items():
if cmap_colors is None:
color = next(ax_traj._get_lines.prop_cycler)['color']
color = next(color_palette)
else:
color = next(cmap_colors)

Expand Down
77 changes: 42 additions & 35 deletions evo/tools/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import collections
import collections.abc
import itertools
import logging
import pickle
import typing
Expand All @@ -35,6 +36,8 @@
import mpl_toolkits.mplot3d.art3d as art3d
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.backend_bases import FigureCanvasBase
from matplotlib.collections import LineCollection
from matplotlib.transforms import Affine2D, Bbox
Expand Down Expand Up @@ -111,7 +114,7 @@ def __init__(self, title: str = "",
def __str__(self) -> str:
return self.title + " (" + str(len(self.figures)) + " figure(s))"

def add_figure(self, name: str, fig: plt.Figure) -> None:
def add_figure(self, name: str, fig: Figure) -> None:
fig.tight_layout()
self.figures[name] = fig

Expand Down Expand Up @@ -235,7 +238,7 @@ def export(self, file_path: str, confirm_overwrite: bool = True) -> None:
logger.info("Plot saved to " + dest)


def set_aspect_equal(ax: plt.Axes) -> None:
def set_aspect_equal(ax: Axes) -> None:
"""
kudos to https://stackoverflow.com/a/35126679
:param ax: matplotlib 3D axes object
Expand Down Expand Up @@ -271,9 +274,9 @@ def formatter(x, _):
return formatter


def prepare_axis(fig: plt.Figure, plot_mode: PlotMode = PlotMode.xy,
def prepare_axis(fig: Figure, plot_mode: PlotMode = PlotMode.xy,
subplot_arg: int = 111,
length_unit: Unit = Unit.meters) -> plt.Axes:
length_unit: Unit = Unit.meters) -> Axes:
"""
prepares an axis according to the plot mode (for trajectory plotting)
:param fig: matplotlib figure object
Expand Down Expand Up @@ -304,7 +307,7 @@ def prepare_axis(fig: plt.Figure, plot_mode: PlotMode = PlotMode.xy,
ylabel = f"$z$ ({length_unit.value})"
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
if plot_mode == PlotMode.xyz:
if plot_mode == PlotMode.xyz and isinstance(ax, Axes3D):
ax.set_zlabel(f'$z$ ({length_unit.value})')
if SETTINGS.plot_invert_xaxis:
plt.gca().invert_xaxis()
Expand All @@ -317,7 +320,7 @@ def prepare_axis(fig: plt.Figure, plot_mode: PlotMode = PlotMode.xy,
formatter = _get_length_formatter(length_unit)
ax.xaxis.set_major_formatter(formatter)
ax.yaxis.set_major_formatter(formatter)
if plot_mode == PlotMode.xyz:
if plot_mode == PlotMode.xyz and isinstance(ax, Axes3D):
ax.zaxis.set_major_formatter(formatter)

return ax
Expand Down Expand Up @@ -347,10 +350,10 @@ def plot_mode_to_idx(
return x_idx, y_idx, z_idx


def add_start_end_markers(ax: plt.Axes, plot_mode: PlotMode,
def add_start_end_markers(ax: Axes, plot_mode: PlotMode,
traj: trajectory.PosePath3D, start_symbol: str = "o",
start_color: str = "black", end_symbol: str = "x",
end_color: str = "black", alpha: float = 1.0,
start_color="black", end_symbol: str = "x",
end_color="black", alpha: float = 1.0,
traj_name: typing.Optional[str] = None):
if traj.num_poses == 0:
return
Expand All @@ -364,15 +367,16 @@ def add_start_end_markers(ax: plt.Axes, plot_mode: PlotMode,
end_coords.append(end[z_idx])
start_label = f"Start of {traj_name}" if traj_name else None
end_label = f"End of {traj_name}" if traj_name else None
# TODO: mypy doesn't deal well with * unpack here for some reason.
ax.scatter(*start_coords, marker=start_symbol, color=start_color,
alpha=alpha, label=start_label)
alpha=alpha, label=start_label) # type: ignore[misc]
ax.scatter(*end_coords, marker=end_symbol, color=end_color, alpha=alpha,
label=end_label)
label=end_label) # type: ignore[misc]


def traj(ax: plt.Axes, plot_mode: PlotMode, traj: trajectory.PosePath3D,
style: str = '-', color: str = 'black', label: str = "",
alpha: float = 1.0, plot_start_end_markers: bool = False) -> None:
def traj(ax: Axes, plot_mode: PlotMode, traj: trajectory.PosePath3D,
style: str = '-', color='black', label: str = "", alpha: float = 1.0,
plot_start_end_markers: bool = False) -> None:
"""
plot a path/trajectory based on xyz coordinates into an axis
:param ax: the matplotlib axis
Expand Down Expand Up @@ -403,7 +407,7 @@ def traj(ax: plt.Axes, plot_mode: PlotMode, traj: trajectory.PosePath3D,


def colored_line_collection(
xyz: np.ndarray, colors: ListOrArray, plot_mode: PlotMode = PlotMode.xy,
xyz: np.ndarray, colors, plot_mode: PlotMode = PlotMode.xy,
linestyles: str = "solid", step: int = 1, alpha: float = 1.
) -> typing.Union[LineCollection, art3d.LineCollection]:
if step > 1 and len(xyz) / step != len(colors):
Expand All @@ -429,9 +433,9 @@ def colored_line_collection(
return line_collection


def traj_colormap(ax: plt.Axes, traj: trajectory.PosePath3D,
array: ListOrArray, plot_mode: PlotMode, min_map: float,
max_map: float, title: str = "",
def traj_colormap(ax: Axes, traj: trajectory.PosePath3D, array: ListOrArray,
plot_mode: PlotMode, min_map: float, max_map: float,
title: str = "",
fig: typing.Optional[mpl.figure.Figure] = None,
plot_start_end_markers: bool = False) -> None:
"""
Expand All @@ -454,11 +458,12 @@ def traj_colormap(ax: plt.Axes, traj: trajectory.PosePath3D,
norm=norm,
cmap=SETTINGS.plot_trajectory_cmap) # cm.*_r is reversed cmap
mapper.set_array(array)
colors = [mapper.to_rgba(a) for a in array]
# TODO: why does mypy complain about 'a' here, float is fine?
colors = [mapper.to_rgba(a) for a in array] # type: ignore[arg-type]
line_collection = colored_line_collection(pos, colors, plot_mode)
ax.add_collection(line_collection)
ax.autoscale_view(True, True, True)
if plot_mode == PlotMode.xyz:
if plot_mode == PlotMode.xyz and isinstance(ax, Axes3D):
ax.set_zlim(np.amin(traj.positions_xyz[:, 2]),
np.amax(traj.positions_xyz[:, 2]))
if SETTINGS.plot_xyz_realistic:
Expand All @@ -481,10 +486,9 @@ def traj_colormap(ax: plt.Axes, traj: trajectory.PosePath3D,
end_color=colors[-1])


def draw_coordinate_axes(ax: plt.Figure, traj: trajectory.PosePath3D,
def draw_coordinate_axes(ax: Axes, traj: trajectory.PosePath3D,
plot_mode: PlotMode, marker_scale: float = 0.1,
x_color: str = "r", y_color: str = "g",
z_color: str = "b") -> None:
x_color="r", y_color="g", z_color="b") -> None:
"""
Draws a coordinate frame axis for each pose of a trajectory.
:param ax: plot axis
Expand Down Expand Up @@ -521,10 +525,10 @@ def draw_coordinate_axes(ax: plt.Figure, traj: trajectory.PosePath3D,
ax.add_collection(markers)


def draw_correspondence_edges(ax: plt.Axes, traj_1: trajectory.PosePath3D,
def draw_correspondence_edges(ax: Axes, traj_1: trajectory.PosePath3D,
traj_2: trajectory.PosePath3D,
plot_mode: PlotMode, style: str = '-',
color: str = "black", alpha: float = 1.) -> None:
color="black", alpha: float = 1.) -> None:
"""
Draw edges between corresponding poses of two trajectories.
Trajectories must be synced, i.e. having the same number of poses.
Expand All @@ -550,7 +554,7 @@ def draw_correspondence_edges(ax: plt.Axes, traj_1: trajectory.PosePath3D,


def traj_xyz(axarr: np.ndarray, traj: trajectory.PosePath3D, style: str = '-',
color: str = 'black', label: str = "", alpha: float = 1.0,
color='black', label: str = "", alpha: float = 1.0,
start_timestamp: typing.Optional[float] = None,
length_unit: Unit = Unit.meters) -> None:
"""
Expand Down Expand Up @@ -599,7 +603,7 @@ def traj_xyz(axarr: np.ndarray, traj: trajectory.PosePath3D, style: str = '-',


def traj_rpy(axarr: np.ndarray, traj: trajectory.PosePath3D, style: str = '-',
color: str = 'black', label: str = "", alpha: float = 1.0,
color='black', label: str = "", alpha: float = 1.0,
start_timestamp: typing.Optional[float] = None) -> None:
"""
plot a path/trajectory's Euler RPY angles into an axis
Expand Down Expand Up @@ -636,7 +640,7 @@ def traj_rpy(axarr: np.ndarray, traj: trajectory.PosePath3D, style: str = '-',
axarr[0].legend(frameon=True)


def trajectories(fig: plt.Figure, trajectories: typing.Union[
def trajectories(fig: Figure, trajectories: typing.Union[
trajectory.PosePath3D, typing.Sequence[trajectory.PosePath3D],
typing.Dict[str, trajectory.PosePath3D]], plot_mode=PlotMode.xy,
title: str = "", subplot_arg: int = 111,
Expand Down Expand Up @@ -665,10 +669,12 @@ def trajectories(fig: plt.Figure, trajectories: typing.Union[
cmap = getattr(cm, SETTINGS.plot_multi_cmap)
cmap_colors = iter(cmap(np.linspace(0, 1, len(trajectories))))

color_palette = itertools.cycle(sns.color_palette())

# helper function
def draw(t, name=""):
if cmap_colors is None:
color = next(ax._get_lines.prop_cycler)['color']
color = next(color_palette)
else:
color = next(cmap_colors)
if SETTINGS.plot_usetex:
Expand All @@ -686,12 +692,12 @@ def draw(t, name=""):
draw(t)


def error_array(ax: plt.Axes, err_array: ListOrArray,
def error_array(ax: Axes, err_array: ListOrArray,
x_array: typing.Optional[ListOrArray] = None,
statistics: typing.Optional[typing.Dict[str, float]] = None,
threshold: typing.Optional[float] = None,
cumulative: bool = False, color: str = 'grey',
name: str = "error", title: str = "", xlabel: str = "index",
cumulative: bool = False, color='grey', name: str = "error",
title: str = "", xlabel: str = "index",
ylabel: typing.Optional[str] = None, subplot_arg: int = 111,
linestyle: str = "-", marker: typing.Optional[str] = None):
"""
Expand Down Expand Up @@ -724,9 +730,10 @@ def error_array(ax: plt.Axes, err_array: ListOrArray,
else:
ax.plot(err_array, linestyle=linestyle, marker=marker, color=color,
label=name)
color_pallete = itertools.cycle(sns.color_palette())
if statistics is not None:
for stat_name, value in statistics.items():
color = next(ax._get_lines.prop_cycler)['color']
color = next(color_pallete)
if stat_name == "std" and "mean" in statistics:
mean, std = statistics["mean"], statistics["std"]
ax.axhspan(mean - std / 2, mean + std / 2, color=color,
Expand All @@ -744,7 +751,7 @@ def error_array(ax: plt.Axes, err_array: ListOrArray,


def ros_map(
ax: plt.Axes, yaml_path: str, plot_mode: PlotMode,
ax: Axes, yaml_path: str, plot_mode: PlotMode,
cmap: str = SETTINGS.ros_map_cmap,
mask_unknown_value: typing.Optional[int] = (
SETTINGS.ros_map_unknown_cell_value if SETTINGS.ros_map_enable_masking
Expand Down Expand Up @@ -815,7 +822,7 @@ def ros_map(
n_rows, n_cols = image.shape[x_idx], image.shape[y_idx]
metric_width = n_cols * resolution
metric_height = n_rows * resolution
extent = [0, metric_width, 0, metric_height]
extent = (0, metric_width, 0, metric_height)
if plot_mode == PlotMode.yx:
image = np.rot90(image)
image = np.fliplr(image)
Expand Down

0 comments on commit aa115ab

Please sign in to comment.