mirror of
https://github.com/meshtastic/python.git
synced 2026-06-03 13:19:44 -04:00
Give TCPInterface reconnect logic on write errors
* Moving to socket.sendall() is safer, as sendall will send the entire buffer, while send() would return the number of bytes sent and require being called multiple times if the buffer was full. * On exceptions: reconnect to the server. * On reconnection: make sure using a lock that there isn't a race between the readers and the writers triggering a reconnect.
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
@@ -12,6 +13,7 @@ from meshtastic.stream_interface import StreamInterface
|
||||
DEFAULT_TCP_PORT = 4403
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TCPInterface(StreamInterface):
|
||||
"""Interface class for meshtastic devices over a TCP link"""
|
||||
|
||||
@@ -19,10 +21,10 @@ class TCPInterface(StreamInterface):
|
||||
self,
|
||||
hostname: str,
|
||||
debugOut=None,
|
||||
noProto: bool=False,
|
||||
connectNow: bool=True,
|
||||
portNumber: int=DEFAULT_TCP_PORT,
|
||||
noNodes:bool=False,
|
||||
noProto: bool = False,
|
||||
connectNow: bool = True,
|
||||
portNumber: int = DEFAULT_TCP_PORT,
|
||||
noNodes: bool = False,
|
||||
timeout: int = 300,
|
||||
):
|
||||
"""Constructor, opens a connection to a specified IP address/hostname
|
||||
@@ -38,13 +40,20 @@ class TCPInterface(StreamInterface):
|
||||
self.portNumber: int = portNumber
|
||||
|
||||
self.socket: Optional[socket.socket] = None
|
||||
self.reconnectLock = threading.Lock()
|
||||
|
||||
if connectNow:
|
||||
self.myConnect()
|
||||
else:
|
||||
self.socket = None
|
||||
|
||||
super().__init__(debugOut=debugOut, noProto=noProto, connectNow=connectNow, noNodes=noNodes, timeout=timeout)
|
||||
super().__init__(
|
||||
debugOut=debugOut,
|
||||
noProto=noProto,
|
||||
connectNow=connectNow,
|
||||
noNodes=noNodes,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
rep = f"TCPInterface({self.hostname!r}"
|
||||
@@ -69,29 +78,35 @@ class TCPInterface(StreamInterface):
|
||||
self.socket.shutdown(socket.SHUT_RDWR)
|
||||
|
||||
def myConnect(self) -> None:
|
||||
"""Connect to socket"""
|
||||
logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe]
|
||||
"""Connect to socket."""
|
||||
logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe]
|
||||
server_address = (self.hostname, self.portNumber)
|
||||
self.socket = socket.create_connection(server_address)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close a connection to the device"""
|
||||
"""Close a connection to the device."""
|
||||
logger.debug("Closing TCP stream")
|
||||
super().close()
|
||||
# Sometimes the socket read might be blocked in the reader thread.
|
||||
# Therefore we force the shutdown by closing the socket here
|
||||
self._wantExit = True
|
||||
if self.socket is not None:
|
||||
with contextlib.suppress(Exception): # Ignore errors in shutdown, because we might have a race with the server
|
||||
with contextlib.suppress(
|
||||
Exception
|
||||
): # Ignore errors in shutdown, because we might have a race with the server
|
||||
self._socket_shutdown()
|
||||
self.socket.close()
|
||||
|
||||
self.socket = None
|
||||
|
||||
def _writeBytes(self, b: bytes) -> None:
|
||||
"""Write an array of bytes to our stream and flush"""
|
||||
"""Write an array of bytes to our stream"""
|
||||
if self.socket is not None:
|
||||
self.socket.send(b)
|
||||
try:
|
||||
self.socket.sendall(b)
|
||||
except OSError as e:
|
||||
logger.error(f"Socket send error, reconnecting: {e}")
|
||||
self._reconnect()
|
||||
|
||||
def _readBytes(self, length) -> Optional[bytes]:
|
||||
"""Read an array of bytes from our stream"""
|
||||
@@ -99,19 +114,28 @@ class TCPInterface(StreamInterface):
|
||||
data = self.socket.recv(length)
|
||||
# empty byte indicates a disconnected socket,
|
||||
# we need to handle it to avoid an infinite loop reading from null socket
|
||||
if data == b'':
|
||||
logger.debug("dead socket, re-connecting")
|
||||
# cleanup and reconnect socket without breaking reader thread
|
||||
with contextlib.suppress(Exception):
|
||||
self._socket_shutdown()
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
time.sleep(1)
|
||||
self.myConnect()
|
||||
self._startConfig()
|
||||
return None
|
||||
if data == b"":
|
||||
logger.debug("Closed socket, re-connecting")
|
||||
self._reconnect()
|
||||
return data
|
||||
|
||||
# no socket, break reader thread
|
||||
self._wantExit = True
|
||||
return None
|
||||
|
||||
def _reconnect(self) -> None:
|
||||
"""Reconnect to the socket"""
|
||||
# Save the socket reference before attempting to acquire the lock.
|
||||
sock = self.socket
|
||||
with self.reconnectLock:
|
||||
# Don't reconnect: someone else already did it.
|
||||
if sock is not self.socket:
|
||||
return
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
self._socket_shutdown()
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
time.sleep(1)
|
||||
self.myConnect()
|
||||
self._startConfig()
|
||||
|
||||
@@ -54,3 +54,44 @@ def test_TCPInterface_without_connecting():
|
||||
with patch("socket.socket"):
|
||||
iface = TCPInterface(hostname="localhost", noProto=True, connectNow=False)
|
||||
assert iface.socket is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_TCPInterface_reconnect():
|
||||
"""Test that _reconnect correctly reconnects"""
|
||||
with patch("socket.socket") as mock_socket:
|
||||
with patch("time.sleep"):
|
||||
iface = TCPInterface(hostname="localhost", noProto=True)
|
||||
old_socket = iface.socket
|
||||
assert old_socket is not None
|
||||
|
||||
iface._reconnect()
|
||||
|
||||
assert old_socket.close.called
|
||||
# We expect socket class to be instantiated at least twice (init + reconnect)
|
||||
assert mock_socket.call_count >= 2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_TCPInterface_writeBytes_reconnects():
|
||||
"""Test that _writeBytes calls _reconnect on OSError"""
|
||||
with patch("socket.socket"):
|
||||
iface = TCPInterface(hostname="localhost", noProto=True)
|
||||
iface.socket.sendall.side_effect = OSError("Broken pipe")
|
||||
|
||||
with patch.object(iface, '_reconnect') as mock_reconnect:
|
||||
iface._writeBytes(b"some data")
|
||||
mock_reconnect.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_TCPInterface_readBytes_reconnects():
|
||||
"""Test that _readBytes calls _reconnect on empty bytes"""
|
||||
with patch("socket.socket"):
|
||||
iface = TCPInterface(hostname="localhost", noProto=True)
|
||||
# Mock the socket instance on the interface
|
||||
iface.socket.recv.return_value = b''
|
||||
|
||||
with patch.object(iface, '_reconnect') as mock_reconnect:
|
||||
iface._readBytes(10)
|
||||
mock_reconnect.assert_called_once()
|
||||
|
||||
Submodule protobufs deleted from 77c8329a59
Reference in New Issue
Block a user