From 2e143f2e778230a8c4ed514967ec63e6232d2092 Mon Sep 17 00:00:00 2001 From: Naoto Mizuno Date: Mon, 16 Dec 2024 19:36:12 +0900 Subject: [PATCH] Simplify return values of _calculate_griddata --- optuna/visualization/matplotlib/_contour.py | 54 +++------------------ 1 file changed, 6 insertions(+), 48 deletions(-) diff --git a/optuna/visualization/matplotlib/_contour.py b/optuna/visualization/matplotlib/_contour.py index 359898b891e..6a4737d9db6 100644 --- a/optuna/visualization/matplotlib/_contour.py +++ b/optuna/visualization/matplotlib/_contour.py @@ -190,19 +190,7 @@ def _calculate_axis_data( return ci, cat_param_labels, cat_param_pos, list(returned_values) -def _calculate_griddata( - info: _SubContourInfo, -) -> tuple[ - np.ndarray, - np.ndarray, - np.ndarray, - list[int], - list[str], - list[int], - list[str], - _PlotValues, - _PlotValues, -]: +def _calculate_griddata(info: _SubContourInfo) -> tuple[np.ndarray, _PlotValues, _PlotValues]: xaxis = info.xaxis yaxis = info.yaxis z_values_dict = info.z_values @@ -220,17 +208,7 @@ def _calculate_griddata( # Return empty values when x or y has no value. if len(x_values) == 0 or len(y_values) == 0: - return ( - np.array([]), - np.array([]), - np.array([]), - [], - [], - [], - [], - _PlotValues([], []), - _PlotValues([], []), - ) + return np.array([]), _PlotValues([], []), _PlotValues([], []) xi, cat_param_labels_x, cat_param_pos_x, transformed_x_values = _calculate_axis_data( xaxis, @@ -261,17 +239,7 @@ def _calculate_griddata( infeasible.x.append(x_value) infeasible.y.append(y_value) - return ( - xi, - yi, - zi, - cat_param_pos_x, - cat_param_labels_x, - cat_param_pos_y, - cat_param_labels_y, - feasible, - infeasible, - ) + return zi, feasible, infeasible def _generate_contour_subplot( @@ -286,14 +254,14 @@ def _generate_contour_subplot( ax.set_xlim(info.xaxis.range[0], info.xaxis.range[1]) ax.set_ylim(info.yaxis.range[0], info.yaxis.range[1]) x_values, y_values = _filter_missing_values(info.xaxis, info.yaxis) + xi, x_cat_param_label, x_cat_param_pos, _ = _calculate_axis_data(info.xaxis, x_values) + yi, y_cat_param_label, y_cat_param_pos, _ = _calculate_axis_data(info.yaxis, y_values) if info.xaxis.is_cat: - _, x_cat_param_label, x_cat_param_pos, _ = _calculate_axis_data(info.xaxis, x_values) ax.set_xticks(x_cat_param_pos) ax.set_xticklabels(x_cat_param_label) else: ax.set_xscale("log" if info.xaxis.is_log else "linear") if info.yaxis.is_cat: - _, y_cat_param_label, y_cat_param_pos, _ = _calculate_axis_data(info.yaxis, y_values) ax.set_yticks(y_cat_param_pos) ax.set_yticklabels(y_cat_param_label) else: @@ -302,17 +270,7 @@ def _generate_contour_subplot( if info.xaxis.name == info.yaxis.name: return None - ( - xi, - yi, - zi, - x_cat_param_pos, - x_cat_param_label, - y_cat_param_pos, - y_cat_param_label, - feasible_plot_values, - infeasible_plot_values, - ) = _calculate_griddata(info) + zi, feasible_plot_values, infeasible_plot_values = _calculate_griddata(info) cs = None if len(zi) > 0: # Contour the gridded data.