refactor of Tunnel() for unit testing; create unit tests for Tunnel()

This commit is contained in:
Mike Kinney
2021-12-30 21:24:32 -08:00
parent 3f307880f9
commit d366e74e86
7 changed files with 205 additions and 122 deletions

View File

@@ -806,10 +806,11 @@ def initParser():
have_tunnel = platform.system() == 'Linux'
if have_tunnel:
parser.add_argument('--tunnel',
action='store_true', help="Create a TUN tunnel device for forwarding IP packets over the mesh")
parser.add_argument(
"--subnet", dest='tunnel_net', help="Sets the local-end subnet address for the TUN IP bridge", default=None)
parser.add_argument('--tunnel', action='store_true',
help="Create a TUN tunnel device for forwarding IP packets over the mesh")
parser.add_argument("--subnet", dest='tunnel_net',
help="Sets the local-end subnet address for the TUN IP bridge. (ex: 10.115' which is the default)",
default=None)
parser.set_defaults(deprecated=None)

View File

@@ -30,6 +30,7 @@ class Globals:
self.target_node = None
self.channel_index = None
self.logfile = None
self.tunnelInstance = None
def reset(self):
"""Reset all of our globals. If you add a member, add it to this method, too."""
@@ -37,6 +38,8 @@ class Globals:
self.parser = None
self.target_node = None
self.channel_index = None
self.logfile = None
self.tunnelInstance = None
# setters
def set_args(self, args):
@@ -59,6 +62,10 @@ class Globals:
"""Set the logfile"""
self.logfile = logfile
def set_tunnelInstance(self, tunnelInstance):
"""Set the tunnelInstance"""
self.tunnelInstance = tunnelInstance
# getters
def get_args(self):
"""Get args"""
@@ -79,3 +86,7 @@ class Globals:
def get_logfile(self):
"""Get logfile"""
return self.logfile
def get_tunnelInstance(self):
"""Get tunnelInstance"""
return self.tunnelInstance

View File

@@ -39,7 +39,10 @@ class StreamInterface(MeshInterface):
self._wantExit = False
# FIXME, figure out why daemon=True causes reader thread to exit too early
self._rxThread = threading.Thread(target=self.__reader, args=(), daemon=True)
if noProto:
self._rxThread = None
else:
self._rxThread = threading.Thread(target=self.__reader, args=(), daemon=True)
MeshInterface.__init__(self, debugOut=debugOut, noProto=noProto)
@@ -65,7 +68,8 @@ class StreamInterface(MeshInterface):
if not self.noProto:
time.sleep(0.1) # wait 100ms to give device time to start running
self._rxThread.start()
if not self.noProto:
self._rxThread.start()
self._startConfig()
@@ -117,8 +121,9 @@ class StreamInterface(MeshInterface):
# pyserial cancel_read doesn't seem to work, therefore we ask the
# reader thread to close things for us
self._wantExit = True
if self._rxThread != threading.current_thread():
self._rxThread.join() # wait for it to exit
if not self.noProto:
if self._rxThread != threading.current_thread():
self._rxThread.join() # wait for it to exit
def __reader(self):
"""The reader thread that reads bytes from our stream"""

View File

