diff --git a/server/api_server/api_server_start.py b/server/api_server/api_server_start.py index 980dcbd0..f250ca5e 100755 --- a/server/api_server/api_server_start.py +++ b/server/api_server/api_server_start.py @@ -71,6 +71,7 @@ from messaging.in_app import ( # noqa: E402 [flake8 lint suppression] delete_notification, mark_notification_as_read ) +from .tools_routes import tools_bp # noqa: E402 [flake8 lint suppression] # Flask application app = Flask(__name__) @@ -87,7 +88,8 @@ CORS( r"/dbquery/*": {"origins": "*"}, r"/messaging/*": {"origins": "*"}, r"/events/*": {"origins": "*"}, - r"/logs/*": {"origins": "*"} + r"/logs/*": {"origins": "*"}, + r"/api/tools/*": {"origins": "*"} }, supports_credentials=True, allow_headers=["Authorization", "Content-Type"], @@ -97,6 +99,17 @@ CORS( # ------------------------------------------------------------------- # Custom handler for 404 - Route not found # ------------------------------------------------------------------- +@app.before_request +def log_request_info(): + """Log details of every incoming request.""" + # Filter out noisy requests if needed, but user asked for drastic logging + mylog("none", [f"[HTTP] {request.method} {request.path} from {request.remote_addr}"]) + mylog("none", [f"[HTTP] Headers: {dict(request.headers)}"]) + if request.method == "POST": + # Be careful with large bodies, but log first 1000 chars + data = request.get_data(as_text=True) + mylog("none", [f"[HTTP] Body: {data[:1000]}"]) + @app.errorhandler(404) def not_found(error): response = { @@ -775,3 +788,10 @@ def start_server(graphql_port, app_state): # Update the state to indicate the server has started app_state = updateState("Process: Idle", None, None, None, 1) +# Register Blueprints +app.register_blueprint(tools_bp, url_prefix='/api/tools') + +if __name__ == "__main__": + # This block is for running the server directly for testing purposes + # In production, start_server is called from api.py + pass diff --git a/server/api_server/tools_routes.py b/server/api_server/tools_routes.py new file mode 100644 index 00000000..bca3b543 --- /dev/null +++ b/server/api_server/tools_routes.py @@ -0,0 +1,687 @@ +import subprocess +import shutil +import os +import re +from datetime import datetime, timedelta +from flask import Blueprint, request, jsonify +import sqlite3 +from helper import get_setting_value +from database import get_temp_db_connection + +tools_bp = Blueprint('tools', __name__) + + +def check_auth(): + """Check API_TOKEN authorization.""" + token = request.headers.get("Authorization") + expected_token = f"Bearer {get_setting_value('API_TOKEN')}" + return token == expected_token + + +@tools_bp.route('/trigger_scan', methods=['POST']) +def trigger_scan(): + """ + Forces NetAlertX to run a specific scan type immediately. + Arguments: scan_type (Enum: arp, nmap_fast, nmap_deep), target (optional IP/CIDR) + """ + if not check_auth(): + return jsonify({"error": "Unauthorized"}), 401 + + data = request.get_json() + scan_type = data.get('scan_type', 'nmap_fast') + target = data.get('target') + + # Validate scan_type + if scan_type not in ['arp', 'nmap_fast', 'nmap_deep']: + return jsonify({"error": "Invalid scan_type. Must be 'arp', 'nmap_fast', or 'nmap_deep'"}), 400 + + # Determine command + cmd = [] + if scan_type == 'arp': + # ARP scan usually requires sudo or root, assuming container runs as root or has caps + cmd = ["arp-scan", "--localnet", "--interface=eth0"] # Defaulting to eth0, might need detection + if target: + cmd = ["arp-scan", target] + elif scan_type == 'nmap_fast': + cmd = ["nmap", "-F"] + if target: + cmd.append(target) + else: + # Default to local subnet if possible, or error if not easily determined + # For now, let's require target for nmap if not easily deducible, + # or try to get it from settings. + # NetAlertX usually knows its subnet. + # Let's try to get the scan subnet from settings if not provided. + scan_subnets = get_setting_value("SCAN_SUBNETS") + if scan_subnets: + # Take the first one for now + cmd.append(scan_subnets.split(',')[0].strip()) + else: + return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400 + elif scan_type == 'nmap_deep': + cmd = ["nmap", "-A", "-T4"] + if target: + cmd.append(target) + else: + scan_subnets = get_setting_value("SCAN_SUBNETS") + if scan_subnets: + cmd.append(scan_subnets.split(',')[0].strip()) + else: + return jsonify({"error": "Target is required and no default SCAN_SUBNETS found"}), 400 + + try: + # Run the command + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + return jsonify({ + "success": True, + "scan_type": scan_type, + "command": " ".join(cmd), + "output": result.stdout.strip().split('\n') + }) + except subprocess.CalledProcessError as e: + return jsonify({ + "success": False, + "error": "Scan failed", + "details": e.stderr.strip() + }), 500 + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@tools_bp.route('/list_devices', methods=['POST']) +def list_devices(): + """List all devices.""" + if not check_auth(): + return jsonify({"error": "Unauthorized"}), 401 + + conn = get_temp_db_connection() + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + try: + cur.execute("SELECT devName, devMac, devLastIP as devIP, devVendor, devFirstConnection, devLastConnection FROM Devices ORDER BY devFirstConnection DESC") + rows = cur.fetchall() + devices = [dict(row) for row in rows] + return jsonify(devices) + except Exception as e: + return jsonify({"error": str(e)}), 500 + finally: + conn.close() + + +@tools_bp.route('/get_device_info', methods=['POST']) +def get_device_info(): + """Get detailed info for a specific device.""" + if not check_auth(): + return jsonify({"error": "Unauthorized"}), 401 + + data = request.get_json() + if not data or 'query' not in data: + return jsonify({"error": "Missing 'query' parameter"}), 400 + + query = data['query'] + + conn = get_temp_db_connection() + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + try: + # Search by MAC, Name, or partial IP + sql = "SELECT * FROM Devices WHERE devMac LIKE ? OR devName LIKE ? OR devLastIP LIKE ?" + cur.execute(sql, (f"%{query}%", f"%{query}%", f"%{query}%")) + rows = cur.fetchall() + + if not rows: + return jsonify({"message": "No devices found"}), 404 + + devices = [dict(row) for row in rows] + return jsonify(devices) + except Exception as e: + return jsonify({"error": str(e)}), 500 + finally: + conn.close() + + +@tools_bp.route('/get_latest_device', methods=['POST']) +def get_latest_device(): + """Get full details of the most recently discovered device.""" + if not check_auth(): + return jsonify({"error": "Unauthorized"}), 401 + + conn = get_temp_db_connection() + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + try: + # Get the device with the most recent devFirstConnection + cur.execute("SELECT * FROM Devices ORDER BY devFirstConnection DESC LIMIT 1") + row = cur.fetchone() + + if not row: + return jsonify({"message": "No devices found"}), 404 + + # Return as a list to be consistent with other endpoints + return jsonify([dict(row)]) + except Exception as e: + return jsonify({"error": str(e)}), 500 + finally: + conn.close() + + +@tools_bp.route('/get_open_ports', methods=['POST']) +def get_open_ports(): + """ + Specific query for the port-scan results of a target. + Arguments: target (IP or MAC) + """ + if not check_auth(): + return jsonify({"error": "Unauthorized"}), 401 + + data = request.get_json() + target = data.get('target') + + if not target: + return jsonify({"error": "Target is required"}), 400 + + # If MAC is provided, try to resolve to IP + if re.match(r"^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$", target): + conn = get_temp_db_connection() + conn.row_factory = sqlite3.Row + cur = conn.cursor() + try: + cur.execute("SELECT devLastIP FROM Devices WHERE devMac = ?", (target,)) + row = cur.fetchone() + if row and row['devLastIP']: + target = row['devLastIP'] + else: + return jsonify({"error": f"Could not resolve IP for MAC {target}"}), 404 + finally: + conn.close() + + try: + # Run nmap -F for fast port scan + cmd = ["nmap", "-F", target] + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True + ) + + # Parse output for open ports + open_ports = [] + for line in result.stdout.split('\n'): + if '/tcp' in line and 'open' in line: + parts = line.split('/') + port = parts[0].strip() + service = line.split()[2] if len(line.split()) > 2 else "unknown" + open_ports.append({"port": int(port), "service": service}) + + return jsonify({ + "success": True, + "target": target, + "open_ports": open_ports, + "raw_output": result.stdout.strip().split('\n') + }) + + except subprocess.CalledProcessError as e: + return jsonify({"success": False, "error": "Port scan failed", "details": e.stderr.strip()}), 500 + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@tools_bp.route('/get_network_topology', methods=['GET']) +def get_network_topology(): + """ + Returns the "Parent/Child" relationships. + """ + if not check_auth(): + return jsonify({"error": "Unauthorized"}), 401 + + conn = get_temp_db_connection() + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + try: + cur.execute("SELECT devName, devMac, devParentMAC, devParentPort, devVendor FROM Devices") + rows = cur.fetchall() + + nodes = [] + links = [] + + for row in rows: + nodes.append({ + "id": row['devMac'], + "name": row['devName'], + "vendor": row['devVendor'] + }) + if row['devParentMAC']: + links.append({ + "source": row['devParentMAC'], + "target": row['devMac'], + "port": row['devParentPort'] + }) + + return jsonify({ + "nodes": nodes, + "links": links + }) + except Exception as e: + return jsonify({"error": str(e)}), 500 + finally: + conn.close() + + +@tools_bp.route('/get_recent_alerts', methods=['POST']) +def get_recent_alerts(): + """ + Fetches the last N system alerts. + Arguments: hours (lookback period, default 24) + """ + if not check_auth(): + return jsonify({"error": "Unauthorized"}), 401 + + data = request.get_json() + hours = data.get('hours', 24) + + conn = get_temp_db_connection() + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + try: + # Calculate cutoff time + cutoff = datetime.now() - timedelta(hours=int(hours)) + cutoff_str = cutoff.strftime('%Y-%m-%d %H:%M:%S') + + cur.execute(""" + SELECT eve_DateTime, eve_EventType, eve_MAC, eve_IP, devName + FROM Events + LEFT JOIN Devices ON Events.eve_MAC = Devices.devMac + WHERE eve_DateTime > ? + ORDER BY eve_DateTime DESC + """, (cutoff_str,)) + + rows = cur.fetchall() + alerts = [dict(row) for row in rows] + + return jsonify(alerts) + except Exception as e: + return jsonify({"error": str(e)}), 500 + finally: + conn.close() + + +@tools_bp.route('/set_device_alias', methods=['POST']) +def set_device_alias(): + """ + Updates the name (alias) of a device. + Arguments: mac, alias + """ + if not check_auth(): + return jsonify({"error": "Unauthorized"}), 401 + + data = request.get_json() + mac = data.get('mac') + alias = data.get('alias') + + if not mac or not alias: + return jsonify({"error": "MAC and Alias are required"}), 400 + + conn = get_temp_db_connection() + cur = conn.cursor() + + try: + cur.execute("UPDATE Devices SET devName = ? WHERE devMac = ?", (alias, mac)) + conn.commit() + + if cur.rowcount == 0: + return jsonify({"error": "Device not found"}), 404 + + return jsonify({"success": True, "message": f"Device {mac} renamed to {alias}"}) + except Exception as e: + return jsonify({"error": str(e)}), 500 + finally: + conn.close() + + +@tools_bp.route('/wol_wake_device', methods=['POST']) +def wol_wake_device(): + """ + Sends a Wake-on-LAN magic packet. + Arguments: mac OR ip + """ + if not check_auth(): + return jsonify({"error": "Unauthorized"}), 401 + + data = request.get_json() + mac = data.get('mac') + ip = data.get('ip') + + if not mac and not ip: + return jsonify({"error": "MAC address or IP address is required"}), 400 + + # Resolve IP to MAC if MAC is missing + if not mac and ip: + conn = get_temp_db_connection() + conn.row_factory = sqlite3.Row + cur = conn.cursor() + try: + # Try to find device by IP (devLastIP) + cur.execute("SELECT devMac FROM Devices WHERE devLastIP = ?", (ip,)) + row = cur.fetchone() + if row and row['devMac']: + mac = row['devMac'] + else: + return jsonify({"error": f"Could not resolve MAC for IP {ip}"}), 404 + except Exception as e: + return jsonify({"error": f"Database error: {str(e)}"}), 500 + finally: + conn.close() + + # Validate MAC + if not re.match(r"^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$", mac): + return jsonify({"success": False, "error": f"Invalid MAC: {mac}"}), 400 + + try: + # Using wakeonlan command + result = subprocess.run( + ["wakeonlan", mac], capture_output=True, text=True, check=True + ) + return jsonify( + { + "success": True, + "message": f"WOL packet sent to {mac}", + "output": result.stdout.strip(), + } + ) + except subprocess.CalledProcessError as e: + return jsonify( + { + "success": False, + "error": "Failed to send WOL packet", + "details": e.stderr.strip(), + } + ), 500 + + +@tools_bp.route('/openapi.json', methods=['GET']) +def openapi_spec(): + """Return OpenAPI specification for tools.""" + # No auth required for spec to allow easy import, or require it if preferred. + # Open WebUI usually needs to fetch spec without auth first or handles it. + # We'll allow public access to spec for simplicity of import. + + spec = { + "openapi": "3.0.0", + "info": { + "title": "NetAlertX Tools", + "description": "API for NetAlertX device management tools", + "version": "1.1.0" + }, + "servers": [ + {"url": "/api/tools"} + ], + "paths": { + "/list_devices": { + "post": { + "summary": "List all devices (Summary)", + "description": ( + "Retrieve a SUMMARY list of all devices, sorted by newest first. " + "IMPORTANT: This only provides basic info (Name, IP, Vendor). " + "For FULL details (like custom props, alerts, etc.), you MUST use 'get_device_info' or 'get_latest_device'." + ), + "operationId": "list_devices", + "responses": { + "200": { + "description": "List of devices (Summary)", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "type": "object", + "properties": { + "devName": {"type": "string"}, + "devMac": {"type": "string"}, + "devIP": {"type": "string"}, + "devVendor": {"type": "string"}, + "devStatus": {"type": "string"}, + "devFirstConnection": {"type": "string"}, + "devLastConnection": {"type": "string"} + } + } + } + } + } + } + } + } + }, + "/get_device_info": { + "post": { + "summary": "Get device info (Full Details)", + "description": ( + "Get COMPREHENSIVE information about a specific device by MAC, Name, or partial IP. " + "Use this to see all available properties, alerts, and metadata not shown in the list." + ), + "operationId": "get_device_info", + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "MAC address, Device Name, or partial IP to search for" + } + }, + "required": ["query"] + } + } + } + }, + "responses": { + "200": { + "description": "Device details (Full)", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": {"type": "object"} + } + } + } + }, + "404": {"description": "Device not found"} + } + } + }, + "/get_latest_device": { + "post": { + "summary": "Get latest device (Full Details)", + "description": "Get COMPREHENSIVE information about the most recently discovered device (latest devFirstConnection).", + "operationId": "get_latest_device", + "responses": { + "200": { + "description": "Latest device details (Full)", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": {"type": "object"} + } + } + } + }, + "404": {"description": "No devices found"} + } + } + }, + "/trigger_scan": { + "post": { + "summary": "Trigger Active Scan", + "description": "Forces NetAlertX to run a specific scan type immediately.", + "operationId": "trigger_scan", + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "scan_type": { + "type": "string", + "enum": ["arp", "nmap_fast", "nmap_deep"], + "default": "nmap_fast" + }, + "target": { + "type": "string", + "description": "IP address or CIDR to scan" + } + } + } + } + } + }, + "responses": { + "200": {"description": "Scan started/completed successfully"}, + "400": {"description": "Invalid input"} + } + } + }, + "/get_open_ports": { + "post": { + "summary": "Get Open Ports", + "description": "Specific query for the port-scan results of a target.", + "operationId": "get_open_ports", + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "target": { + "type": "string", + "description": "IP or MAC address" + } + }, + "required": ["target"] + } + } + } + }, + "responses": { + "200": {"description": "List of open ports"}, + "404": {"description": "Target not found"} + } + } + }, + "/get_network_topology": { + "get": { + "summary": "Get Network Topology", + "description": "Returns the Parent/Child relationships for network visualization.", + "operationId": "get_network_topology", + "responses": { + "200": {"description": "Graph data (nodes and links)"} + } + } + }, + "/get_recent_alerts": { + "post": { + "summary": "Get Recent Alerts", + "description": "Fetches the last N system alerts.", + "operationId": "get_recent_alerts", + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "hours": { + "type": "integer", + "default": 24 + } + } + } + } + } + }, + "responses": { + "200": {"description": "List of alerts"} + } + } + }, + "/set_device_alias": { + "post": { + "summary": "Set Device Alias", + "description": "Updates the name (alias) of a device.", + "operationId": "set_device_alias", + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "mac": {"type": "string"}, + "alias": {"type": "string"} + }, + "required": ["mac", "alias"] + } + } + } + }, + "responses": { + "200": {"description": "Alias updated"}, + "404": {"description": "Device not found"} + } + } + }, + "/wol_wake_device": { + "post": { + "summary": "Wake on LAN", + "description": "Sends a Wake-on-LAN magic packet to the target MAC or IP. If IP is provided, it resolves to MAC first.", + "operationId": "wol_wake_device", + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "mac": {"type": "string", "description": "Target MAC address"}, + "ip": {"type": "string", "description": "Target IP address (resolves to MAC)"} + } + } + } + } + }, + "responses": { + "200": {"description": "WOL packet sent"}, + "404": {"description": "IP not found"} + } + } + } + }, + "components": { + "securitySchemes": { + "bearerAuth": { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT" + } + } + }, + "security": [ + {"bearerAuth": []} + ] + } + return jsonify(spec) diff --git a/test/api_endpoints/test_mcp_tools_endpoints.py b/test/api_endpoints/test_mcp_tools_endpoints.py new file mode 100644 index 00000000..22bd136d --- /dev/null +++ b/test/api_endpoints/test_mcp_tools_endpoints.py @@ -0,0 +1,279 @@ +import sys +import os +import pytest +from unittest.mock import patch, MagicMock +import subprocess + +INSTALL_PATH = os.getenv('NETALERTX_APP', '/app') +sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) + +from helper import get_setting_value # noqa: E402 +from api_server.api_server_start import app # noqa: E402 + +@pytest.fixture(scope="session") +def api_token(): + return get_setting_value("API_TOKEN") + +@pytest.fixture +def client(): + with app.test_client() as client: + yield client + +def auth_headers(token): + return {"Authorization": f"Bearer {token}"} + +# --- get_device_info Tests --- + +@patch('api_server.tools_routes.get_temp_db_connection') +def test_get_device_info_ip_partial(mock_db_conn, client, api_token): + """Test get_device_info with partial IP search.""" + mock_cursor = MagicMock() + # Mock return of a device with IP ending in .50 + mock_cursor.fetchall.return_value = [ + {"devName": "Test Device", "devMac": "AA:BB:CC:DD:EE:FF", "devLastIP": "192.168.1.50"} + ] + mock_db_conn.return_value.cursor.return_value = mock_cursor + + payload = {"query": ".50"} + response = client.post('/api/tools/get_device_info', + json=payload, + headers=auth_headers(api_token)) + + assert response.status_code == 200 + devices = response.get_json() + assert len(devices) == 1 + assert devices[0]["devLastIP"] == "192.168.1.50" + + # Verify SQL query included 3 params (MAC, Name, IP) + args, _ = mock_cursor.execute.call_args + assert args[0].count("?") == 3 + assert len(args[1]) == 3 + +# --- trigger_scan Tests --- + +@patch('subprocess.run') +def test_trigger_scan_nmap_fast(mock_run, client, api_token): + """Test trigger_scan with nmap_fast.""" + mock_run.return_value = MagicMock(stdout="Scan completed", returncode=0) + + payload = {"scan_type": "nmap_fast", "target": "192.168.1.1"} + response = client.post('/api/tools/trigger_scan', + json=payload, + headers=auth_headers(api_token)) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + assert "nmap -F 192.168.1.1" in data["command"] + mock_run.assert_called_once() + +@patch('subprocess.run') +def test_trigger_scan_invalid_type(mock_run, client, api_token): + """Test trigger_scan with invalid scan_type.""" + payload = {"scan_type": "invalid_type", "target": "192.168.1.1"} + response = client.post('/api/tools/trigger_scan', + json=payload, + headers=auth_headers(api_token)) + + assert response.status_code == 400 + mock_run.assert_not_called() + +# --- get_open_ports Tests --- + +@patch('subprocess.run') +def test_get_open_ports_ip(mock_run, client, api_token): + """Test get_open_ports with an IP address.""" + mock_output = """ +Starting Nmap 7.80 ( https://nmap.org ) at 2023-10-27 10:00 UTC +Nmap scan report for 192.168.1.1 +Host is up (0.0010s latency). +Not shown: 98 closed ports +PORT STATE SERVICE +22/tcp open ssh +80/tcp open http +Nmap done: 1 IP address (1 host up) scanned in 0.10 seconds +""" + mock_run.return_value = MagicMock(stdout=mock_output, returncode=0) + + payload = {"target": "192.168.1.1"} + response = client.post('/api/tools/get_open_ports', + json=payload, + headers=auth_headers(api_token)) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + assert len(data["open_ports"]) == 2 + assert data["open_ports"][0]["port"] == 22 + assert data["open_ports"][1]["service"] == "http" + +@patch('api_server.tools_routes.get_temp_db_connection') +@patch('subprocess.run') +def test_get_open_ports_mac_resolve(mock_run, mock_db_conn, client, api_token): + """Test get_open_ports with a MAC address that resolves to an IP.""" + # Mock DB to resolve MAC to IP + mock_cursor = MagicMock() + mock_cursor.fetchone.return_value = {"devLastIP": "192.168.1.50"} + mock_db_conn.return_value.cursor.return_value = mock_cursor + + # Mock Nmap output + mock_run.return_value = MagicMock(stdout="80/tcp open http", returncode=0) + + payload = {"target": "AA:BB:CC:DD:EE:FF"} + response = client.post('/api/tools/get_open_ports', + json=payload, + headers=auth_headers(api_token)) + + assert response.status_code == 200 + data = response.get_json() + assert data["target"] == "192.168.1.50" # Should be resolved IP + mock_run.assert_called_once() + args, _ = mock_run.call_args + assert "192.168.1.50" in args[0] + +# --- get_network_topology Tests --- + +@patch('api_server.tools_routes.get_temp_db_connection') +def test_get_network_topology(mock_db_conn, client, api_token): + """Test get_network_topology.""" + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [ + {"devName": "Router", "devMac": "AA:AA:AA:AA:AA:AA", "devParentMAC": None, "devParentPort": None, "devVendor": "VendorA"}, + {"devName": "Device1", "devMac": "BB:BB:BB:BB:BB:BB", "devParentMAC": "AA:AA:AA:AA:AA:AA", "devParentPort": "eth1", "devVendor": "VendorB"} + ] + mock_db_conn.return_value.cursor.return_value = mock_cursor + + response = client.get('/api/tools/get_network_topology', + headers=auth_headers(api_token)) + + assert response.status_code == 200 + data = response.get_json() + assert len(data["nodes"]) == 2 + assert len(data["links"]) == 1 + assert data["links"][0]["source"] == "AA:AA:AA:AA:AA:AA" + assert data["links"][0]["target"] == "BB:BB:BB:BB:BB:BB" + +# --- get_recent_alerts Tests --- + +@patch('api_server.tools_routes.get_temp_db_connection') +def test_get_recent_alerts(mock_db_conn, client, api_token): + """Test get_recent_alerts.""" + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [ + {"eve_DateTime": "2023-10-27 10:00:00", "eve_EventType": "New Device", "eve_MAC": "CC:CC:CC:CC:CC:CC", "eve_IP": "192.168.1.100", "devName": "Unknown"} + ] + mock_db_conn.return_value.cursor.return_value = mock_cursor + + payload = {"hours": 24} + response = client.post('/api/tools/get_recent_alerts', + json=payload, + headers=auth_headers(api_token)) + + assert response.status_code == 200 + data = response.get_json() + assert len(data) == 1 + assert data[0]["eve_EventType"] == "New Device" + +# --- set_device_alias Tests --- + +@patch('api_server.tools_routes.get_temp_db_connection') +def test_set_device_alias(mock_db_conn, client, api_token): + """Test set_device_alias.""" + mock_cursor = MagicMock() + mock_cursor.rowcount = 1 # Simulate successful update + mock_db_conn.return_value.cursor.return_value = mock_cursor + + payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"} + response = client.post('/api/tools/set_device_alias', + json=payload, + headers=auth_headers(api_token)) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + +@patch('api_server.tools_routes.get_temp_db_connection') +def test_set_device_alias_not_found(mock_db_conn, client, api_token): + """Test set_device_alias when device is not found.""" + mock_cursor = MagicMock() + mock_cursor.rowcount = 0 # Simulate no rows updated + mock_db_conn.return_value.cursor.return_value = mock_cursor + + payload = {"mac": "AA:BB:CC:DD:EE:FF", "alias": "New Name"} + response = client.post('/api/tools/set_device_alias', + json=payload, + headers=auth_headers(api_token)) + + assert response.status_code == 404 + +# --- wol_wake_device Tests --- + +@patch('subprocess.run') +def test_wol_wake_device(mock_subprocess, client, api_token): + """Test wol_wake_device.""" + mock_subprocess.return_value.stdout = "Sending magic packet to 255.255.255.255:9 with AA:BB:CC:DD:EE:FF" + mock_subprocess.return_value.returncode = 0 + + payload = {"mac": "AA:BB:CC:DD:EE:FF"} + response = client.post('/api/tools/wol_wake_device', + json=payload, + headers=auth_headers(api_token)) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True) + +@patch('api_server.tools_routes.get_temp_db_connection') +@patch('subprocess.run') +def test_wol_wake_device_by_ip(mock_subprocess, mock_db_conn, client, api_token): + """Test wol_wake_device with IP address.""" + # Mock DB for IP resolution + mock_cursor = MagicMock() + mock_cursor.fetchone.return_value = {"devMac": "AA:BB:CC:DD:EE:FF"} + mock_db_conn.return_value.cursor.return_value = mock_cursor + + # Mock subprocess + mock_subprocess.return_value.stdout = "Sending magic packet to 255.255.255.255:9 with AA:BB:CC:DD:EE:FF" + mock_subprocess.return_value.returncode = 0 + + payload = {"ip": "192.168.1.50"} + response = client.post('/api/tools/wol_wake_device', + json=payload, + headers=auth_headers(api_token)) + + assert response.status_code == 200 + data = response.get_json() + assert data["success"] is True + assert "AA:BB:CC:DD:EE:FF" in data["message"] + + # Verify DB lookup + mock_cursor.execute.assert_called_with("SELECT devMac FROM Devices WHERE devLastIP = ?", ("192.168.1.50",)) + + # Verify subprocess call + mock_subprocess.assert_called_with(["wakeonlan", "AA:BB:CC:DD:EE:FF"], capture_output=True, text=True, check=True) + +def test_wol_wake_device_invalid_mac(client, api_token): + """Test wol_wake_device with invalid MAC.""" + payload = {"mac": "invalid-mac"} + response = client.post('/api/tools/wol_wake_device', + json=payload, + headers=auth_headers(api_token)) + + assert response.status_code == 400 + +# --- openapi_spec Tests --- + +def test_openapi_spec(client): + """Test openapi_spec endpoint contains new paths.""" + response = client.get('/api/tools/openapi.json') + assert response.status_code == 200 + spec = response.get_json() + + # Check for new endpoints + assert "/trigger_scan" in spec["paths"] + assert "/get_open_ports" in spec["paths"] + assert "/get_network_topology" in spec["paths"] + assert "/get_recent_alerts" in spec["paths"] + assert "/set_device_alias" in spec["paths"] + assert "/wol_wake_device" in spec["paths"] diff --git a/test/api_endpoints/test_tools_endpoints.py b/test/api_endpoints/test_tools_endpoints.py new file mode 100644 index 00000000..297f11b6 --- /dev/null +++ b/test/api_endpoints/test_tools_endpoints.py @@ -0,0 +1,79 @@ +import sys +import os +import pytest + +INSTALL_PATH = os.getenv('NETALERTX_APP', '/app') +sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) + +from helper import get_setting_value # noqa: E402 [flake8 lint suppression] +from api_server.api_server_start import app # noqa: E402 [flake8 lint suppression] + + +@pytest.fixture(scope="session") +def api_token(): + return get_setting_value("API_TOKEN") + + +@pytest.fixture +def client(): + with app.test_client() as client: + yield client + + +def auth_headers(token): + return {"Authorization": f"Bearer {token}"} + + +def test_openapi_spec(client): + """Test OpenAPI spec endpoint.""" + response = client.get('/api/tools/openapi.json') + assert response.status_code == 200 + spec = response.get_json() + assert "openapi" in spec + assert "info" in spec + assert "paths" in spec + assert "/list_devices" in spec["paths"] + assert "/get_device_info" in spec["paths"] + + +def test_list_devices(client, api_token): + """Test list_devices endpoint.""" + response = client.post('/api/tools/list_devices', headers=auth_headers(api_token)) + assert response.status_code == 200 + devices = response.get_json() + assert isinstance(devices, list) + # If there are devices, check structure + if devices: + device = devices[0] + assert "devName" in device + assert "devMac" in device + + +def test_get_device_info(client, api_token): + """Test get_device_info endpoint.""" + # Test with a query that might not exist + payload = {"query": "nonexistent_device"} + response = client.post('/api/tools/get_device_info', + json=payload, + headers=auth_headers(api_token)) + # Should return 404 if no match, or 200 with results + assert response.status_code in [200, 404] + if response.status_code == 200: + devices = response.get_json() + assert isinstance(devices, list) + elif response.status_code == 404: + # Expected for no matches + pass + + +def test_list_devices_unauthorized(client): + """Test list_devices without authorization.""" + response = client.post('/api/tools/list_devices') + assert response.status_code == 401 + + +def test_get_device_info_unauthorized(client): + """Test get_device_info without authorization.""" + payload = {"query": "test"} + response = client.post('/api/tools/get_device_info', json=payload) + assert response.status_code == 401