Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Stanley Kudrow committed Oct 15, 2024
1 parent 1ecc5e1 commit 3d8d546
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 60 deletions.
10 changes: 7 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
REPO_OWNER=nats-io
PROJECT_NAME=nats.py
SOURCE_CODE=nats
TEST_CODE=tests


help:
Expand All @@ -22,14 +23,17 @@ deps:

format:
yapf -i --recursive $(SOURCE_CODE)
yapf -i --recursive tests
yapf -i --recursive $(TEST_CODE)


test:
lint:
yapf --recursive --diff $(SOURCE_CODE)
yapf --recursive --diff tests
yapf --recursive --diff $(TEST_CODE)
mypy
flake8 ./nats/js/


test:
pytest


Expand Down
89 changes: 44 additions & 45 deletions nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
import ipaddress
import json
import logging
import os
import ssl
import time
import string
from dataclasses import dataclass
from email.parser import BytesParser
from enum import Enum
from random import shuffle
from secrets import token_hex
from typing import (
Expand All @@ -32,10 +34,8 @@
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
TypedDict,
Union,
)
from urllib.parse import ParseResult, urlparse
Expand All @@ -54,7 +54,6 @@
from nats.protocol.parser import (
AUTHORIZATION_VIOLATION,
PERMISSIONS_ERR,
PING,
PONG,
STALE_CONNECTION,
Parser,
Expand Down Expand Up @@ -196,14 +195,10 @@ async def _default_error_callback(ex: Exception) -> None:
_logger.error("nats: encountered error", exc_info=ex)


class Client:
"""
Asyncio based client for NATS.
"""
# Client section

msg_class: type[Msg] = Msg

# FIXME: Use an enum instead.
class ClientStates(Enum):
DISCONNECTED = 0
CONNECTED = 1
CLOSED = 2
Expand All @@ -212,6 +207,12 @@ class Client:
DRAINING_SUBS = 5
DRAINING_PUBS = 6


class Client:
"""Asyncio-based client for NATS."""

msg_class: type[Msg] = Msg

def __repr__(self) -> str:
return f"<nats client v{__version__}>"

Expand Down Expand Up @@ -242,7 +243,7 @@ def __init__(self) -> None:
self._client_id: Optional[int] = None
self._sid: int = 0
self._subs: Dict[int, Subscription] = {}
self._status: int = Client.DISCONNECTED
self._status = ClientStates.DISCONNECTED
self._ps: Parser = Parser(self)

# pending queue of commands that will be flushed to the server.
Expand Down Expand Up @@ -523,7 +524,7 @@ async def subscribe_handler(msg):
if not self.options["allow_reconnect"]:
raise e

await self._close(Client.DISCONNECTED, False)
await self._close(ClientStates.DISCONNECTED, False)
if self._current_server is not None:
self._current_server.last_attempt = time.monotonic()
self._current_server.reconnects += 1
Expand All @@ -536,7 +537,6 @@ def _setup_nkeys_connect(self) -> None:

def _setup_nkeys_jwt_connect(self) -> None:
assert self._user_credentials, "_user_credentials required"
import os
import nkeys

creds: Credentials = self._user_credentials
Expand Down Expand Up @@ -641,16 +641,15 @@ def _setup_nkeys_seed_connect(self) -> None:
assert (
self._nkeys_seed or self._nkeys_seed_str
), "Client.connect must be called first"
import os
import nkeys

def _get_nkeys_seed() -> nkeys.KeyPair:
import os

if self._nkeys_seed_str:
seed = bytearray(self._nkeys_seed_str.encode())
else:
creds = self._nkeys_seed
if creds is None:
raise ValueError("cannot extract nkeys seed")
with open(creds, "rb") as f:
seed = bytearray(os.fstat(f.fileno()).st_size)
f.readinto(seed) # type: ignore[attr-defined]
Expand Down Expand Up @@ -681,13 +680,12 @@ async def close(self) -> None:
sets the client to be in the CLOSED state.
No further reconnections occur once reaching this point.
"""
await self._close(Client.CLOSED)
await self._close(ClientStates.CLOSED)

async def _close(self, status: int, do_cbs: bool = True) -> None:
async def _close(self, status: ClientStates, do_cbs: bool = True) -> None:
if self.is_closed:
self._status = status
return
self._status = Client.CLOSED

# Kick the flusher once again so that Task breaks and avoid pending futures.
await self._flush_pending()
Expand Down Expand Up @@ -759,6 +757,8 @@ async def _close(self, status: int, do_cbs: bool = True) -> None:
if self._closed_cb is not None:
await self._closed_cb()

self._status = ClientStates.CLOSED

# Set the client_id and subscription prefix back to None
self._client_id = None
self._resp_sub_prefix = None
Expand Down Expand Up @@ -791,7 +791,7 @@ async def drain(self) -> None:
# Relinquish CPU to allow drain tasks to start in the background,
# before setting state to draining.
await asyncio.sleep(0)
self._status = Client.DRAINING_SUBS
self._status = ClientStates.DRAINING_SUBS

try:
await asyncio.wait_for(
Expand All @@ -804,9 +804,9 @@ async def drain(self) -> None:
except asyncio.CancelledError:
pass
finally:
self._status = Client.DRAINING_PUBS
self._status = ClientStates.DRAINING_PUBS
await self.flush()
await self._close(Client.CLOSED)
await self._close(ClientStates.CLOSED)

async def publish(
self,
Expand Down Expand Up @@ -1191,30 +1191,30 @@ def pending_data_size(self) -> int:

@property
def is_closed(self) -> bool:
return self._status == Client.CLOSED
return self._status == ClientStates.CLOSED

@property
def is_reconnecting(self) -> bool:
return self._status == Client.RECONNECTING
return self._status == ClientStates.RECONNECTING

@property
def is_connected(self) -> bool:
return (self._status == Client.CONNECTED) or self.is_draining
return (self._status == ClientStates.CONNECTED) or self.is_draining

@property
def is_connecting(self) -> bool:
return self._status == Client.CONNECTING
return self._status == ClientStates.CONNECTING

@property
def is_draining(self) -> bool:
return (
self._status == Client.DRAINING_SUBS
or self._status == Client.DRAINING_PUBS
self._status == ClientStates.DRAINING_SUBS
or self._status == ClientStates.DRAINING_PUBS
)

@property
def is_draining_pubs(self) -> bool:
return self._status == Client.DRAINING_PUBS
return self._status == ClientStates.DRAINING_PUBS

@property
def connected_server_version(self) -> ServerVersion:
Expand Down Expand Up @@ -1272,7 +1272,7 @@ async def _flush_pending(
except asyncio.CancelledError:
pass

def _setup_server_pool(self, connect_url: Union[List[str]]) -> None:
def _setup_server_pool(self, connect_url: Union[str | List[str]]) -> None:
if isinstance(connect_url, str):
try:
if "nats://" in connect_url or "tls://" in connect_url:
Expand Down Expand Up @@ -1404,7 +1404,7 @@ async def _process_err(self, err_msg: str) -> None:
# FIXME: Some errors such as 'Invalid Subscription'
# do not cause the server to close the connection.
# For now we handle similar as other clients and close.
asyncio.create_task(self._close(Client.CLOSED, do_cbs))
asyncio.create_task(self._close(ClientStates.CLOSED, do_cbs))

async def _process_op_err(self, e: Exception) -> None:
"""
Expand All @@ -1417,7 +1417,7 @@ async def _process_op_err(self, e: Exception) -> None:
return

if self.options["allow_reconnect"] and self.is_connected:
self._status = Client.RECONNECTING
self._status = ClientStates.RECONNECTING
self._ps.reset()

if (self._reconnection_task is not None
Expand All @@ -1431,7 +1431,7 @@ async def _process_op_err(self, e: Exception) -> None:
else:
self._process_disconnect()
self._err = e
await self._close(Client.CLOSED, True)
await self._close(ClientStates.CLOSED, True)

async def _attempt_reconnect(self) -> None:
assert self._current_server, "Client.connect must be called first"
Expand Down Expand Up @@ -1517,7 +1517,7 @@ async def _attempt_reconnect(self) -> None:
# to bail earlier in case there are errors in the connection.
# await self._flush_pending(force_flush=True)
await self._flush_pending()
self._status = Client.CONNECTED
self._status = ClientStates.CONNECTED
await self.flush()
if self._reconnected_cb is not None:
await self._reconnected_cb()
Expand All @@ -1530,7 +1530,7 @@ async def _attempt_reconnect(self) -> None:
except (OSError, errors.Error, asyncio.TimeoutError) as e:
self._err = e
await self._error_cb(e)
self._status = Client.RECONNECTING
self._status = ClientStates.RECONNECTING
self._current_server.last_attempt = time.monotonic()
self._current_server.reconnects += 1
except asyncio.CancelledError:
Expand Down Expand Up @@ -1593,9 +1593,8 @@ def _connect_command(self) -> bytes:
return b"".join([CONNECT_OP + _SPC_ + connect_opts.encode() + _CRLF_])

async def _process_ping(self) -> None:
"""
Process PING sent by server.
"""
"""Process PING sent by server."""

await self._send_command(PONG)
await self._flush_pending()

Expand All @@ -1622,7 +1621,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]:
if not headers:
return None

hdr: Optional[Dict[str, str]] = None
hdr: Dict[str, str] = {}
raw_headers = headers[NATS_HDR_LINE_SIZE:]

# If the first character is an empty space, then this is
Expand Down Expand Up @@ -1653,7 +1652,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]:
i = raw_headers.find(_CRLF_)
raw_headers = raw_headers[i + _CRLF_LEN_:]

