Skip to content

Commit

Permalink
Add weighted estimation in function interface
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Dec 6, 2023
1 parent 2bb945c commit c4854df
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 30 deletions.
2 changes: 1 addition & 1 deletion seaborn/_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def __call__(self, data, var):
return pd.Series({var: estimate, f"{var}min": err_min, f"{var}max": err_max})


class WeightedEstimateAggregator:
class WeightedAggregator:

def __init__(self, estimator, errorbar=None, **boot_kws):
"""
Expand Down
4 changes: 2 additions & 2 deletions seaborn/_stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from seaborn._stats.base import Stat
from seaborn._statistics import (
EstimateAggregator,
WeightedEstimateAggregator,
WeightedAggregator,
)
from seaborn._core.typing import Vector

Expand Down Expand Up @@ -105,7 +105,7 @@ def __call__(

boot_kws = {"n_boot": self.n_boot, "seed": self.seed}
if "weight" in data:
engine = WeightedEstimateAggregator(self.func, self.errorbar, **boot_kws)
engine = WeightedAggregator(self.func, self.errorbar, **boot_kws)
else:
engine = EstimateAggregator(self.func, self.errorbar, **boot_kws)

Expand Down
44 changes: 27 additions & 17 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
_version_predates,
)
from seaborn._compat import MarkerStyle
from seaborn._statistics import EstimateAggregator, LetterValues
from seaborn._statistics import (
EstimateAggregator,
LetterValues,
WeightedAggregator,
)
from seaborn.palettes import light_palette
from seaborn.axisgrid import FacetGrid, _facet_docs

Expand Down Expand Up @@ -1385,11 +1389,16 @@ class _CategoricalAggPlotter(_CategoricalPlotter):
.. versionadded:: v0.12.0
n_boot : int
Number of bootstrap samples used to compute confidence intervals.
seed : int, `numpy.random.Generator`, or `numpy.random.RandomState`
Seed or random number generator for reproducible bootstrapping.
units : name of variable in `data` or vector data
Identifier of sampling units; used by the errorbar function to
perform a multilevel bootstrap and account for repeated measures
seed : int, `numpy.random.Generator`, or `numpy.random.RandomState`
Seed or random number generator for reproducible bootstrapping.\
weights : name of variable in `data` or vector data
Data values or column used to compute weighted statistics.
Note that the use of weights may limit other statistical options.
.. versionadded:: v0.13.1\
"""),
ci=dedent("""\
ci : float
Expand Down Expand Up @@ -2308,10 +2317,10 @@ def swarmplot(

def barplot(
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None,
orient=None, color=None, palette=None, saturation=.75, fill=True, hue_norm=None,
width=.8, dodge="auto", gap=0, log_scale=None, native_scale=False, formatter=None,
legend="auto", capsize=0, err_kws=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None,
weights=None, orient=None, color=None, palette=None, saturation=.75,
fill=True, hue_norm=None, width=.8, dodge="auto", gap=0, log_scale=None,
native_scale=False, formatter=None, legend="auto", capsize=0, err_kws=None,
ci=deprecated, errcolor=deprecated, errwidth=deprecated, ax=None, **kwargs,
):

Expand All @@ -2324,7 +2333,7 @@ def barplot(

p = _CategoricalAggPlotter(
data=data,
variables=dict(x=x, y=y, hue=hue, units=units),
variables=dict(x=x, y=y, hue=hue, units=units, weight=weights),
order=order,
orient=orient,
color=color,
Expand Down Expand Up @@ -2354,7 +2363,8 @@ def barplot(
p.map_hue(palette=palette, order=hue_order, norm=hue_norm, saturation=saturation)
color = _default_color(ax.bar, hue, color, kwargs, saturation=saturation)

aggregator = EstimateAggregator(estimator, errorbar, n_boot=n_boot, seed=seed)
agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator
aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)
err_kws = {} if err_kws is None else _normalize_kwargs(err_kws, mpl.lines.Line2D)

# Deprecations to remove in v0.15.0.
Expand Down Expand Up @@ -2449,20 +2459,19 @@ def barplot(

def pointplot(
data=None, *, x=None, y=None, hue=None, order=None, hue_order=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None,
color=None, palette=None, hue_norm=None, markers=default, linestyles=default,
dodge=False, log_scale=None, native_scale=False, orient=None, capsize=0,
formatter=None, legend="auto", err_kws=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None,
weights=None, color=None, palette=None, hue_norm=None, markers=default,
linestyles=default, dodge=False, log_scale=None, native_scale=False,
orient=None, capsize=0, formatter=None, legend="auto", err_kws=None,
ci=deprecated, errwidth=deprecated, join=deprecated, scale=deprecated,
ax=None,
**kwargs,
ax=None, **kwargs,
):

errorbar = utils._deprecate_ci(errorbar, ci)

p = _CategoricalAggPlotter(
data=data,
variables=dict(x=x, y=y, hue=hue, units=units),
variables=dict(x=x, y=y, hue=hue, units=units, weight=weights),
order=order,
orient=orient,
# Handle special backwards compatibility where pointplot originally
Expand All @@ -2489,7 +2498,8 @@ def pointplot(
p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
color = _default_color(ax.plot, hue, color, kwargs)

aggregator = EstimateAggregator(estimator, errorbar, n_boot=n_boot, seed=seed)
agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator
aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed)
err_kws = {} if err_kws is None else _normalize_kwargs(err_kws, mpl.lines.Line2D)

# Deprecations to remove in v0.15.0.
Expand Down
15 changes: 11 additions & 4 deletions seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_normalize_kwargs,
_scatter_legend_artist,
)
from ._statistics import EstimateAggregator
from ._statistics import EstimateAggregator, WeightedAggregator
from .axisgrid import FacetGrid, _facet_docs
from ._docstrings import DocstringComponents, _core_docs

Expand Down Expand Up @@ -252,7 +252,8 @@ def plot(self, ax, kws):
raise ValueError(err.format(self.err_style))

# Initialize the aggregation object
agg = EstimateAggregator(
weighted = "weight" in self.plot_data
agg = (WeightedAggregator if weighted else EstimateAggregator)(
self.estimator, self.errorbar, n_boot=self.n_boot, seed=self.seed,
)

Expand Down Expand Up @@ -464,7 +465,7 @@ def plot(self, ax, kws):

def lineplot(
data=None, *,
x=None, y=None, hue=None, size=None, style=None, units=None,
x=None, y=None, hue=None, size=None, style=None, units=None, weights=None,
palette=None, hue_order=None, hue_norm=None,
sizes=None, size_order=None, size_norm=None,
dashes=True, markers=None, style_order=None,
Expand All @@ -478,7 +479,9 @@ def lineplot(

p = _LinePlotter(
data=data,
variables=dict(x=x, y=y, hue=hue, size=size, style=style, units=units),
variables=dict(
x=x, y=y, hue=hue, size=size, style=style, units=units, weight=weights
),
estimator=estimator, n_boot=n_boot, seed=seed, errorbar=errorbar,
sort=sort, orient=orient, err_style=err_style, err_kws=err_kws,
legend=legend,
Expand Down Expand Up @@ -536,6 +539,10 @@ def lineplot(
and/or markers. Can have a numeric dtype but will always be treated
as categorical.
{params.rel.units}
weights : vector or key in `data`
Data values or column used to compute weighted estimation.
Note that use of weights currently limits the choice of statistics
to a 'mean' estimator and 'ci' errorbar.
{params.core.palette}
{params.core.hue_order}
{params.core.hue_norm}
Expand Down
14 changes: 14 additions & 0 deletions tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,6 +2131,13 @@ def test_estimate_func(self, long_df):
for i, bar in enumerate(ax.patches):
assert bar.get_height() == approx(agg_df[order[i]])

def test_weighted_estimate(self, long_df):

ax = barplot(long_df, y="y", weights="x")
height = ax.patches[0].get_height()
expected = np.average(long_df["y"], weights=long_df["x"])
assert height == expected

def test_estimate_log_transform(self, long_df):

ax = mpl.figure.Figure().subplots()
Expand Down Expand Up @@ -2490,6 +2497,13 @@ def test_estimate(self, long_df, estimator):
for i, xy in enumerate(ax.lines[0].get_xydata()):
assert tuple(xy) == approx((i, agg_df[order[i]]))

def test_weighted_estimate(self, long_df):

ax = pointplot(long_df, y="y", weights="x")
val = ax.lines[0].get_ydata().item()
expected = np.average(long_df["y"], weights=long_df["x"])
assert val == expected

def test_estimate_log_transform(self, long_df):

ax = mpl.figure.Figure().subplots()
Expand Down
9 changes: 9 additions & 0 deletions tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,15 @@ def test_plot(self, long_df, repeated_df):
ax.clear()
p.plot(ax, {})

def test_weights(self, long_df):

ax = lineplot(long_df, x="a", y="y", weights="x")
vals = ax.lines[0].get_ydata()
for i, a in enumerate(ax.get_xticklabels()):
pos_df = long_df.loc[long_df["a"] == a.get_text()]
expected = np.average(pos_df["y"], weights=pos_df["x"])
assert vals[i] == pytest.approx(expected)

def test_non_aggregated_data(self):

x = [1, 2, 3, 4]
Expand Down
12 changes: 6 additions & 6 deletions tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ECDF,
EstimateAggregator,
LetterValues,
WeightedEstimateAggregator,
WeightedAggregator,
_validate_errorbar_arg,
_no_scipy,
)
Expand Down Expand Up @@ -633,12 +633,12 @@ def test_errorbar_validation(self):
_validate_errorbar_arg(arg)


class TestWeightedEstimateAggregator:
class TestWeightedAggregator:

def test_weighted_mean(self, long_df):

long_df["weight"] = long_df["x"]
est = WeightedEstimateAggregator("mean")
est = WeightedAggregator("mean")
out = est(long_df, "y")
expected = np.average(long_df["y"], weights=long_df["weight"])
assert_array_equal(out["y"], expected)
Expand All @@ -648,7 +648,7 @@ def test_weighted_mean(self, long_df):
def test_weighted_ci(self, long_df):

long_df["weight"] = long_df["x"]
est = WeightedEstimateAggregator("mean", "ci")
est = WeightedAggregator("mean", "ci")
out = est(long_df, "y")
expected = np.average(long_df["y"], weights=long_df["weight"])
assert_array_equal(out["y"], expected)
Expand All @@ -658,12 +658,12 @@ def test_weighted_ci(self, long_df):
def test_limited_estimator(self):

with pytest.raises(ValueError, match="Weighted estimator must be 'mean'"):
WeightedEstimateAggregator("median")
WeightedAggregator("median")

def test_limited_ci(self):

with pytest.raises(ValueError, match="Error bar method must be 'ci'"):
WeightedEstimateAggregator("mean", "sd")
WeightedAggregator("mean", "sd")


class TestLetterValues:
Expand Down

0 comments on commit c4854df

Please sign in to comment.