diff --git a/optuna/visualization/matplotlib/_contour.py b/optuna/visualization/matplotlib/_contour.py index 359c56b9cba..6a4737d9db6 100644 --- a/optuna/visualization/matplotlib/_contour.py +++ b/optuna/visualization/matplotlib/_contour.py @@ -190,19 +190,7 @@ def _calculate_axis_data( return ci, cat_param_labels, cat_param_pos, list(returned_values) -def _calculate_griddata( - info: _SubContourInfo, -) -> tuple[ - np.ndarray, - np.ndarray, - np.ndarray, - list[int], - list[str], - list[int], - list[str], - _PlotValues, - _PlotValues, -]: +def _calculate_griddata(info: _SubContourInfo) -> tuple[np.ndarray, _PlotValues, _PlotValues]: xaxis = info.xaxis yaxis = info.yaxis z_values_dict = info.z_values @@ -220,17 +208,7 @@ def _calculate_griddata( # Return empty values when x or y has no value. if len(x_values) == 0 or len(y_values) == 0: - return ( - np.array([]), - np.array([]), - np.array([]), - [], - [], - [], - [], - _PlotValues([], []), - _PlotValues([], []), - ) + return np.array([]), _PlotValues([], []), _PlotValues([], []) xi, cat_param_labels_x, cat_param_pos_x, transformed_x_values = _calculate_axis_data( xaxis, @@ -261,90 +239,64 @@ def _calculate_griddata( infeasible.x.append(x_value) infeasible.y.append(y_value) - return ( - xi, - yi, - zi, - cat_param_pos_x, - cat_param_labels_x, - cat_param_pos_y, - cat_param_labels_y, - feasible, - infeasible, - ) + return zi, feasible, infeasible def _generate_contour_subplot( info: _SubContourInfo, ax: "Axes", cmap: "Colormap" ) -> "ContourSet" | None: + ax.label_outer() + if len(info.xaxis.indices) < 2 or len(info.yaxis.indices) < 2: - ax.label_outer() return None ax.set(xlabel=info.xaxis.name, ylabel=info.yaxis.name) ax.set_xlim(info.xaxis.range[0], info.xaxis.range[1]) ax.set_ylim(info.yaxis.range[0], info.yaxis.range[1]) x_values, y_values = _filter_missing_values(info.xaxis, info.yaxis) + xi, x_cat_param_label, x_cat_param_pos, _ = _calculate_axis_data(info.xaxis, x_values) + yi, y_cat_param_label, y_cat_param_pos, _ = _calculate_axis_data(info.yaxis, y_values) if info.xaxis.is_cat: - _, x_cat_param_label, x_cat_param_pos, _ = _calculate_axis_data(info.xaxis, x_values) ax.set_xticks(x_cat_param_pos) ax.set_xticklabels(x_cat_param_label) else: ax.set_xscale("log" if info.xaxis.is_log else "linear") if info.yaxis.is_cat: - _, y_cat_param_label, y_cat_param_pos, _ = _calculate_axis_data(info.yaxis, y_values) ax.set_yticks(y_cat_param_pos) ax.set_yticklabels(y_cat_param_label) else: ax.set_yscale("log" if info.yaxis.is_log else "linear") if info.xaxis.name == info.yaxis.name: - ax.label_outer() return None - ( - xi, - yi, - zi, - x_cat_param_pos, - x_cat_param_label, - y_cat_param_pos, - y_cat_param_label, - feasible_plot_values, - infeasible_plot_values, - ) = _calculate_griddata(info) + zi, feasible_plot_values, infeasible_plot_values = _calculate_griddata(info) cs = None if len(zi) > 0: - if info.xaxis.is_log: - ax.set_xscale("log") - if info.yaxis.is_log: - ax.set_yscale("log") - if info.xaxis.name != info.yaxis.name: - # Contour the gridded data. - ax.contour(xi, yi, zi, 15, linewidths=0.5, colors="k") - cs = ax.contourf(xi, yi, zi, 15, cmap=cmap.reversed()) - assert isinstance(cs, ContourSet) - # Plot data points. - ax.scatter( - feasible_plot_values.x, - feasible_plot_values.y, - marker="o", - c="black", - s=20, - edgecolors="grey", - linewidth=2.0, - ) - ax.scatter( - infeasible_plot_values.x, - infeasible_plot_values.y, - marker="o", - c="#cccccc", - s=20, - edgecolors="grey", - linewidth=2.0, - ) + # Contour the gridded data. + ax.contour(xi, yi, zi, 15, linewidths=0.5, colors="k") + cs = ax.contourf(xi, yi, zi, 15, cmap=cmap.reversed()) + assert isinstance(cs, ContourSet) + # Plot data points. + ax.scatter( + feasible_plot_values.x, + feasible_plot_values.y, + marker="o", + c="black", + s=20, + edgecolors="grey", + linewidth=2.0, + ) + ax.scatter( + infeasible_plot_values.x, + infeasible_plot_values.y, + marker="o", + c="#cccccc", + s=20, + edgecolors="grey", + linewidth=2.0, + ) - ax.label_outer() return cs