diff --git a/google/colab/_quickchart.py b/google/colab/_quickchart.py index 9a1cbce2..5be09d05 100644 --- a/google/colab/_quickchart.py +++ b/google/colab/_quickchart.py @@ -79,7 +79,7 @@ def _ensure_dataframe_registry(): if len(numeric_cols) >= 2: chart_sections += [ - _quickchart_helpers.linked_scatter_section( + _quickchart_helpers.scatter_section( df, _select_first_k_pairs(numeric_cols, k=max_chart_instances), _DATAFRAME_REGISTRY, diff --git a/google/colab/_quickchart_helpers.py b/google/colab/_quickchart_helpers.py index a54060a9..1d9289af 100644 --- a/google/colab/_quickchart_helpers.py +++ b/google/colab/_quickchart_helpers.py @@ -75,7 +75,7 @@ class ChartSectionType: FACETED_DISTRIBUTION = 'faceted_distribution' HEATMAP = 'heatmap' HISTOGRAM = 'histogram' - LINKED_SCATTER = 'linked_scatter' + SCATTER = 'scatter' TIME_SERIES_LINE_PLOT = 'time_series_line_plot' VALUE_PLOT = 'value_plot' @@ -330,8 +330,8 @@ def heatmaps_section(df, colname_pairs, df_registry): ) -def linked_scatter_section(df, colname_pairs, df_registry): - """Generates a section of linked scatter plots. +def scatter_section(df, colname_pairs, df_registry): + """Generates a section of scatter plots. Args: df: (pd.DataFrame) A dataframe. @@ -340,13 +340,13 @@ def linked_scatter_section(df, colname_pairs, df_registry): df_registry: (DataframeRegistry) Registry to use for dataframe lookups. Returns: - (ChartSection) A chart section containing linked scatter plots. + (ChartSection) A chart section containing scatter plots. """ return _chart_section( - ChartSectionType.LINKED_SCATTER, + ChartSectionType.SCATTER, df, - _quickchart_lib.scatter_plots, - [[list(colname_pairs)]], + _quickchart_lib.scatter_plot, + colname_pairs, {}, df_registry, '2-d distributions', diff --git a/google/colab/_quickchart_lib.py b/google/colab/_quickchart_lib.py index 90b2649f..037bc24b 100644 --- a/google/colab/_quickchart_lib.py +++ b/google/colab/_quickchart_lib.py @@ -138,13 +138,11 @@ def value_plot(df, y, figscale=1): return autoviz.MplChart.from_current_mpl_state() -def scatter_plots(df, colname_pairs, figscale=1, alpha=.8): +def scatter_plot(df, x_colname, y_colname, figscale=1, alpha=.8): from matplotlib import pyplot as plt - plt.figure(figsize=(len(colname_pairs) * 6 * figscale, 6 * figscale)) - for plot_i, (x_colname, y_colname) in enumerate(colname_pairs, start=1): - ax = plt.subplot(1, len(colname_pairs), plot_i) - df.plot(kind='scatter', x=x_colname, y=y_colname, s=(32 * figscale), alpha=alpha, ax=ax) - ax.spines[['top', 'right',]].set_visible(False) + plt.figure(figsize=(6 * figscale, 6 * figscale)) + df.plot(kind='scatter', x=x_colname, y=y_colname, s=(32 * figscale), alpha=alpha) + plt.gca().spines[['top', 'right',]].set_visible(False) plt.tight_layout() return autoviz.MplChart.from_current_mpl_state()