mirror of
https://github.com/meshtastic/python.git
synced 2026-06-02 12:45:00 -04:00
@@ -4,6 +4,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
@@ -12,6 +13,7 @@ from meshtastic.stream_interface import StreamInterface
|
||||
DEFAULT_TCP_PORT = 4403
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TCPInterface(StreamInterface):
|
||||
"""Interface class for meshtastic devices over a TCP link"""
|
||||
|
||||
@@ -19,10 +21,10 @@ class TCPInterface(StreamInterface):
|
||||
self,
|
||||
hostname: str,
|
||||
debugOut=None,
|
||||
noProto: bool=False,
|
||||
connectNow: bool=True,
|
||||
portNumber: int=DEFAULT_TCP_PORT,
|
||||
noNodes:bool=False,
|
||||
noProto: bool = False,
|
||||
connectNow: bool = True,
|
||||
portNumber: int = DEFAULT_TCP_PORT,
|
||||
noNodes: bool = False,
|
||||
timeout: int = 300,
|
||||
):
|
||||
"""Constructor, opens a connection to a specified IP address/hostname
|
||||
@@ -35,8 +37,15 @@ class TCPInterface(StreamInterface):
|
||||
self.portNumber: int = portNumber
|
||||
|
||||
self.socket: Optional[socket.socket] = None
|
||||
self.reconnectLock = threading.Lock()
|
||||
|
||||
super().__init__(debugOut=debugOut, noProto=noProto, connectNow=connectNow, noNodes=noNodes, timeout=timeout)
|
||||
super().__init__(
|
||||
debugOut=debugOut,
|
||||
noProto=noProto,
|
||||
connectNow=connectNow,
|
||||
noNodes=noNodes,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
rep = f"TCPInterface({self.hostname!r}"
|
||||
@@ -67,18 +76,20 @@ class TCPInterface(StreamInterface):
|
||||
|
||||
def myConnect(self) -> None:
|
||||
"""Connect to socket (without attempting to start the interface's receive thread)"""
|
||||
logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe]
|
||||
logger.debug(f"Connecting to {self.hostname}") # type: ignore[str-bytes-safe]
|
||||
server_address = (self.hostname, self.portNumber)
|
||||
self.socket = socket.create_connection(server_address)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close a connection to the device"""
|
||||
"""Close a connection to the device."""
|
||||
logger.debug("Closing TCP stream")
|
||||
# Sometimes the socket read might be blocked in the reader thread.
|
||||
# Therefore force a shutdown first to unblock reader thread reads.
|
||||
self._wantExit = True
|
||||
if self.socket is not None:
|
||||
with contextlib.suppress(Exception): # Ignore errors in shutdown, because we might have a race with the server
|
||||
with contextlib.suppress(
|
||||
Exception
|
||||
): # Ignore errors in shutdown, because we might have a race with the server
|
||||
self._socket_shutdown()
|
||||
with contextlib.suppress(Exception):
|
||||
self.socket.close()
|
||||
@@ -87,9 +98,15 @@ class TCPInterface(StreamInterface):
|
||||
super().close()
|
||||
|
||||
def _writeBytes(self, b: bytes) -> None:
|
||||
"""Write an array of bytes to our stream and flush"""
|
||||
"""Write an array of bytes to our stream"""
|
||||
if self.socket is not None:
|
||||
self.socket.send(b)
|
||||
try:
|
||||
self.socket.sendall(b)
|
||||
except OSError as e:
|
||||
logger.error(f"Socket send error, reconnecting: {e}")
|
||||
if not self._wantExit:
|
||||
self._reconnect()
|
||||
raise
|
||||
|
||||
def _readBytes(self, length) -> Optional[bytes]:
|
||||
"""Read an array of bytes from our stream"""
|
||||
@@ -97,19 +114,36 @@ class TCPInterface(StreamInterface):
|
||||
data = self.socket.recv(length)
|
||||
# empty byte indicates a disconnected socket,
|
||||
# we need to handle it to avoid an infinite loop reading from null socket
|
||||
if data == b'':
|
||||
logger.debug("dead socket, re-connecting")
|
||||
# cleanup and reconnect socket without breaking reader thread
|
||||
with contextlib.suppress(Exception):
|
||||
self._socket_shutdown()
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
time.sleep(1)
|
||||
self.myConnect()
|
||||
self._startConfig()
|
||||
return None
|
||||
if data == b"":
|
||||
logger.debug("Closed socket, re-connecting")
|
||||
if not self._wantExit:
|
||||
self._reconnect()
|
||||
return data
|
||||
|
||||
# no socket, break reader thread
|
||||
self._wantExit = True
|
||||
return None
|
||||
|
||||
def _reconnect(self) -> None:
|
||||
"""Reconnect to the socket"""
|
||||
# Save the socket reference before attempting to acquire the lock.
|
||||
sock = self.socket
|
||||
start_config = False
|
||||
with self.reconnectLock:
|
||||
if self._wantExit:
|
||||
return
|
||||
# Don't reconnect: someone else already did it.
|
||||
if sock is not self.socket:
|
||||
return
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
self._socket_shutdown()
|
||||
if self.socket is not None:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
time.sleep(1)
|
||||
self.myConnect()
|
||||
start_config = True
|
||||
|
||||
if start_config and not self._wantExit and self.socket is not None:
|
||||
self._startConfig()
|
||||
|
||||
@@ -76,3 +76,44 @@ def test_TCPInterface_close_shutdowns_socket_before_super_close():
|
||||
assert call_order == ["shutdown", "super_close"]
|
||||
sock.close.assert_called_once()
|
||||
assert iface.socket is None
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_TCPInterface_reconnect():
|
||||
"""Test that _reconnect correctly reconnects"""
|
||||
with patch("socket.socket") as mock_socket:
|
||||
with patch("time.sleep"):
|
||||
iface = TCPInterface(hostname="localhost", noProto=True)
|
||||
old_socket = iface.socket
|
||||
assert old_socket is not None
|
||||
|
||||
iface._reconnect()
|
||||
|
||||
assert old_socket.close.called
|
||||
# We expect socket class to be instantiated at least twice (init + reconnect)
|
||||
assert mock_socket.call_count >= 2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_TCPInterface_writeBytes_reconnects():
|
||||
"""Test that _writeBytes reconnects and re-raises on OSError."""
|
||||
with patch("socket.socket"):
|
||||
iface = TCPInterface(hostname="localhost", noProto=True)
|
||||
iface.socket.sendall.side_effect = OSError("Broken pipe")
|
||||
|
||||
with patch.object(iface, "_reconnect") as mock_reconnect:
|
||||
with pytest.raises(OSError, match="Broken pipe"):
|
||||
iface._writeBytes(b"some data")
|
||||
mock_reconnect.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_TCPInterface_readBytes_reconnects():
|
||||
"""Test that _readBytes calls _reconnect on empty bytes"""
|
||||
iface = TCPInterface(hostname="localhost", noProto=True, connectNow=False)
|
||||
iface.socket = MagicMock()
|
||||
iface.socket.recv.return_value = b""
|
||||
|
||||
with patch.object(iface, "_reconnect") as mock_reconnect:
|
||||
iface._readBytes(10)
|
||||
mock_reconnect.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user