Skip to content

Commit

Permalink
Use memoryview.nbytes when warning on large graph send (#8268)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Oct 13, 2023
1 parent 0111087 commit 8d89ef0
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 19 deletions.
8 changes: 5 additions & 3 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
import_term,
is_python_shutting_down,
log_errors,
nbytes,
sync,
thread_state,
)
Expand Down Expand Up @@ -3156,10 +3157,11 @@ def _graph_to_futures(
from distributed.protocol.serialize import ToPickle

header, frames = serialize(ToPickle(dsk), on_error="raise")
nbytes = len(header) + sum(map(len, frames))
if nbytes > 10_000_000:

pickled_size = sum(map(nbytes, [header] + frames))
if pickled_size > 10_000_000:
warnings.warn(
f"Sending large graph of size {format_bytes(nbytes)}.\n"
f"Sending large graph of size {format_bytes(pickled_size)}.\n"
"This may cause some slowdown.\n"
"Consider scattering data ahead of time and using futures."
)
Expand Down
14 changes: 5 additions & 9 deletions distributed/spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from distributed.protocol import deserialize_bytes, serialize_bytelist
from distributed.protocol.compression import get_compression_settings
from distributed.sizeof import safe_sizeof
from distributed.utils import RateLimiterFilter
from distributed.utils import RateLimiterFilter, nbytes

logger = logging.getLogger(__name__)
logger.addFilter(RateLimiterFilter("Spill file on disk reached capacity"))
Expand Down Expand Up @@ -208,11 +208,11 @@ def __getitem__(self, key: Key) -> object:
# Note: don't log from self.fast.__getitem__, because that's called
# every time a key is evicted, and we don't want to count those events
# here.
nbytes = cast(int, self.fast.weights[key])
memory_size = cast(int, self.fast.weights[key])
# This is logged not only by the internal metrics callback but also by
# those installed by gather_dep, get_data, and execute
context_meter.digest_metric("memory-read", 1, "count")
context_meter.digest_metric("memory-read", nbytes, "bytes")
context_meter.digest_metric("memory-read", memory_size, "bytes")

return super().__getitem__(key)

Expand Down Expand Up @@ -326,17 +326,13 @@ def __setitem__(self, key: Key, value: object) -> None:
# which will then unwrap it.
raise PickleError(key, e)

pickled_size = sum(
frame.nbytes if isinstance(frame, memoryview) else len(frame)
# See note in __init__ about serialize_bytelist
for frame in cast(list, pickled)
)

# Thanks to Buffer.__setitem__, we never update existing
# keys in slow, but always delete them and reinsert them.
assert key not in self.d
assert key not in self.weight_by_key

pickled_size = sum(map(nbytes, pickled))

if (
self.max_weight is not False
and self.total_weight.disk + pickled_size > self.max_weight
Expand Down
25 changes: 18 additions & 7 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import array
import asyncio
import concurrent.futures
import functools
Expand Down Expand Up @@ -5964,16 +5965,26 @@ async def test_config_scheduler_address(s, a, b):
assert sio.getvalue() == f"Config value `scheduler-address` found: {s.address}\n"


@pytest.mark.filterwarnings("ignore:Large object:UserWarning")
@gen_cluster(client=True)
async def test_warn_when_submitting_large_values(c, s, a, b):
with pytest.warns(
UserWarning,
match="Sending large graph of size",
):
@gen_cluster(client=True, nthreads=[])
async def test_warn_when_submitting_large_values(c, s):
with pytest.warns(UserWarning, match="Sending large graph of size"):
future = c.submit(lambda x: x + 1, b"0" * 10_000_000)


@gen_cluster(client=True, nthreads=[])
async def test_warn_when_submitting_large_values_memoryview(c, s):
"""When sending numpy or parquet data, len(memoryview(obj)) returns the number of
elements, not the number of bytes. Make sure we're reading memoryview.nbytes.
"""
# The threshold is 10MB
a = array.array("d", b"0" * 9_500_000)
c.submit(lambda: a)

a = array.array("d", b"0" * 10_000_000)
with pytest.warns(UserWarning, match="Sending large graph of size"):
c.submit(lambda: a)


@gen_cluster(client=True)
async def test_unhashable_function(c, s, a, b):
func = _UnhashableCallable()
Expand Down

0 comments on commit 8d89ef0

Please sign in to comment.