Skip to content

Commit

Permalink
Avoid hanging locks in modbus connection
Browse files Browse the repository at this point in the history
  • Loading branch information
albireox committed Nov 9, 2023
1 parent a325041 commit 80e49eb
Showing 1 changed file with 33 additions and 31 deletions.
64 changes: 33 additions & 31 deletions python/lvmecp/modbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,21 +234,15 @@ def __init__(self, config: dict | pathlib.Path | str | None = None):
async def connect(self):
"""Connects to the client."""

if self.lock:
await self.lock.acquire()

hp = f"{self.host}:{self.port}"
log.debug(f"Trying to connect to modbus server on {hp}")

try:
await asyncio.wait_for(self.client.connect(), timeout=1)
await asyncio.wait_for(self.client.connect(), timeout=5)
except asyncio.TimeoutError:
raise ConnectionError(f"Timed out connecting to server at {hp}.")
except Exception as err:
raise ConnectionError(f"Failed connecting to server at {hp}: {err}.")
finally:
if self.lock and self.lock.locked() and not self.client.connected:
self.lock.release()

log.debug(f"Connected to {hp}.")

Expand All @@ -260,45 +254,53 @@ async def disconnect(self):

log.debug(f"Disonnected from {self.host}:{self.port}.")

if self.lock and self.lock.locked():
self.lock.release()

async def __aenter__(self):
"""Initialises the connection to the server."""

await self.connect()
# Acquire the lock, but also don't allow it to block for too long.
try:
await asyncio.wait_for(self.lock.acquire(), 10)
except asyncio.TimeoutError:
log.warning("Timed out waiting for lock to be released. Forcing release.")
self.lock.release()
await self.lock.acquire()

try:
await self.connect()
except Exception:
if self.lock.locked():
self.lock.release()

raise

async def __aexit__(self, exc_type, exc, tb):
"""Closes the connection to the server."""

await self.disconnect()
try:
await self.disconnect()
finally:
if self.lock.locked():
self.lock.release()

async def get_all(self):
"""Returns a dictionary with all the registers."""

names = results = []

NRETRIES: int = 3
for retry in range(NRETRIES):
async with self:
names = [name for name in self]
tasks = [elem.get(open_connection=False) for elem in self.values()]

results = await asyncio.gather(*tasks, return_exceptions=True)

if any([isinstance(result, Exception) for result in results]):
log.debug("Exceptions received while getting all registers.")
async with self:
names = [name for name in self]
tasks = [elem.get(open_connection=False) for elem in self.values()]

if retry < NRETRIES:
continue
results = await asyncio.gather(*tasks, return_exceptions=True)

for result in results:
if isinstance(result, Exception):
log.warning(
"Failed retrieving all registers. First exception:",
exc_info=result,
)
break
if any([isinstance(result, Exception) for result in results]):
for result in results:
if isinstance(result, Exception):
log.warning(
"Failed retrieving all registers. First exception:",
exc_info=result,
)
break

return {
names[ii]: results[ii]
Expand Down

0 comments on commit 80e49eb

Please sign in to comment.