if len(desc) > 0:
if len(desc):
# Heartbeat messages can have both headers and inline status,
# check that there are no pending headers to be parsed.
i = desc.find(_CRLF_)
Expand All @@ -1668,7 +1667,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]:
# Just inline status...
hdr[nats.js.api.Header.DESCRIPTION] = desc.decode()

if not len(raw_headers) > _CRLF_LEN_:
if len(raw_headers) <= _CRLF_LEN_:
return hdr

#
Expand Down Expand Up @@ -1861,7 +1860,7 @@ def _process_disconnect(self) -> None:
Process disconnection from the server and set client status
to DISCONNECTED.
"""
self._status = Client.DISCONNECTED
self._status = ClientStates.DISCONNECTED

def _process_info(
self, info: Dict[str, Any], initial_connection: bool = False
Expand Down Expand Up @@ -1925,7 +1924,7 @@ async def _process_connect_init(self) -> None:
"""
assert self._transport, "must be called only from Client.connect"
assert self._current_server, "must be called only from Client.connect"
self._status = Client.CONNECTING
self._status = ClientStates.CONNECTING

# Check whether to reuse the original hostname for an implicit route.
hostname = None
Expand Down Expand Up @@ -2026,7 +2025,7 @@ async def _process_connect_init(self) -> None:
)

