Skip to content

Commit

Permalink
Add weights to relplot and catplot
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Dec 6, 2023
1 parent c4854df commit 595dba7
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 18 deletions.
33 changes: 20 additions & 13 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2739,12 +2739,12 @@ def countplot(

def catplot(
data=None, *, x=None, y=None, hue=None, row=None, col=None, kind="strip",
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None,
order=None, hue_order=None, row_order=None, col_order=None, col_wrap=None,
height=5, aspect=1, log_scale=None, native_scale=False, formatter=None,
orient=None, color=None, palette=None, hue_norm=None, legend="auto",
legend_out=True, sharex=True, sharey=True, margin_titles=False, facet_kws=None,
ci=deprecated, **kwargs
estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None,
weights=None, order=None, hue_order=None, row_order=None, col_order=None,
col_wrap=None, height=5, aspect=1, log_scale=None, native_scale=False,
formatter=None, orient=None, color=None, palette=None, hue_norm=None,
legend="auto", legend_out=True, sharex=True, sharey=True,
margin_titles=False, facet_kws=None, ci=deprecated, **kwargs
):

# Check for attempt to plot onto specific axes and warn
Expand Down Expand Up @@ -2774,7 +2774,9 @@ def catplot(

p = Plotter(
data=data,
variables=dict(x=x, y=y, hue=hue, row=row, col=col, units=units),
variables=dict(
x=x, y=y, hue=hue, row=row, col=col, units=units, weight=weights
),
order=order,
orient=orient,
# Handle special backwards compatibility where pointplot originally
Expand Down Expand Up @@ -2850,6 +2852,14 @@ def catplot(
if dodge == "auto":
dodge = p._dodge_needed()

if "weight" in p.plot_data:
if kind not in ["bar", "point"]:
msg = f"The `weights` parameter has no effect with kind={kind!r}."
warnings.warn(msg, stacklevel=2)
agg_cls = WeightedAggregator
else:
agg_cls = EstimateAggregator

if kind == "strip":

jitter = kwargs.pop("jitter", True)
Expand Down Expand Up @@ -2999,9 +3009,7 @@ def catplot(

elif kind == "point":

aggregator = EstimateAggregator(
estimator, errorbar, n_boot=n_boot, seed=seed
)
aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)

markers = kwargs.pop("markers", default)
linestyles = kwargs.pop("linestyles", default)
Expand Down Expand Up @@ -3035,9 +3043,8 @@ def catplot(

elif kind == "bar":

aggregator = EstimateAggregator(
estimator, errorbar, n_boot=n_boot, seed=seed
)
aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)

err_kws, capsize = p._err_kws_backcompat(
_normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D),
errcolor=kwargs.pop("errcolor", deprecated),
Expand Down
13 changes: 9 additions & 4 deletions seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def scatterplot(

def relplot(
data=None, *,
x=None, y=None, hue=None, size=None, style=None, units=None,
x=None, y=None, hue=None, size=None, style=None, units=None, weights=None,
row=None, col=None, col_wrap=None, row_order=None, col_order=None,
palette=None, hue_order=None, hue_norm=None,
sizes=None, size_order=None, size_norm=None,
Expand Down Expand Up @@ -732,9 +732,14 @@ def relplot(
variables = dict(x=x, y=y, hue=hue, size=size, style=style)
if kind == "line":
variables["units"] = units
elif units is not None:
msg = "The `units` parameter of `relplot` has no effect with kind='scatter'"
warnings.warn(msg, stacklevel=2)
variables["weight"] = weights
else:
if units is not None:
msg = "The `units` parameter has no effect with kind='scatter'."
warnings.warn(msg, stacklevel=2)
if weights is not None:
msg = "The `weights` parameter has no effect with kind='scatter'."
warnings.warn(msg, stacklevel=2)
p = Plotter(
data=data,
variables=variables,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -3147,6 +3147,12 @@ def test_legend_with_auto(self):
g2 = catplot(self.df, x="g", y="y", hue="g", legend=True)
assert g2._legend is not None

def test_weights_warning(self, long_df):

with pytest.warns(UserWarning, match="The `weights` parameter"):
g = catplot(long_df, x="a", y="y", weights="z")
assert g.ax is not None


class TestBeeswarm:

Expand Down
6 changes: 5 additions & 1 deletion tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,12 +668,16 @@ def test_facet_variable_collision(self, long_df):
)
assert g.axes.shape == (1, len(col_data.unique()))

def test_relplot_scatter_units(self, long_df):
def test_relplot_scatter_unused_variables(self, long_df):

with pytest.warns(UserWarning, match="The `units` parameter"):
g = relplot(long_df, x="x", y="y", units="a")
assert g.ax is not None

with pytest.warns(UserWarning, match="The `weights` parameter"):
g = relplot(long_df, x="x", y="y", weights="x")
assert g.ax is not None

def test_ax_kwarg_removal(self, long_df):

f, ax = plt.subplots()
Expand Down

0 comments on commit 595dba7

Please sign in to comment.