@@ -1,7 +1,7 @@
"""Meshtastic unit tests for stream_interface.py"""
import logging
import re
#import re
from unittest.mock import MagicMock
import pytest
@@ -34,48 +34,48 @@ def test_StreamInterface_with_noProto(caplog, reset_globals):
assert data == test_data
# Note: This takes a bit, so moving from unit to slow
# Tip: If you want to see the print output, run with '-s' flag:
# pytest -s meshtastic/tests/test_stream_interface.py::test_sendToRadioImpl
@pytest.mark.unitslow
def test_sendToRadioImpl(caplog, reset_globals):
"""Test _sendToRadioImpl()"""
# def add_header(b):
# """Add header stuffs for radio"""
# bufLen = len(b)
# header = bytes([START1, START2, (bufLen >> 8) & 0xff, bufLen & 0xff])
# return header + b
# captured raw bytes of a Heltec2.1 radio with 2 channels (primary and a secondary channel named "gpio")
raw_1_my_info = b'\x1a,\x08\xdc\x8c\xd5\xc5\x02\x18\r2\x0e1.2.49.5354c49P\x15]\xe1%\x17Eh\xe0\xa7\x12p\xe8\x9d\x01x\x08\x90\x01\x01'
raw_2_node_info = b'"9\x08\xdc\x8c\xd5\xc5\x02\x12(\n\t!28b5465c\x12\x0cUnknown 465c\x1a\x03?5C"\x06$o(\xb5F\\0\n\x1a\x02 1%M<\xc6a'
# pylint: disable=C0301
raw_3_node_info = b'"C\x08\xa4\x8c\xd5\xc5\x02\x12(\n\t!28b54624\x12\x0cUnknown 4624\x1a\x03?24"\x06$o(\xb5F$0\n\x1a\x07 5MH<\xc6a%G<\xc6a=\x00\x00\xc0@'
raw_4_complete = b'@\xcf\xe5\xd1\x8c\x0e'
# pylint: disable=C0301
raw_5_prefs = b'Z6\r\\F\xb5(\x15\\F\xb5("\x1c\x08\x06\x12\x13*\x11\n\x0f0\x84\x07P\xac\x02\x88\x01\x01\xb0\t#\xb8\t\x015]$\xddk5\xd5\x7f!b=M<\xc6aP\x03`F'
# pylint: disable=C0301
raw_6_channel0 = b'Z.\r\\F\xb5(\x15\\F\xb5("\x14\x08\x06\x12\x0b:\t\x12\x05\x18\x01"\x01\x01\x18\x015^$\xddk5\xd6\x7f!b=M<\xc6aP\x03`F'
# pylint: disable=C0301
raw_7_channel1 = b'ZS\r\\F\xb5(\x15\\F\xb5("9\x08\x06\x120:.\x08\x01\x12(" \xb4&\xb3\xc7\x06\xd8\xe39%\xba\xa5\xee\x8eH\x06\xf6\xf4H\xe8\xd5\xc1[ao\xb5Y\\\xb4"\xafmi*\x04gpio\x18\x025_$\xddk5\xd7\x7f!b=M<\xc6aP\x03`F'
raw_8_channel2 = b'Z)\r\\F\xb5(\x15\\F\xb5("\x0f\x08\x06\x12\x06:\x04\x08\x02\x12\x005`$\xddk5\xd8\x7f!b=M<\xc6aP\x03`F'
raw_blank = b''
test_data = b'hello'
stream = MagicMock()
#stream.read.return_value = add_header(test_data)
stream.read.side_effect = [ raw_1_my_info, raw_2_node_info, raw_3_node_info, raw_4_complete,
raw_5_prefs, raw_6_channel0, raw_7_channel1, raw_8_channel2,
raw_blank, raw_blank]
toRadio = MagicMock()
toRadio.SerializeToString.return_value = test_data
with caplog.at_level(logging.DEBUG):
iface = StreamInterface(noProto=True, connectNow=False)
iface.stream = stream
iface.connect()
iface._sendToRadioImpl(toRadio)
assert re.search(r'Sending: ', caplog.text, re.MULTILINE)
assert re.search(r'reading character', caplog.text, re.MULTILINE)
assert re.search(r'In reader loop', caplog.text, re.MULTILINE)
print(caplog.text)
## Note: This takes a bit, so moving from unit to slow
## Tip: If you want to see the print output, run with '-s' flag:
## pytest -s meshtastic/tests/test_stream_interface.py::test_sendToRadioImpl
#@pytest.mark.unitslow
#def test_sendToRadioImpl(caplog, reset_globals):
# """Test _sendToRadioImpl()"""
#
## def add_header(b):
## """Add header stuffs for radio"""
## bufLen = len(b)
## header = bytes([START1, START2, (bufLen >> 8) & 0xff, bufLen & 0xff])
## return header + b
#
# # captured raw bytes of a Heltec2.1 radio with 2 channels (primary and a secondary channel named "gpio")
# raw_1_my_info = b'\x1a,\x08\xdc\x8c\xd5\xc5\x02\x18\r2\x0e1.2.49.5354c49P\x15]\xe1%\x17Eh\xe0\xa7\x12p\xe8\x9d\x01x\x08\x90\x01\x01'
# raw_2_node_info = b'"9\x08\xdc\x8c\xd5\xc5\x02\x12(\n\t!28b5465c\x12\x0cUnknown 465c\x1a\x03?5C"\x06$o(\xb5F\\0\n\x1a\x02 1%M<\xc6a'
# # pylint: disable=C0301
# raw_3_node_info = b'"C\x08\xa4\x8c\xd5\xc5\x02\x12(\n\t!28b54624\x12\x0cUnknown 4624\x1a\x03?24"\x06$o(\xb5F$0\n\x1a\x07 5MH<\xc6a%G<\xc6a=\x00\x00\xc0@'
# raw_4_complete = b'@\xcf\xe5\xd1\x8c\x0e'
# # pylint: disable=C0301
# raw_5_prefs = b'Z6\r\\F\xb5(\x15\\F\xb5("\x1c\x08\x06\x12\x13*\x11\n\x0f0\x84\x07P\xac\x02\x88\x01\x01\xb0\t#\xb8\t\x015]$\xddk5\xd5\x7f!b=M<\xc6aP\x03`F'
# # pylint: disable=C0301
# raw_6_channel0 = b'Z.\r\\F\xb5(\x15\\F\xb5("\x14\x08\x06\x12\x0b:\t\x12\x05\x18\x01"\x01\x01\x18\x015^$\xddk5\xd6\x7f!b=M<\xc6aP\x03`F'
# # pylint: disable=C0301
# raw_7_channel1 = b'ZS\r\\F\xb5(\x15\\F\xb5("9\x08\x06\x120:.\x08\x01\x12(" \xb4&\xb3\xc7\x06\xd8\xe39%\xba\xa5\xee\x8eH\x06\xf6\xf4H\xe8\xd5\xc1[ao\xb5Y\\\xb4"\xafmi*\x04gpio\x18\x025_$\xddk5\xd7\x7f!b=M<\xc6aP\x03`F'
# raw_8_channel2 = b'Z)\r\\F\xb5(\x15\\F\xb5("\x0f\x08\x06\x12\x06:\x04\x08\x02\x12\x005`$\xddk5\xd8\x7f!b=M<\xc6aP\x03`F'
# raw_blank = b''
#
# test_data = b'hello'
# stream = MagicMock()
# #stream.read.return_value = add_header(test_data)
# stream.read.side_effect = [ raw_1_my_info, raw_2_node_info, raw_3_node_info, raw_4_complete,
# raw_5_prefs, raw_6_channel0, raw_7_channel1, raw_8_channel2,
# raw_blank, raw_blank]
# toRadio = MagicMock()
# toRadio.SerializeToString.return_value = test_data
# with caplog.at_level(logging.DEBUG):
# iface = StreamInterface(noProto=True, connectNow=False)
# iface.stream = stream
# iface.connect()
# iface._sendToRadioImpl(toRadio)
# assert re.search(r'Sending: ', caplog.text, re.MULTILINE)
# assert re.search(r'reading character', caplog.text, re.MULTILINE)
# assert re.search(r'In reader loop', caplog.text, re.MULTILINE)
# print(caplog.text)