if PONG_PROTO in next_op:
self._status = Client.CONNECTED
self._status = ClientStates.CONNECTED
elif ERR_OP in next_op:
err_line = next_op.decode()
_, err_msg = err_line.split(" ", 1)
Expand All @@ -2037,7 +2036,7 @@ async def _process_connect_init(self) -> None:
raise errors.Error("nats: " + err_msg.rstrip("\r\n"))

if PONG_PROTO in next_op:
self._status = Client.CONNECTED
self._status = ClientStates.CONNECTED

self._reading_task = asyncio.get_running_loop().create_task(
self._read_loop()
Expand Down Expand Up @@ -2150,7 +2149,7 @@ async def __aenter__(self) -> "Client":

async def __aexit__(self, *exc_info) -> None:
"""Close connection to NATS when used in a context manager"""
await self._close(Client.CLOSED, do_cbs=True)
await self._close(ClientStates.CLOSED, do_cbs=True)

def jetstream(self, **opts) -> nats.js.JetStreamContext:
"""
Expand Down
2 changes: 1 addition & 1 deletion nats/js/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ class StreamsListIterator(Iterable):
"""

def __init__(
self, offset: int, total: int, streams: List[Dict[str, any]]
self, offset: int, total: int, streams: List[Dict[str, Any]]
) -> None:
self.offset = offset
self.total = total
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespaces = false # to disable scanning PEP 420 namespaces (true by default)

[tool.mypy]
files = ["nats"]
python_version = "3.7"
python_version = "3.9"
ignore_missing_imports = true
follow_imports = "silent"
show_error_codes = true
Expand All @@ -67,3 +67,8 @@ combine_as_imports = true
multi_line_output = 3
include_trailing_comma = true
src_paths = ["nats", "tests"]

[tool.pytest.ini_options]
minversion = "7.0"
addopts = "--maxfail=1 -rfs -vvv"
testpaths = ["tests"]
Loading

0 comments on commit 3d8d546

Please sign in to comment.