Give TCPInterface reconnect logic on write errors

* Moving to socket.sendall() is safer, as sendall will send the entire
   buffer, while send() would return the number of bytes sent and
   require being called multiple times if the buffer was full.
 * On exceptions: reconnect to the server.
 * On reconnection: make sure using a lock that there isn't a race
   between the readers and the writers triggering a reconnect.
This commit is contained in:
Stephen Thorne
2026-02-05 21:53:18 +01:00
parent cdf893e618
commit 07172f88f3
3 changed files with 87 additions and 23 deletions

View File

@@ -4,6 +4,7 @@
import contextlib
import logging
import socket
import threading
import time
from typing import Optional
@@ -12,6 +13,7 @@ from meshtastic.stream_interface import StreamInterface
DEFAULT_TCP_PORT = 4403
logger = logging.getLogger(__name__)
class TCPInterface(StreamInterface):
"""Interface class for meshtastic devices over a TCP link"""
@@ -19,10 +21,10 @@ class TCPInterface(StreamInterface):
self,
hostname: str,
debugOut=None,
noProto: bool=False,
connectNow: bool=True,
portNumber: int=DEFAULT_TCP_PORT,
noNodes:bool=False,
noProto: bool = False,
connectNow: bool = True,
portNumber: int = DEFAULT_TCP_PORT,
noNodes: bool = False,
timeout: int = 300,
):
"""Constructor, opens a connection to a specified IP address/hostname
@@ -38,13 +40,20 @@ class TCPInterface(StreamInterface):
self.portNumber: int = portNumber
self.socket: Optional[socket.socket] = None
self.reconnectLock = threading.Lock()
if connectNow:
self.myConnect()
else:
self.socket = None
super().__init__(debugOut=debugOut, noProto=noProto, connectNow=connectNow, noNodes=noNodes, timeout=timeout)
super().__init__(
debugOut=debugOut,
noProto=noProto,
connectNow=connectNow,
noNodes=noNodes,
timeout=timeout,
)
def __repr__(self):
rep = f"TCPInterface({self.hostname!r}"
@@ -69,29 +78,35 @@ class TCPInterface(StreamInterface):
self.socket.shutdown(socket.SHUT_RDWR)
def myConnect(self) -> None:
"""Connect to socket"""
logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe]
"""Connect to socket."""
logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe]
server_address = (self.hostname, self.portNumber)
self.socket = socket.create_connection(server_address)
def close(self) -> None:
"""Close a connection to the device"""
"""Close a connection to the device."""
logger.debug("Closing TCP stream")
super().close()
# Sometimes the socket read might be blocked in the reader thread.
# Therefore we force the shutdown by closing the socket here
self._wantExit = True
if self.socket is not None:
with contextlib.suppress(Exception): # Ignore errors in shutdown, because we might have a race with the server
with contextlib.suppress(
Exception
): # Ignore errors in shutdown, because we might have a race with the server
self._socket_shutdown()
self.socket.close()
self.socket = None
def _writeBytes(self, b: bytes) -> None:
"""Write an array of bytes to our stream and flush"""
"""Write an array of bytes to our stream"""
if self.socket is not None:
self.socket.send(b)
try:
self.socket.sendall(b)
except OSError as e:
logger.error(f"Socket send error, reconnecting: {e}")
self._reconnect()
def _readBytes(self, length) -> Optional[bytes]:
"""Read an array of bytes from our stream"""
@@ -99,19 +114,28 @@ class TCPInterface(StreamInterface):
data = self.socket.recv(length)
# empty byte indicates a disconnected socket,
# we need to handle it to avoid an infinite loop reading from null socket
if data == b'':
logger.debug("dead socket, re-connecting")
# cleanup and reconnect socket without breaking reader thread
with contextlib.suppress(Exception):
self._socket_shutdown()
self.socket.close()
self.socket = None
time.sleep(1)
self.myConnect()
self._startConfig()
return None
if data == b"":
logger.debug("Closed socket, re-connecting")
self._reconnect()
return data
# no socket, break reader thread
self._wantExit = True
return None
def _reconnect(self) -> None:
"""Reconnect to the socket"""
# Save the socket reference before attempting to acquire the lock.
sock = self.socket
with self.reconnectLock:
# Don't reconnect: someone else already did it.
if sock is not self.socket:
return
with contextlib.suppress(Exception):
self._socket_shutdown()
self.socket.close()
self.socket = None
time.sleep(1)
self.myConnect()
self._startConfig()

View File

@@ -54,3 +54,44 @@ def test_TCPInterface_without_connecting():
with patch("socket.socket"):
iface = TCPInterface(hostname="localhost", noProto=True, connectNow=False)
assert iface.socket is None
@pytest.mark.unit
def test_TCPInterface_reconnect():
"""Test that _reconnect correctly reconnects"""
with patch("socket.socket") as mock_socket:
with patch("time.sleep"):
iface = TCPInterface(hostname="localhost", noProto=True)
old_socket = iface.socket
assert old_socket is not None
iface._reconnect()
assert old_socket.close.called
# We expect socket class to be instantiated at least twice (init + reconnect)
assert mock_socket.call_count >= 2
@pytest.mark.unit
def test_TCPInterface_writeBytes_reconnects():
"""Test that _writeBytes calls _reconnect on OSError"""
with patch("socket.socket"):
iface = TCPInterface(hostname="localhost", noProto=True)
iface.socket.sendall.side_effect = OSError("Broken pipe")
with patch.object(iface, '_reconnect') as mock_reconnect:
iface._writeBytes(b"some data")
mock_reconnect.assert_called_once()
@pytest.mark.unit
def test_TCPInterface_readBytes_reconnects():
"""Test that _readBytes calls _reconnect on empty bytes"""
with patch("socket.socket"):
iface = TCPInterface(hostname="localhost", noProto=True)
# Mock the socket instance on the interface
iface.socket.recv.return_value = b''
with patch.object(iface, '_reconnect') as mock_reconnect:
iface._readBytes(10)
mock_reconnect.assert_called_once()

Submodule protobufs deleted from 77c8329a59