diff --git a/seaborn/relational.py b/seaborn/relational.py index bb872887a1..8d0d856021 100644 --- a/seaborn/relational.py +++ b/seaborn/relational.py @@ -792,17 +792,18 @@ def relplot( # Add the grid semantics onto the plotter grid_variables = dict( - x=x, y=y, row=row, col=col, - hue=hue, size=size, style=style, + x=x, y=y, row=row, col=col, hue=hue, size=size, style=style, ) if kind == "line": - grid_variables["units"] = units + grid_variables.update(units=units, weights=weights) p.assign_variables(data, grid_variables) # Define the named variables for plotting on each facet # Rename the variables with a leading underscore to avoid # collisions with faceting variable names plot_variables = {v: f"_{v}" for v in variables} + if "weight" in plot_variables: + plot_variables["weights"] = plot_variables.pop("weight") plot_kws.update(plot_variables) # Pass the row/col variables to FacetGrid with their original @@ -930,6 +931,10 @@ def relplot( Grouping variable that will produce elements with different styles. Can have a numeric dtype but will always be treated as categorical. {params.rel.units} +weights : vector or key in `data` + Data values or column used to compute weighted estimation. + Note that use of weights currently limits the choice of statistics + to a 'mean' estimator and 'ci' errorbar. {params.facets.rowcol} {params.facets.col_wrap} row_order, col_order : lists of strings diff --git a/tests/test_relational.py b/tests/test_relational.py index 2f5eda1dc6..06d0860a38 100644 --- a/tests/test_relational.py +++ b/tests/test_relational.py @@ -578,6 +578,15 @@ def test_relplot_styles(self, long_df): expected_paths = [paths[val] for val in grp_df["a"]] assert self.paths_equal(points.get_paths(), expected_paths) + def test_relplot_weighted_estimator(self, long_df): + + g = relplot(data=long_df, x="a", y="y", weights="x", kind="line") + ydata = g.ax.lines[0].get_ydata() + for i, label in enumerate(g.ax.get_xticklabels()): + pos_df = long_df[long_df["a"] == label.get_text()] + expected = np.average(pos_df["y"], weights=pos_df["x"]) + assert ydata[i] == pytest.approx(expected) + def test_relplot_stringy_numerics(self, long_df): long_df["x_str"] = long_df["x"].astype(str) @@ -1063,8 +1072,8 @@ def test_weights(self, long_df): ax = lineplot(long_df, x="a", y="y", weights="x") vals = ax.lines[0].get_ydata() - for i, a in enumerate(ax.get_xticklabels()): - pos_df = long_df.loc[long_df["a"] == a.get_text()] + for i, label in enumerate(ax.get_xticklabels()): + pos_df = long_df.loc[long_df["a"] == label.get_text()] expected = np.average(pos_df["y"], weights=pos_df["x"]) assert vals[i] == pytest.approx(expected)