From 5e4ad10fe0961b8a386af0f3d6ece944879a1a76 Mon Sep 17 00:00:00 2001 From: Adam Outler Date: Fri, 28 Nov 2025 21:13:20 +0000 Subject: [PATCH] Tidy up --- server/api_server/api_server_start.py | 2 + server/api_server/mcp_routes.py | 63 +++++---- server/api_server/tools_routes.py | 42 +++--- .../api_endpoints/test_mcp_tools_endpoints.py | 128 ++++++++++-------- 4 files changed, 128 insertions(+), 107 deletions(-) diff --git a/server/api_server/api_server_start.py b/server/api_server/api_server_start.py index 77d3af36..39772c65 100755 --- a/server/api_server/api_server_start.py +++ b/server/api_server/api_server_start.py @@ -118,6 +118,7 @@ def log_request_info(): data = request.get_data(as_text=True) mylog("none", [f"[HTTP] Body: {data[:1000]}"]) + @app.errorhandler(404) def not_found(error): response = { @@ -797,6 +798,7 @@ def start_server(graphql_port, app_state): # Update the state to indicate the server has started app_state = updateState("Process: Idle", None, None, None, 1) + if __name__ == "__main__": # This block is for running the server directly for testing purposes # In production, start_server is called from api.py diff --git a/server/api_server/mcp_routes.py b/server/api_server/mcp_routes.py index fa7ddba0..b65539a6 100644 --- a/server/api_server/mcp_routes.py +++ b/server/api_server/mcp_routes.py @@ -1,3 +1,5 @@ +"""MCP bridge routes exposing NetAlertX tool endpoints via JSON-RPC.""" + import json import uuid import queue @@ -16,11 +18,13 @@ openapi_spec_cache = None API_BASE_URL = "http://localhost:20212/api/tools" + def get_openapi_spec(): + """Fetch and cache the tools OpenAPI specification from the local API server.""" global openapi_spec_cache if openapi_spec_cache: return openapi_spec_cache - + try: # Fetch from local server # We use localhost because this code runs on the server @@ -32,7 +36,9 @@ def get_openapi_spec(): print(f"Error fetching OpenAPI spec: {e}") return None + def map_openapi_to_mcp_tools(spec): + """Convert OpenAPI paths into MCP tool descriptors.""" tools = [] if not spec or "paths" not in spec: return tools @@ -49,14 +55,14 @@ def map_openapi_to_mcp_tools(spec): "required": [] } } - + # Extract parameters from requestBody if present if "requestBody" in details: content = details["requestBody"].get("content", {}) if "application/json" in content: schema = content["application/json"].get("schema", {}) tool["inputSchema"] = schema - + # Extract parameters from 'parameters' list (query/path params) - simplistic support if "parameters" in details: for param in details["parameters"]: @@ -73,12 +79,14 @@ def map_openapi_to_mcp_tools(spec): tools.append(tool) return tools + def process_mcp_request(data): + """Handle incoming MCP JSON-RPC requests and route them to tools.""" method = data.get("method") msg_id = data.get("id") - + response = None - + if method == "initialize": response = { "jsonrpc": "2.0", @@ -94,11 +102,11 @@ def process_mcp_request(data): } } } - + elif method == "notifications/initialized": # No response needed for notification pass - + elif method == "tools/list": spec = get_openapi_spec() tools = map_openapi_to_mcp_tools(spec) @@ -109,17 +117,17 @@ def process_mcp_request(data): "tools": tools } } - + elif method == "tools/call": params = data.get("params", {}) tool_name = params.get("name") tool_args = params.get("arguments", {}) - + # Find the endpoint for this tool spec = get_openapi_spec() target_path = None target_method = None - + if spec and "paths" in spec: for path, methods in spec["paths"].items(): for m, details in methods.items(): @@ -129,7 +137,7 @@ def process_mcp_request(data): break if target_path: break - + if target_path: try: # Make the request to the local API @@ -139,16 +147,16 @@ def process_mcp_request(data): } if "Authorization" in request.headers: headers["Authorization"] = request.headers["Authorization"] - + url = f"{API_BASE_URL}{target_path}" - + if target_method == "POST": api_res = requests.post(url, json=tool_args, headers=headers) elif target_method == "GET": api_res = requests.get(url, params=tool_args, headers=headers) else: api_res = None - + if api_res: content = [] try: @@ -157,12 +165,12 @@ def process_mcp_request(data): "type": "text", "text": json.dumps(json_content, indent=2) }) - except: + except (ValueError, json.JSONDecodeError): content.append({ "type": "text", "text": api_res.text }) - + is_error = api_res.status_code >= 400 response = { "jsonrpc": "2.0", @@ -194,27 +202,29 @@ def process_mcp_request(data): "id": msg_id, "error": {"code": -32601, "message": f"Tool {tool_name} not found"} } - + elif method == "ping": response = { "jsonrpc": "2.0", "id": msg_id, "result": {} } - + else: # Unknown method - if msg_id: # Only respond if it's a request (has id) + if msg_id: # Only respond if it's a request (has id) response = { "jsonrpc": "2.0", "id": msg_id, "error": {"code": -32601, "message": "Method not found"} } - + return response + @mcp_bp.route('/sse', methods=['GET', 'POST']) def handle_sse(): + """Expose an SSE endpoint that streams MCP responses to connected clients.""" if request.method == 'POST': # Handle verification or keep-alive pings try: @@ -228,25 +238,26 @@ def handle_sse(): return "", 202 except Exception: pass - + return jsonify({"status": "ok", "message": "MCP SSE endpoint active"}), 200 session_id = uuid.uuid4().hex q = queue.Queue() - + with sessions_lock: sessions[session_id] = q def stream(): + """Yield SSE messages for queued MCP responses until the client disconnects.""" # Send the endpoint event # The client should POST to /api/mcp/messages?session_id= yield f"event: endpoint\ndata: /api/mcp/messages?session_id={session_id}\n\n" - + try: while True: try: # Wait for messages - message = q.get(timeout=20) # Keep-alive timeout + message = q.get(timeout=20) # Keep-alive timeout yield f"event: message\ndata: {json.dumps(message)}\n\n" except queue.Empty: # Send keep-alive comment @@ -258,12 +269,14 @@ def handle_sse(): return Response(stream_with_context(stream()), mimetype='text/event-stream') + @mcp_bp.route('/messages', methods=['POST']) def handle_messages(): + """Receive MCP JSON-RPC messages and enqueue responses for an SSE session.""" session_id = request.args.get('session_id') if not session_id: return jsonify({"error": "Missing session_id"}), 400 - + with sessions_lock: if session_id not in sessions: return jsonify({"error": "Session not found"}), 404 diff --git a/server/api_server/tools_routes.py b/server/api_server/tools_routes.py index bca3b543..5e84f781 100644 --- a/server/api_server/tools_routes.py +++ b/server/api_server/tools_routes.py @@ -1,6 +1,4 @@ import subprocess -import shutil -import os import re from datetime import datetime, timedelta from flask import Blueprint, request, jsonify @@ -39,25 +37,25 @@ def trigger_scan(): cmd = [] if scan_type == 'arp': # ARP scan usually requires sudo or root, assuming container runs as root or has caps - cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection + cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection if target: - cmd = ["arp-scan", target] + cmd = ["arp-scan", target] elif scan_type == 'nmap_fast': cmd = ["nmap", "-F"] if target: cmd.append(target) else: # Default to local subnet if possible, or error if not easily determined - # For now, let's require target for nmap if not easily deducible, - # or try to get it from settings. + # For now, let's require target for nmap if not easily deducible, + # or try to get it from settings. # NetAlertX usually knows its subnet. # Let's try to get the scan subnet from settings if not provided. scan_subnets = get_setting_value("SCAN_SUBNETS") if scan_subnets: - # Take the first one for now - cmd.append(scan_subnets.split(',')[0].strip()) + # Take the first one for now + cmd.append(scan_subnets.split(',')[0].strip()) else: - return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400 + return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400 elif scan_type == 'nmap_deep': cmd = ["nmap", "-A", "-T4"] if target: @@ -65,9 +63,9 @@ def trigger_scan(): else: scan_subnets = get_setting_value("SCAN_SUBNETS") if scan_subnets: - cmd.append(scan_subnets.split(',')[0].strip()) + cmd.append(scan_subnets.split(',')[0].strip()) else: - return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400 + return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400 try: # Run the command @@ -212,7 +210,7 @@ def get_open_ports(): text=True, check=True ) - + # Parse output for open ports open_ports = [] for line in result.stdout.split('\n'): @@ -250,10 +248,10 @@ def get_network_topology(): try: cur.execute("SELECT devName, devMac, devParentMAC, devParentPort, devVendor FROM Devices") rows = cur.fetchall() - + nodes = [] links = [] - + for row in rows: nodes.append({ "id": row['devMac'], @@ -299,16 +297,16 @@ def get_recent_alerts(): cutoff_str = cutoff.strftime('%Y-%m-%d %H:%M:%S') cur.execute(""" - SELECT eve_DateTime, eve_EventType, eve_MAC, eve_IP, devName - FROM Events + SELECT eve_DateTime, eve_EventType, eve_MAC, eve_IP, devName + FROM Events LEFT JOIN Devices ON Events.eve_MAC = Devices.devMac - WHERE eve_DateTime > ? + WHERE eve_DateTime > ? ORDER BY eve_DateTime DESC """, (cutoff_str,)) - + rows = cur.fetchall() alerts = [dict(row) for row in rows] - + return jsonify(alerts) except Exception as e: return jsonify({"error": str(e)}), 500 @@ -338,10 +336,10 @@ def set_device_alias(): try: cur.execute("UPDATE Devices SET devName = ? WHERE devMac = ?", (alias, mac)) conn.commit() - + if cur.rowcount == 0: return jsonify({"error": "Device not found"}), 404 - + return jsonify({"success": True, "message": f"Device {mac} renamed to {alias}"}) except Exception as e: return jsonify({"error": str(e)}), 500 @@ -379,7 +377,7 @@ def wol_wake_device(): else: return jsonify({"error": f"Could not resolve MAC for IP {ip}"}), 404 except Exception as e: - return jsonify({"error": f"Database error: {str(e)}"}), 500 + return jsonify({"error": f"Database error: {str(e)}"}), 500 finally: conn.close() diff --git a/test/api_endpoints/test_mcp_tools_endpoints.py b/test/api_endpoints/test_mcp_tools_endpoints.py index 22bd136d..fd221879 100644 --- a/test/api_endpoints/test_mcp_tools_endpoints.py +++ b/test/api_endpoints/test_mcp_tools_endpoints.py @@ -2,7 +2,6 @@ import sys import os import pytest from unittest.mock import patch, MagicMock -import subprocess INSTALL_PATH = os.getenv('NETALERTX_APP', '/app') sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) @@ -10,20 +9,23 @@ sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) from helper import get_setting_value # noqa: E402 from api_server.api_server_start import app # noqa: E402 + @pytest.fixture(scope="session") def api_token(): return get_setting_value("API_TOKEN") + @pytest.fixture def client(): with app.test_client() as client: yield client + def auth_headers(token): return {"Authorization": f"Bearer {token}"} -# --- get_device_info Tests --- +# --- get_device_info Tests --- @patch('api_server.tools_routes.get_temp_db_connection') def test_get_device_info_ip_partial(mock_db_conn, client, api_token): """Test get_device_info with partial IP search.""" @@ -33,53 +35,55 @@ def test_get_device_info_ip_partial(mock_db_conn, client, api_token): {"devName": "Test Device", "devMac": "AA:BB:CC:DD:EE:FF", "devLastIP": "192.168.1.50"} ] mock_db_conn.return_value.cursor.return_value = mock_cursor - + payload = {"query": ".50"} - response = client.post('/api/tools/get_device_info', - json=payload, + response = client.post('/api/tools/get_device_info', + json=payload, headers=auth_headers(api_token)) - + assert response.status_code == 200 devices = response.get_json() assert len(devices) == 1 assert devices[0]["devLastIP"] == "192.168.1.50" - + # Verify SQL query included 3 params (MAC, Name, IP) args, _ = mock_cursor.execute.call_args assert args[0].count("?") == 3 assert len(args[1]) == 3 -# --- trigger_scan Tests --- +# --- trigger_scan Tests --- @patch('subprocess.run') def test_trigger_scan_nmap_fast(mock_run, client, api_token): """Test trigger_scan with nmap_fast.""" mock_run.return_value = MagicMock(stdout="Scan completed", returncode=0) - + payload = {"scan_type": "nmap_fast", "target": "192.168.1.1"} - response = client.post('/api/tools/trigger_scan', - json=payload, + response = client.post('/api/tools/trigger_scan', + json=payload, headers=auth_headers(api_token)) - + assert response.status_code == 200 data = response.get_json() assert data["success"] is True assert "nmap -F 192.168.1.1" in data["command"] mock_run.assert_called_once() + @patch('subprocess.run') def test_trigger_scan_invalid_type(mock_run, client, api_token): """Test trigger_scan with invalid scan_type.""" payload = {"scan_type": "invalid_type", "target": "192.168.1.1"} - response = client.post('/api/tools/trigger_scan', - json=payload, + response = client.post('/api/tools/trigger_scan', + json=payload, headers=auth_headers(api_token)) - + assert response.status_code == 400 mock_run.assert_not_called() # --- get_open_ports Tests --- + @patch('subprocess.run') def test_get_open_ports_ip(mock_run, client, api_token): """Test get_open_ports with an IP address.""" @@ -94,12 +98,12 @@ PORT STATE SERVICE Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds """ mock_run.return_value = MagicMock(stdout=mock_output, returncode=0) - + payload = {"target": "192.168.1.1"} - response = client.post('/api/tools/get_open_ports', - json=payload, + response = client.post('/api/tools/get_open_ports', + json=payload, headers=auth_headers(api_token)) - + assert response.status_code == 200 data = response.get_json() assert data["success"] is True @@ -107,6 +111,7 @@ Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds assert data["open_ports"][0]["port"] == 22 assert data["open_ports"][1]["service"] == "http" + @patch('api_server.tools_routes.get_temp_db_connection') @patch('subprocess.run') def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token): @@ -115,24 +120,24 @@ def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token): mock_cursor = MagicMock() mock_cursor.fetchone.return_value = {"devLastIP": "192.168.1.50"} mock_db_conn.return_value.cursor.return_value = mock_cursor - + # Mock Nmap output mock_run.return_value = MagicMock(stdout="80/tcp open http", returncode=0) - + payload = {"target": "AA:BB:CC:DD:EE:FF"} - response = client.post('/api/tools/get_open_ports', - json=payload, + response = client.post('/api/tools/get_open_ports', + json=payload, headers=auth_headers(api_token)) - + assert response.status_code == 200 data = response.get_json() - assert data["target"] == "192.168.1.50" # Should be resolved IP + assert data["target"] == "192.168.1.50" # Should be resolved IP mock_run.assert_called_once() args, _ = mock_run.call_args assert "192.168.1.50" in args[0] -# --- get_network_topology Tests --- +# --- get_network_topology Tests --- @patch('api_server.tools_routes.get_temp_db_connection') def test_get_network_topology(mock_db_conn, client, api_token): """Test get_network_topology.""" @@ -142,10 +147,10 @@ def test_get_network_topology(mock_db_conn, client, api_token): {"devName": "Device1", "devMac": "BB:BB:BB:BB:BB:BB", "devParentMAC": "AA:AA:AA:AA:AA:AA", "devParentPort": "eth1", "devVendor": "VendorB"} ] mock_db_conn.return_value.cursor.return_value = mock_cursor - - response = client.get('/api/tools/get_network_topology', + + response = client.get('/api/tools/get_network_topology', headers=auth_headers(api_token)) - + assert response.status_code == 200 data = response.get_json() assert len(data["nodes"]) == 2 @@ -153,8 +158,8 @@ def test_get_network_topology(mock_db_conn, client, api_token): assert data["links"][0]["source"] == "AA:AA:AA:AA:AA:AA" assert data["links"][0]["target"] == "BB:BB:BB:BB:BB:BB" -# --- get_recent_alerts Tests --- +# --- get_recent_alerts Tests --- @patch('api_server.tools_routes.get_temp_db_connection') def test_get_recent_alerts(mock_db_conn, client, api_token): """Test get_recent_alerts.""" @@ -163,67 +168,69 @@ def test_get_recent_alerts(mock_db_conn, client, api_token): {"eve_DateTime": "2023-10-27 10:00:00", "eve_EventType": "New Device", "eve_MAC": "CC:CC:CC:CC:CC:CC", "eve_IP": "192.168.1.100", "devName": "Unknown"} ] mock_db_conn.return_value.cursor.return_value = mock_cursor - + payload = {"hours": 24} - response = client.post('/api/tools/get_recent_alerts', - json=payload, + response = client.post('/api/tools/get_recent_alerts', + json=payload, headers=auth_headers(api_token)) - + assert response.status_code == 200 data = response.get_json() assert len(data) == 1 assert data[0]["eve_EventType"] == "New Device" -# --- set_device_alias Tests --- +# --- set_device_alias Tests --- @patch('api_server.tools_routes.get_temp_db_connection') def test_set_device_alias(mock_db_conn, client, api_token): """Test set_device_alias.""" mock_cursor = MagicMock() - mock_cursor.rowcount = 1 # Simulate successful update + mock_cursor.rowcount = 1 # Simulate successful update mock_db_conn.return_value.cursor.return_value = mock_cursor - + payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"} - response = client.post('/api/tools/set_device_alias', - json=payload, + response = client.post('/api/tools/set_device_alias', + json=payload, headers=auth_headers(api_token)) - + assert response.status_code == 200 data = response.get_json() assert data["success"] is True + @patch('api_server.tools_routes.get_temp_db_connection') def test_set_device_alias_not_found(mock_db_conn, client, api_token): """Test set_device_alias when device is not found.""" mock_cursor = MagicMock() - mock_cursor.rowcount = 0 # Simulate no rows updated + mock_cursor.rowcount = 0 # Simulate no rows updated mock_db_conn.return_value.cursor.return_value = mock_cursor - + payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"} - response = client.post('/api/tools/set_device_alias', - json=payload, + response = client.post('/api/tools/set_device_alias', + json=payload, headers=auth_headers(api_token)) - + assert response.status_code == 404 -# --- wol_wake_device Tests --- +# --- wol_wake_device Tests --- @patch('subprocess.run') def test_wol_wake_device(mock_subprocess, client, api_token): """Test wol_wake_device.""" mock_subprocess.return_value.stdout = "Sending magic packet to 255.255.255.255:9 with AA:BB:CC:DD:EE:FF" mock_subprocess.return_value.returncode = 0 - + payload = {"mac": "AA:BB:CC:DD:EE:FF"} - response = client.post('/api/tools/wol_wake_device', - json=payload, + response = client.post('/api/tools/wol_wake_device', + json=payload, headers=auth_headers(api_token)) - + assert response.status_code == 200 data = response.get_json() assert data["success"] is True mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True) + @patch('api_server.tools_routes.get_temp_db_connection') @patch('subprocess.run') def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token): @@ -238,38 +245,39 @@ def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token) mock_subprocess.return_value.returncode = 0 payload = {"ip": "192.168.1.50"} - response = client.post('/api/tools/wol_wake_device', - json=payload, + response = client.post('/api/tools/wol_wake_device', + json=payload, headers=auth_headers(api_token)) - + assert response.status_code == 200 data = response.get_json() assert data["success"] is True assert "AA:BB:CC:DD:EE:FF" in data["message"] - + # Verify DB lookup mock_cursor.execute.assert_called_with("SELECT devMac FROM Devices WHERE devLastIP = ?", ("192.168.1.50",)) - + # Verify subprocess call mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True) + def test_wol_wake_device_invalid_mac(client, api_token): """Test wol_wake_device with invalid MAC.""" payload = {"mac": "invalid-mac"} - response = client.post('/api/tools/wol_wake_device', - json=payload, + response = client.post('/api/tools/wol_wake_device', + json=payload, headers=auth_headers(api_token)) - + assert response.status_code == 400 -# --- openapi_spec Tests --- +# --- openapi_spec Tests --- def test_openapi_spec(client): """Test openapi_spec endpoint contains new paths.""" response = client.get('/api/tools/openapi.json') assert response.status_code == 200 spec = response.get_json() - + # Check for new endpoints assert "/trigger_scan" in spec["paths"] assert "/get_open_ports" in spec["paths"]