Skip to content

Commit

Permalink
Simplify return values of _calculate_griddata
Browse files Browse the repository at this point in the history
  • Loading branch information
not522 committed Dec 16, 2024
1 parent 3a2298b commit 2e143f2
Showing 1 changed file with 6 additions and 48 deletions.
54 changes: 6 additions & 48 deletions optuna/visualization/matplotlib/_contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,19 +190,7 @@ def _calculate_axis_data(
return ci, cat_param_labels, cat_param_pos, list(returned_values)


def _calculate_griddata(
info: _SubContourInfo,
) -> tuple[
np.ndarray,
np.ndarray,
np.ndarray,
list[int],
list[str],
list[int],
list[str],
_PlotValues,
_PlotValues,
]:
def _calculate_griddata(info: _SubContourInfo) -> tuple[np.ndarray, _PlotValues, _PlotValues]:
xaxis = info.xaxis
yaxis = info.yaxis
z_values_dict = info.z_values
Expand All @@ -220,17 +208,7 @@ def _calculate_griddata(

# Return empty values when x or y has no value.
if len(x_values) == 0 or len(y_values) == 0:
return (
np.array([]),
np.array([]),
np.array([]),
[],
[],
[],
[],
_PlotValues([], []),
_PlotValues([], []),
)
return np.array([]), _PlotValues([], []), _PlotValues([], [])

xi, cat_param_labels_x, cat_param_pos_x, transformed_x_values = _calculate_axis_data(
xaxis,
Expand Down Expand Up @@ -261,17 +239,7 @@ def _calculate_griddata(
infeasible.x.append(x_value)
infeasible.y.append(y_value)

return (
xi,
yi,
zi,
cat_param_pos_x,
cat_param_labels_x,
cat_param_pos_y,
cat_param_labels_y,
feasible,
infeasible,
)
return zi, feasible, infeasible


def _generate_contour_subplot(
Expand All @@ -286,14 +254,14 @@ def _generate_contour_subplot(
ax.set_xlim(info.xaxis.range[0], info.xaxis.range[1])
ax.set_ylim(info.yaxis.range[0], info.yaxis.range[1])
x_values, y_values = _filter_missing_values(info.xaxis, info.yaxis)
xi, x_cat_param_label, x_cat_param_pos, _ = _calculate_axis_data(info.xaxis, x_values)
yi, y_cat_param_label, y_cat_param_pos, _ = _calculate_axis_data(info.yaxis, y_values)
if info.xaxis.is_cat:
_, x_cat_param_label, x_cat_param_pos, _ = _calculate_axis_data(info.xaxis, x_values)
ax.set_xticks(x_cat_param_pos)
ax.set_xticklabels(x_cat_param_label)
else:
ax.set_xscale("log" if info.xaxis.is_log else "linear")
if info.yaxis.is_cat:
_, y_cat_param_label, y_cat_param_pos, _ = _calculate_axis_data(info.yaxis, y_values)
ax.set_yticks(y_cat_param_pos)
ax.set_yticklabels(y_cat_param_label)
else:
Expand All @@ -302,17 +270,7 @@ def _generate_contour_subplot(
if info.xaxis.name == info.yaxis.name:
return None

(
xi,
yi,
zi,
x_cat_param_pos,
x_cat_param_label,
y_cat_param_pos,
y_cat_param_label,
feasible_plot_values,
infeasible_plot_values,
) = _calculate_griddata(info)
zi, feasible_plot_values, infeasible_plot_values = _calculate_griddata(info)
cs = None
if len(zi) > 0:
# Contour the gridded data.
Expand Down

0 comments on commit 2e143f2

Please sign in to comment.