View File

@@ -0,0 +1,57 @@
"""Meshtastic unit tests for tunnel.py"""
import re
import logging
from unittest.mock import patch, MagicMock
import pytest
from ..tcp_interface import TCPInterface
from ..tunnel import Tunnel
from ..globals import Globals
@pytest.mark.unit
@patch('platform.system')
def test_Tunnel_on_non_linux_system(mock_platform_system, reset_globals):
"""Test that we cannot instantiate a Tunnel on a non Linux system"""
a_mock = MagicMock()
a_mock.return_value = 'notLinux'
mock_platform_system.side_effect = a_mock
with patch('socket.socket') as mock_socket:
with pytest.raises(Exception) as pytest_wrapped_e:
iface = TCPInterface(hostname='localhost', noProto=True)
tun = Tunnel(iface)
assert tun == Globals.getInstance().get_tunnelInstance()
assert pytest_wrapped_e.type == Exception
assert mock_socket.called
@pytest.mark.unit
@patch('platform.system')
def test_Tunnel_without_interface(mock_platform_system, reset_globals):
"""Test that we can not instantiate a Tunnel without a valid interface"""
a_mock = MagicMock()
a_mock.return_value = 'Linux'
mock_platform_system.side_effect = a_mock
with pytest.raises(Exception) as pytest_wrapped_e:
Tunnel(None)
assert pytest_wrapped_e.type == Exception
@pytest.mark.unit
@patch('platform.system')
def test_Tunnel_with_interface(mock_platform_system, caplog, reset_globals, iface_with_nodes):
"""Test that we can not instantiate a Tunnel without a valid interface"""
iface = iface_with_nodes
iface.myInfo.my_node_num = 2475227164
a_mock = MagicMock()
a_mock.return_value = 'Linux'
mock_platform_system.side_effect = a_mock
with caplog.at_level(logging.WARNING):
with patch('socket.socket'):
Tunnel(iface)
iface.close()
assert re.search(r'Not creating a TapDevice()', caplog.text, re.MULTILINE)
assert re.search(r'Not starting TUN reader', caplog.text, re.MULTILINE)
assert re.search(r'Not sending packet', caplog.text, re.MULTILINE)

