Skip to content

Commit

Permalink
Fix relpot weights
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Dec 6, 2023
1 parent 595dba7 commit 785242b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
11 changes: 8 additions & 3 deletions seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,17 +792,18 @@ def relplot(

# Add the grid semantics onto the plotter
grid_variables = dict(
x=x, y=y, row=row, col=col,
hue=hue, size=size, style=style,
x=x, y=y, row=row, col=col, hue=hue, size=size, style=style,
)
if kind == "line":
grid_variables["units"] = units
grid_variables.update(units=units, weights=weights)
p.assign_variables(data, grid_variables)

# Define the named variables for plotting on each facet
# Rename the variables with a leading underscore to avoid
# collisions with faceting variable names
plot_variables = {v: f"_{v}" for v in variables}
if "weight" in plot_variables:
plot_variables["weights"] = plot_variables.pop("weight")
plot_kws.update(plot_variables)

# Pass the row/col variables to FacetGrid with their original
Expand Down Expand Up @@ -930,6 +931,10 @@ def relplot(
Grouping variable that will produce elements with different styles.
Can have a numeric dtype but will always be treated as categorical.
{params.rel.units}
weights : vector or key in `data`
Data values or column used to compute weighted estimation.
Note that use of weights currently limits the choice of statistics
to a 'mean' estimator and 'ci' errorbar.
{params.facets.rowcol}
{params.facets.col_wrap}
row_order, col_order : lists of strings
Expand Down
13 changes: 11 additions & 2 deletions tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,15 @@ def test_relplot_styles(self, long_df):
expected_paths = [paths[val] for val in grp_df["a"]]
assert self.paths_equal(points.get_paths(), expected_paths)

def test_relplot_weighted_estimator(self, long_df):

g = relplot(data=long_df, x="a", y="y", weights="x", kind="line")
ydata = g.ax.lines[0].get_ydata()
for i, label in enumerate(g.ax.get_xticklabels()):
pos_df = long_df[long_df["a"] == label.get_text()]
expected = np.average(pos_df["y"], weights=pos_df["x"])
assert ydata[i] == pytest.approx(expected)

def test_relplot_stringy_numerics(self, long_df):

long_df["x_str"] = long_df["x"].astype(str)
Expand Down Expand Up @@ -1063,8 +1072,8 @@ def test_weights(self, long_df):

ax = lineplot(long_df, x="a", y="y", weights="x")
vals = ax.lines[0].get_ydata()
for i, a in enumerate(ax.get_xticklabels()):
pos_df = long_df.loc[long_df["a"] == a.get_text()]
for i, label in enumerate(ax.get_xticklabels()):
pos_df = long_df.loc[long_df["a"] == label.get_text()]
expected = np.average(pos_df["y"], weights=pos_df["x"])
assert vals[i] == pytest.approx(expected)

Expand Down

0 comments on commit 785242b

Please sign in to comment.