diff --git a/meshtastic/tcp_interface.py b/meshtastic/tcp_interface.py index 7ba6306..3118f4d 100644 --- a/meshtastic/tcp_interface.py +++ b/meshtastic/tcp_interface.py @@ -4,6 +4,7 @@ import contextlib import logging import socket +import threading import time from typing import Optional @@ -12,6 +13,7 @@ from meshtastic.stream_interface import StreamInterface DEFAULT_TCP_PORT = 4403 logger = logging.getLogger(__name__) + class TCPInterface(StreamInterface): """Interface class for meshtastic devices over a TCP link""" @@ -19,10 +21,10 @@ class TCPInterface(StreamInterface): self, hostname: str, debugOut=None, - noProto: bool=False, - connectNow: bool=True, - portNumber: int=DEFAULT_TCP_PORT, - noNodes:bool=False, + noProto: bool = False, + connectNow: bool = True, + portNumber: int = DEFAULT_TCP_PORT, + noNodes: bool = False, timeout: int = 300, ): """Constructor, opens a connection to a specified IP address/hostname @@ -35,8 +37,15 @@ class TCPInterface(StreamInterface): self.portNumber: int = portNumber self.socket: Optional[socket.socket] = None + self.reconnectLock = threading.Lock() - super().__init__(debugOut=debugOut, noProto=noProto, connectNow=connectNow, noNodes=noNodes, timeout=timeout) + super().__init__( + debugOut=debugOut, + noProto=noProto, + connectNow=connectNow, + noNodes=noNodes, + timeout=timeout, + ) def __repr__(self): rep = f"TCPInterface({self.hostname!r}" @@ -67,18 +76,20 @@ class TCPInterface(StreamInterface): def myConnect(self) -> None: """Connect to socket (without attempting to start the interface's receive thread)""" - logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe] + logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe] server_address = (self.hostname, self.portNumber) self.socket = socket.create_connection(server_address) def close(self) -> None: - """Close a connection to the device""" + """Close a connection to the device.""" logger.debug("Closing TCP stream") # Sometimes the socket read might be blocked in the reader thread. # Therefore force a shutdown first to unblock reader thread reads. self._wantExit = True if self.socket is not None: - with contextlib.suppress(Exception): # Ignore errors in shutdown, because we might have a race with the server + with contextlib.suppress( + Exception + ): # Ignore errors in shutdown, because we might have a race with the server self._socket_shutdown() with contextlib.suppress(Exception): self.socket.close() @@ -87,9 +98,15 @@ class TCPInterface(StreamInterface): super().close() def _writeBytes(self, b: bytes) -> None: - """Write an array of bytes to our stream and flush""" + """Write an array of bytes to our stream""" if self.socket is not None: - self.socket.send(b) + try: + self.socket.sendall(b) + except OSError as e: + logger.error(f"Socket send error, reconnecting: {e}") + if not self._wantExit: + self._reconnect() + raise def _readBytes(self, length) -> Optional[bytes]: """Read an array of bytes from our stream""" @@ -97,19 +114,36 @@ class TCPInterface(StreamInterface): data = self.socket.recv(length) # empty byte indicates a disconnected socket, # we need to handle it to avoid an infinite loop reading from null socket - if data == b'': - logger.debug("dead socket, re-connecting") - # cleanup and reconnect socket without breaking reader thread - with contextlib.suppress(Exception): - self._socket_shutdown() - self.socket.close() - self.socket = None - time.sleep(1) - self.myConnect() - self._startConfig() - return None + if data == b"": + logger.debug("Closed socket, re-connecting") + if not self._wantExit: + self._reconnect() return data # no socket, break reader thread self._wantExit = True return None + + def _reconnect(self) -> None: + """Reconnect to the socket""" + # Save the socket reference before attempting to acquire the lock. + sock = self.socket + start_config = False + with self.reconnectLock: + if self._wantExit: + return + # Don't reconnect: someone else already did it. + if sock is not self.socket: + return + + with contextlib.suppress(Exception): + self._socket_shutdown() + if self.socket is not None: + self.socket.close() + self.socket = None + time.sleep(1) + self.myConnect() + start_config = True + + if start_config and not self._wantExit and self.socket is not None: + self._startConfig() diff --git a/meshtastic/tests/test_tcp_interface.py b/meshtastic/tests/test_tcp_interface.py index 129ad24..4f0fec9 100644 --- a/meshtastic/tests/test_tcp_interface.py +++ b/meshtastic/tests/test_tcp_interface.py @@ -76,3 +76,44 @@ def test_TCPInterface_close_shutdowns_socket_before_super_close(): assert call_order == ["shutdown", "super_close"] sock.close.assert_called_once() assert iface.socket is None + + +@pytest.mark.unit +def test_TCPInterface_reconnect(): + """Test that _reconnect correctly reconnects""" + with patch("socket.socket") as mock_socket: + with patch("time.sleep"): + iface = TCPInterface(hostname="localhost", noProto=True) + old_socket = iface.socket + assert old_socket is not None + + iface._reconnect() + + assert old_socket.close.called + # We expect socket class to be instantiated at least twice (init + reconnect) + assert mock_socket.call_count >= 2 + + +@pytest.mark.unit +def test_TCPInterface_writeBytes_reconnects(): + """Test that _writeBytes reconnects and re-raises on OSError.""" + with patch("socket.socket"): + iface = TCPInterface(hostname="localhost", noProto=True) + iface.socket.sendall.side_effect = OSError("Broken pipe") + + with patch.object(iface, "_reconnect") as mock_reconnect: + with pytest.raises(OSError, match="Broken pipe"): + iface._writeBytes(b"some data") + mock_reconnect.assert_called_once() + + +@pytest.mark.unit +def test_TCPInterface_readBytes_reconnects(): + """Test that _readBytes calls _reconnect on empty bytes""" + iface = TCPInterface(hostname="localhost", noProto=True, connectNow=False) + iface.socket = MagicMock() + iface.socket.recv.return_value = b"" + + with patch.object(iface, "_reconnect") as mock_reconnect: + iface._readBytes(10) + mock_reconnect.assert_called_once()