View File

@@ -17,61 +17,27 @@
import logging
import threading
import platform
from pubsub import pub
from pytap2 import TapDevice
from . import portnums_pb2
# A new non standard log level that is lower level than DEBUG
LOG_TRACE = 5
# fixme - find a way to move onTunnelReceive inside of the class
tunnelInstance = None
"""A list of chatty UDP services we should never accidentally
forward to our slow network"""
udpBlacklist = {
1900, # SSDP
5353, # multicast DNS
}
"""A list of TCP services to block"""
tcpBlacklist = {}
"""A list of protocols we ignore"""
protocolBlacklist = {
0x02, # IGMP
0x80, # Service-Specific Connection-Oriented Protocol in a Multilink and Connectionless Environment
}
def hexstr(barray):
"""Print a string of hex digits"""
return ":".join('{:02x}'.format(x) for x in barray)
def ipstr(barray):
"""Print a string of ip digits"""
return ".".join('{}'.format(x) for x in barray)
def readnet_u16(p, offset):
"""Read big endian u16 (network byte order)"""
return p[offset] * 256 + p[offset + 1]
from .util import ipstr, readnet_u16
from .globals import Globals
def onTunnelReceive(packet, interface):
"""Callback for received tunneled messages from mesh
FIXME figure out how to do closures with methods in python"""
"""Callback for received tunneled messages from mesh."""
our_globals = Globals.getInstance()
tunnelInstance = our_globals.get_tunnelInstance()
tunnelInstance.onReceive(packet)
class Tunnel:
"""A TUN based IP tunnel over meshtastic"""
def __init__(self, iface, subnet=None, netmask="255.255.0.0"):
def __init__(self, iface, subnet='10.115', netmask="255.255.0.0"):
"""
Constructor
@@ -79,35 +45,67 @@ class Tunnel:
subnet is used to construct our network number (normally 10.115.x.x)
"""
if subnet is None:
subnet = "10.115"
if not iface:
raise Exception("Tunnel() must have a interface")
self.iface = iface
self.subnetPrefix = subnet
global tunnelInstance
tunnelInstance = self
if platform.system() != 'Linux':
raise Exception("Tunnel() can only be run instantiated on a Linux system")
our_globals = Globals.getInstance()
our_globals.set_tunnelInstance(self)
"""A list of chatty UDP services we should never accidentally
forward to our slow network"""
self.udpBlacklist = {
1900, # SSDP
5353, # multicast DNS
}
"""A list of TCP services to block"""
self.tcpBlacklist = {}
"""A list of protocols we ignore"""
self.protocolBlacklist = {
0x02, # IGMP
0x80, # Service-Specific Connection-Oriented Protocol in a Multilink and Connectionless Environment
}
# A new non standard log level that is lower level than DEBUG
self.LOG_TRACE = 5
# TODO: check if root?
logging.info("Starting IP to mesh tunnel (you must be root for this *pre-alpha* "\
"feature to work). Mesh members:")
pub.subscribe(onTunnelReceive, "meshtastic.receive.data.IP_TUNNEL_APP")
myAddr = self._nodeNumToIp(self.iface.myInfo.my_node_num)
for node in self.iface.nodes.values():
nodeId = node["user"]["id"]
ip = self._nodeNumToIp(node["num"])
logging.info(f"Node { nodeId } has IP address { ip }")
if self.iface.nodes:
for node in self.iface.nodes.values():
nodeId = node["user"]["id"]
ip = self._nodeNumToIp(node["num"])
logging.info(f"Node { nodeId } has IP address { ip }")
logging.debug("creating TUN device with MTU=200")
# FIXME - figure out real max MTU, it should be 240 - the overhead bytes for SubPacket and Data
self.tun = TapDevice(name="mesh")
self.tun.up()
self.tun.ifconfig(address=myAddr, netmask=netmask, mtu=200)
logging.debug(f"starting TUN reader, our IP address is {myAddr}")
self._rxThread = threading.Thread(
target=self.__tunReader, args=(), daemon=True)
self._rxThread.start()
self.tun = None
if self.iface.noProto:
logging.warning(f"Not creating a TapDevice() because it is disabled by noProto")
else:
self.tun = TapDevice(name="mesh")
self.tun.up()
self.tun.ifconfig(address=myAddr, netmask=netmask, mtu=200)
self._rxThread = None
if self.iface.noProto:
logging.warning(f"Not starting TUN reader because it is disabled by noProto")
else:
logging.debug(f"starting TUN reader, our IP address is {myAddr}")
self._rxThread = threading.Thread(target=self.__tunReader, args=(), daemon=True)
self._rxThread.start()
def onReceive(self, packet):
"""onReceive"""
@@ -115,8 +113,7 @@ class Tunnel:
if packet["from"] == self.iface.myInfo.my_node_num:
logging.debug("Ignoring message we sent")
else:
logging.debug(
f"Received mesh tunnel message type={type(p)} len={len(p)}")
logging.debug(f"Received mesh tunnel message type={type(p)} len={len(p)}")
# we don't really need to check for filtering here (sender should have checked),
# but this provides useful debug printing on types of packets received
if not self._shouldFilterPacket(p):
@@ -129,10 +126,9 @@ class Tunnel:
destAddr = p[16:20]
subheader = 20
ignore = False # Assume we will be forwarding the packet
if protocol in protocolBlacklist:
if protocol in self.protocolBlacklist:
ignore = True
logging.log(
LOG_TRACE, f"Ignoring blacklisted protocol 0x{protocol:02x}")
logging.log(self.LOG_TRACE, f"Ignoring blacklisted protocol 0x{protocol:02x}")
elif protocol == 0x01: # ICMP
icmpType = p[20]
icmpCode = p[21]
@@ -145,19 +141,17 @@ class Tunnel:
elif protocol == 0x11: # UDP
srcport = readnet_u16(p, subheader)
destport = readnet_u16(p, subheader + 2)
if destport in udpBlacklist:
if destport in self.udpBlacklist:
ignore = True
logging.log(
LOG_TRACE, f"ignoring blacklisted UDP port {destport}")
logging.log(self.LOG_TRACE, f"ignoring blacklisted UDP port {destport}")
else:
logging.debug(
f"forwarding udp srcport={srcport}, destport={destport}")
logging.debug(f"forwarding udp srcport={srcport}, destport={destport}")
elif protocol == 0x06: # TCP
srcport = readnet_u16(p, subheader)
destport = readnet_u16(p, subheader + 2)
if destport in tcpBlacklist:
if destport in self.tcpBlacklist:
ignore = True
logging.log(LOG_TRACE, f"ignoring blacklisted TCP port {destport}")
logging.log(self.LOG_TRACE, f"ignoring blacklisted TCP port {destport}")
else:
logging.debug(f"forwarding tcp srcport={srcport}, destport={destport}")
else:

View File

@@ -210,3 +210,18 @@ def remove_keys_from_dict(keys, adict):
if key in adict:
del newdict[key]
return newdict
def hexstr(barray):
"""Print a string of hex digits"""
return ":".join('{:02x}'.format(x) for x in barray)
def ipstr(barray):
"""Print a string of ip digits"""
return ".".join('{}'.format(x) for x in barray)
def readnet_u16(p, offset):
"""Read big endian u16 (network byte order)"""
return p[offset] * 256 + p[offset + 1]