diff --git a/server/api_server/api_server_start.py b/server/api_server/api_server_start.py index eb330dbe..de68c90d 100755 --- a/server/api_server/api_server_start.py +++ b/server/api_server/api_server_start.py @@ -7,6 +7,7 @@ import os from flask import Flask, request, jsonify, Response from models.device_instance import DeviceInstance # noqa: E402 from flask_cors import CORS +from werkzeug.exceptions import HTTPException # Register NetAlertX directories INSTALL_PATH = os.getenv("NETALERTX_APP", "/app") @@ -59,7 +60,8 @@ from .mcp_endpoint import ( mcp_sse, mcp_messages, openapi_spec, -) # noqa: E402 [flake8 lint suppression] + get_openapi_spec, +) # validation and schemas for MCP v2 from .openapi.validation import validate_request # noqa: E402 [flake8 lint suppression] from .openapi.schemas import ( # noqa: E402 [flake8 lint suppression] @@ -100,6 +102,20 @@ from .sse_endpoint import ( # noqa: E402 [flake8 lint suppression] app = Flask(__name__) +@app.errorhandler(500) +@app.errorhandler(Exception) +def handle_500_error(e): + """Global error handler for uncaught exceptions.""" + if isinstance(e, HTTPException): + return e + mylog("none", [f"[API] Uncaught exception: {e}"]) + return jsonify({ + "success": False, + "error": "Internal Server Error", + "message": "Something went wrong on the server" + }), 500 + + # Parse CORS origins from environment or use safe defaults _cors_origins_env = os.environ.get("CORS_ORIGINS", "") _cors_origins = [ @@ -599,7 +615,7 @@ def api_device_open_ports(payload=None): @validate_request( operation_id="get_all_devices", summary="Get All Devices", - description="Retrieve a list of all devices in the system.", + description="Retrieve a list of all devices in the system. Returns all records. No pagination supported.", response_model=DeviceListWrapperResponse, tags=["devices"], auth_callable=is_authorized @@ -662,7 +678,7 @@ def api_delete_unknown_devices(payload=None): @app.route("/devices/export", methods=["GET"]) @app.route("/devices/export/", methods=["GET"]) @validate_request( - operation_id="export_devices", + operation_id="export_devices_all", summary="Export Devices", description="Export all devices in CSV or JSON format.", query_params=[{ @@ -679,7 +695,8 @@ def api_delete_unknown_devices(payload=None): }], response_model=DeviceExportResponse, tags=["devices"], - auth_callable=is_authorized + auth_callable=is_authorized, + response_content_types=["application/json", "text/csv"] ) def api_export_devices(format=None, payload=None): export_format = (format or request.args.get("format", "csv")).lower() @@ -747,7 +764,7 @@ def api_devices_totals(payload=None): @app.route('/mcp/sse/devices/by-status', methods=['GET', 'POST']) @app.route("/devices/by-status", methods=["GET", "POST"]) @validate_request( - operation_id="list_devices_by_status", + operation_id="list_devices_by_status_api", summary="List Devices by Status", description="List devices filtered by their online/offline status.", request_model=DeviceListRequest, @@ -763,7 +780,30 @@ def api_devices_totals(payload=None): "connected", "down", "favorites", "new", "archived", "all", "my", "offline" ]} - }] + }], + links={ + "GetOpenPorts": { + "operationId": "get_open_ports", + "parameters": { + "target": "$response.body#/0/devLastIP" + }, + "description": "The `target` parameter for `get_open_ports` requires an IP address. Use the `devLastIP` from the first device in the list." + }, + "WakeOnLan": { + "operationId": "wake_on_lan", + "parameters": { + "devMac": "$response.body#/0/devMac" + }, + "description": "The `devMac` parameter for `wake_on_lan` requires a MAC address. Use the `devMac` from the first device in the list." + }, + "UpdateDevice": { + "operationId": "update_device", + "parameters": { + "mac": "$response.body#/0/devMac" + }, + "description": "The `mac` parameter for `update_device` is a path parameter. Use the `devMac` from the first device in the list." + } + } ) def api_devices_by_status(payload: DeviceListRequest = None): status = payload.status if payload else request.args.get("status") @@ -774,13 +814,43 @@ def api_devices_by_status(payload: DeviceListRequest = None): @app.route('/mcp/sse/devices/search', methods=['POST']) @app.route('/devices/search', methods=['POST']) @validate_request( - operation_id="search_devices", + operation_id="search_devices_api", summary="Search Devices", description="Search for devices based on various criteria like name, IP, MAC, or vendor.", request_model=DeviceSearchRequest, response_model=DeviceSearchResponse, tags=["devices"], - auth_callable=is_authorized + auth_callable=is_authorized, + links={ + "GetOpenPorts": { + "operationId": "get_open_ports", + "parameters": { + "target": "$response.body#/devices/0/devLastIP" + }, + "description": "The `target` parameter for `get_open_ports` requires an IP address. Use the `devLastIP` from the first device in the search results." + }, + "WakeOnLan": { + "operationId": "wake_on_lan", + "parameters": { + "devMac": "$response.body#/devices/0/devMac" + }, + "description": "The `devMac` parameter for `wake_on_lan` requires a MAC address. Use the `devMac` from the first device in the search results." + }, + "NmapScan": { + "operationId": "run_nmap_scan", + "parameters": { + "scan": "$response.body#/devices/0/devLastIP" + }, + "description": "The `scan` parameter for `run_nmap_scan` requires an IP or range. Use the `devLastIP` from the first device in the search results." + }, + "UpdateDevice": { + "operationId": "update_device", + "parameters": { + "mac": "$response.body#/devices/0/devMac" + }, + "description": "The `mac` parameter for `update_device` is a path parameter. Use the `devMac` from the first device in the search results." + } + } ) def api_devices_search(payload=None): """Device search: accepts 'query' in JSON and maps to device info/search.""" @@ -884,9 +954,13 @@ def api_devices_network_topology(payload=None): auth_callable=is_authorized ) def api_wakeonlan(payload=None): - data = request.get_json(silent=True) or {} - mac = data.get("devMac") - ip = data.get("devLastIP") or data.get('ip') + if payload: + mac = payload.mac + ip = payload.devLastIP + else: + data = request.get_json(silent=True) or {} + mac = data.get("mac") or data.get("devMac") + ip = data.get("devLastIP") or data.get('ip') if not mac and ip: @@ -1011,7 +1085,7 @@ def api_network_interfaces(payload=None): @app.route('/mcp/sse/nettools/trigger-scan', methods=['POST']) -@app.route("/nettools/trigger-scan", methods=["GET"]) +@app.route("/nettools/trigger-scan", methods=["GET", "POST"]) @validate_request( operation_id="trigger_network_scan", summary="Trigger Network Scan", @@ -1300,13 +1374,25 @@ def api_create_event(mac, payload=None): @app.route("/events/", methods=["DELETE"]) @validate_request( - operation_id="delete_events_by_mac", - summary="Delete Events by MAC", - description="Delete all events for a specific device MAC address.", + operation_id="delete_events", + summary="Delete Events", + description="Delete events by device MAC address or older than a specified number of days.", path_params=[{ "name": "mac", - "description": "Device MAC address", - "schema": {"type": "string"} + "description": "Device MAC address or number of days", + "schema": { + "oneOf": [ + { + "type": "integer", + "description": "Number of days (e.g., 30) to delete events older than this value." + }, + { + "type": "string", + "pattern": "^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$", + "description": "Device MAC address to delete all events for a specific device." + } + ] + } }], response_model=BaseResponse, tags=["events"], @@ -1315,6 +1401,7 @@ def api_create_event(mac, payload=None): def api_events_by_mac(mac, payload=None): """Delete events for a specific device MAC; string converter keeps this distinct from /events/.""" device_handler = DeviceInstance() + result = device_handler.deleteDeviceEvents(mac) return jsonify(result) @@ -1338,7 +1425,7 @@ def api_delete_all_events(payload=None): @validate_request( operation_id="get_all_events", summary="Get Events", - description="Retrieve a list of events, optionally filtered by MAC.", + description="Retrieve a list of events, optionally filtered by MAC. Returns all matching records. No pagination supported.", query_params=[{ "name": "mac", "description": "Filter by Device MAC", @@ -1372,7 +1459,8 @@ def api_get_events(payload=None): }], response_model=BaseResponse, tags=["events"], - auth_callable=is_authorized + auth_callable=is_authorized, + exclude_from_spec=True ) def api_delete_old_events(days: int, payload=None): """ @@ -1406,7 +1494,7 @@ def api_get_events_totals(payload=None): @app.route('/mcp/sse/events/recent', methods=['GET', 'POST']) -@app.route('/events/recent', methods=['GET']) +@app.route('/events/recent', methods=['GET', 'POST']) @validate_request( operation_id="get_recent_events", summary="Get Recent Events", @@ -1426,7 +1514,7 @@ def api_events_default_24h(payload=None): @app.route('/mcp/sse/events/last', methods=['GET', 'POST']) -@app.route('/events/last', methods=['GET']) +@app.route('/events/last', methods=['GET', 'POST']) @validate_request( operation_id="get_last_events", summary="Get Last Events", @@ -1763,7 +1851,7 @@ def sync_endpoint_post(payload=None): @validate_request( operation_id="check_auth", summary="Check Authentication", - description="Check if the current API token is valid.", + description="Check if the current API token is valid. Note: tokens must be generated externally via the UI or CLI.", response_model=BaseResponse, tags=["auth"], auth_callable=is_authorized @@ -1778,6 +1866,14 @@ def check_auth(payload=None): # Mount SSE endpoints after is_authorized is defined (avoid circular import) create_sse_endpoint(app, is_authorized) +# Apply environment-driven MCP disablement by regenerating the OpenAPI spec. +# This populates the registry and applies any operation IDs listed in MCP_DISABLED_TOOLS. +try: + get_openapi_spec(force_refresh=True, flask_app=app) + mylog("verbose", [f"[MCP] Applied MCP_DISABLED_TOOLS: {os.environ.get('MCP_DISABLED_TOOLS', '')}"]) +except Exception as e: + mylog("none", [f"[MCP] Error applying MCP_DISABLED_TOOLS: {e}"]) + def start_server(graphql_port, app_state): """Start the GraphQL server in a background thread.""" diff --git a/server/api_server/mcp_endpoint.py b/server/api_server/mcp_endpoint.py index 005ff1ef..1aa3ba49 100644 --- a/server/api_server/mcp_endpoint.py +++ b/server/api_server/mcp_endpoint.py @@ -309,6 +309,7 @@ def map_openapi_to_mcp_tools(spec: Dict[str, Any]) -> List[Dict[str, Any]]: This function transforms OpenAPI operations into MCP-compatible tool schemas, ensuring proper inputSchema derivation from request bodies and parameters. + It deduplicates tools by their original operationId, preferring /mcp/ routes. Args: spec: OpenAPI specification dictionary @@ -316,10 +317,10 @@ def map_openapi_to_mcp_tools(spec: Dict[str, Any]) -> List[Dict[str, Any]]: Returns: List of MCP tool definitions with name, description, and inputSchema """ - tools = [] + tools_map = {} if not spec or "paths" not in spec: - return tools + return [] for path, methods in spec["paths"].items(): for method, details in methods.items(): @@ -327,6 +328,9 @@ def map_openapi_to_mcp_tools(spec: Dict[str, Any]) -> List[Dict[str, Any]]: continue operation_id = details["operationId"] + # Deduplicate using the original operationId (before suffixing) + # or the unique operationId as fallback. + original_op_id = details.get("x-original-operationId", operation_id) # Build inputSchema from requestBody and parameters input_schema = { @@ -382,31 +386,82 @@ def map_openapi_to_mcp_tools(spec: Dict[str, Any]) -> List[Dict[str, Any]]: tool = { "name": operation_id, "description": details.get("description", details.get("summary", "")), - "inputSchema": input_schema + "inputSchema": input_schema, + "_original_op_id": original_op_id, + "_is_mcp": path.startswith("/mcp/"), + "_is_post": method.upper() == "POST" } - tools.append(tool) + # Preference logic for deduplication: + # 1. Prefer /mcp/ routes over standard ones. + # 2. Prefer POST methods over GET for the same logic (usually more robust body validation). + existing = tools_map.get(original_op_id) + if not existing: + tools_map[original_op_id] = tool + else: + # Upgrade if current is MCP and existing is not + mcp_upgrade = tool["_is_mcp"] and not existing["_is_mcp"] + # Upgrade if same route type but current is POST and existing is GET + method_upgrade = (tool["_is_mcp"] == existing["_is_mcp"]) and tool["_is_post"] and not existing["_is_post"] + + if mcp_upgrade or method_upgrade: + tools_map[original_op_id] = tool - return tools + # Final cleanup: remove internal preference flags and ensure tools have the original names + # unless we explicitly want the suffixed ones. + # The user said "Eliminate Duplicate Tool Names", so we should use original_op_id as the tool name. + final_tools = [] + _tool_name_to_operation_id: Dict[str, str] = {} + for tool in tools_map.values(): + actual_operation_id = tool["name"] # Save before overwriting + tool["name"] = tool["_original_op_id"] + _tool_name_to_operation_id[tool["name"]] = actual_operation_id + del tool["_original_op_id"] + del tool["_is_mcp"] + del tool["_is_post"] + final_tools.append(tool) + + return final_tools def find_route_for_tool(tool_name: str) -> Optional[Dict[str, Any]]: """ Find the registered route for a given tool name (operationId). + Handles exact matches and deduplicated original IDs. Args: - tool_name: The operationId to look up + tool_name: The operationId or original_operation_id to look up Returns: Route dictionary with path, method, and models, or None if not found """ registry = get_registry() + candidates = [] for entry in registry: + # Exact match (priority) - if the client passed the specific suffixed ID if entry["operation_id"] == tool_name: return entry + if entry.get("original_operation_id") == tool_name: + candidates.append(entry) - return None + if not candidates: + return None + + # Apply same preference logic as map_openapi_to_mcp_tools to ensure we pick the + # same route definition that generated the tool schema. + + # Priority 1: MCP routes (they have specialized paths/behavior) + mcp_candidates = [c for c in candidates if c["path"].startswith("/mcp/")] + pool = mcp_candidates if mcp_candidates else candidates + + # Priority 2: POST methods (usually preferred for tools) + post_candidates = [c for c in pool if c["method"].upper() == "POST"] + if post_candidates: + return post_candidates[0] + + # Fallback: return the first from the best pool available + return pool[0] # ============================================================================= diff --git a/server/api_server/openapi/introspection.py b/server/api_server/openapi/introspection.py index 2c1454de..dea99245 100644 --- a/server/api_server/openapi/introspection.py +++ b/server/api_server/openapi/introspection.py @@ -1,10 +1,12 @@ from __future__ import annotations import re -from typing import Any +from typing import Any, Dict, Optional import graphene from .registry import register_tool, _operation_ids +from .schemas import GraphQLRequest +from .schema_converter import pydantic_to_json_schema, resolve_schema_refs def introspect_graphql_schema(schema: graphene.Schema): @@ -26,6 +28,7 @@ def introspect_graphql_schema(schema: graphene.Schema): operation_id="graphql_query", summary="GraphQL Endpoint", description="Execute arbitrary GraphQL queries against the system schema.", + request_model=GraphQLRequest, tags=["graphql"] ) @@ -36,6 +39,20 @@ def _flask_to_openapi_path(flask_path: str) -> str: return re.sub(r'<(?:\w+:)?(\w+)>', r'{\1}', flask_path) +def _get_openapi_metadata(func: Any) -> Optional[Dict[str, Any]]: + """Recursively find _openapi_metadata in wrapped functions.""" + # Check current function + metadata = getattr(func, "_openapi_metadata", None) + if metadata: + return metadata + + # Check __wrapped__ (standard for @wraps) + if hasattr(func, "__wrapped__"): + return _get_openapi_metadata(func.__wrapped__) + + return None + + def introspect_flask_app(app: Any): """ Introspect the Flask application to find routes decorated with @validate_request @@ -47,14 +64,13 @@ def introspect_flask_app(app: Any): if not view_func: continue - # Check for our decorator's metadata - metadata = getattr(view_func, "_openapi_metadata", None) - if not metadata: - # Fallback for wrapped functions - if hasattr(view_func, "__wrapped__"): - metadata = getattr(view_func.__wrapped__, "_openapi_metadata", None) + # Check for our decorator's metadata recursively + metadata = _get_openapi_metadata(view_func) if metadata: + if metadata.get("exclude_from_spec"): + continue + op_id = metadata["operation_id"] # Register the tool with real path and method from Flask @@ -75,20 +91,72 @@ def introspect_flask_app(app: Any): # Determine tags - create a copy to avoid mutating shared metadata tags = list(metadata.get("tags") or ["rest"]) if path.startswith("/mcp/"): - # Move specific tags to secondary position or just add MCP - if "rest" in tags: - tags.remove("rest") - if "mcp" not in tags: - tags.append("mcp") + # For MCP endpoints, we want them exclusively in the 'mcp' tag section + tags = ["mcp"] # Ensure unique operationId original_op_id = op_id unique_op_id = op_id + + # Semantic naming strategy for duplicates + if unique_op_id in _operation_ids: + # Construct a semantic suffix to replace numeric ones + # Priority: /mcp/ prefix and HTTP method + suffix = "" + if path.startswith("/mcp/"): + suffix = "_mcp" + + if method.upper() == "POST": + suffix += "_post" + elif method.upper() == "GET": + suffix += "_get" + + if suffix: + candidate = f"{op_id}{suffix}" + if candidate not in _operation_ids: + unique_op_id = candidate + + # Fallback to numeric suffixes if semantic naming didn't ensure uniqueness count = 1 while unique_op_id in _operation_ids: unique_op_id = f"{op_id}_{count}" count += 1 + # Filter path_params to only include those that are actually in the path + path_params = metadata.get("path_params") + if path_params: + path_params = [ + p for p in path_params + if f"{{{p['name']}}}" in path + ] + + # Auto-generate query_params from request_model for GET requests + query_params = metadata.get("query_params") + if method == 'GET' and not query_params and metadata.get("request_model"): + try: + schema = pydantic_to_json_schema(metadata["request_model"]) + defs = schema.get("$defs", {}) + properties = schema.get("properties", {}) + query_params = [] + for name, prop in properties.items(): + is_required = name in schema.get("required", []) + # Resolve references to inlined definitions (preserving Enums) + resolved_prop = resolve_schema_refs(prop, defs) + # Create param definition + param_def = { + "name": name, + "in": "query", + "required": is_required, + "description": prop.get("description", ""), + "schema": resolved_prop + } + # Remove description from schema to avoid duplication + if "description" in param_def["schema"]: + del param_def["schema"]["description"] + query_params.append(param_def) + except Exception: + pass # Fallback to empty if schema generation fails + register_tool( path=path, method=method, @@ -98,9 +166,11 @@ def introspect_flask_app(app: Any): description=metadata["description"], request_model=metadata.get("request_model"), response_model=metadata.get("response_model"), - path_params=metadata.get("path_params"), - query_params=metadata.get("query_params"), + path_params=path_params, + query_params=query_params, tags=tags, - allow_multipart_payload=metadata.get("allow_multipart_payload", False) + allow_multipart_payload=metadata.get("allow_multipart_payload", False), + response_content_types=metadata.get("response_content_types"), + links=metadata.get("links") ) registered_ops.add(op_key) diff --git a/server/api_server/openapi/registry.py b/server/api_server/openapi/registry.py index fcd2fa91..6d8759b3 100644 --- a/server/api_server/openapi/registry.py +++ b/server/api_server/openapi/registry.py @@ -96,7 +96,9 @@ def register_tool( tags: Optional[List[str]] = None, deprecated: bool = False, original_operation_id: Optional[str] = None, - allow_multipart_payload: bool = False + allow_multipart_payload: bool = False, + response_content_types: Optional[List[str]] = None, + links: Optional[Dict[str, Any]] = None ) -> None: """ Register an API endpoint for OpenAPI spec generation. @@ -115,6 +117,8 @@ def register_tool( deprecated: Whether this endpoint is deprecated original_operation_id: The base ID before suffixing (for disablement mapping) allow_multipart_payload: Whether to allow multipart/form-data payloads + response_content_types: List of supported response media types (e.g. ["application/json", "text/csv"]) + links: Dictionary of OpenAPI links to include in the response definition. Raises: DuplicateOperationIdError: If operation_id already exists in registry @@ -140,7 +144,9 @@ def register_tool( "query_params": query_params or [], "tags": tags or ["default"], "deprecated": deprecated, - "allow_multipart_payload": allow_multipart_payload + "allow_multipart_payload": allow_multipart_payload, + "response_content_types": response_content_types or ["application/json"], + "links": links }) diff --git a/server/api_server/openapi/schema_converter.py b/server/api_server/openapi/schema_converter.py index c6979527..73b6c029 100644 --- a/server/api_server/openapi/schema_converter.py +++ b/server/api_server/openapi/schema_converter.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Dict, Any, Optional, Type, List from pydantic import BaseModel +from .schemas import ErrorResponse, BaseResponse def pydantic_to_json_schema(model: Type[BaseModel], mode: str = "validation") -> Dict[str, Any]: @@ -161,57 +162,124 @@ def strip_validation(schema: Dict[str, Any]) -> Dict[str, Any]: return clean_schema +def resolve_schema_refs(schema: Dict[str, Any], definitions: Dict[str, Any]) -> Dict[str, Any]: + """ + Recursively resolve $ref in schema by inlining the definition. + Useful for standalone schema parts like query parameters where global definitions aren't available. + """ + if not isinstance(schema, dict): + return schema + + if "$ref" in schema: + ref = schema["$ref"] + # Handle #/$defs/Name syntax + if ref.startswith("#/$defs/"): + def_name = ref.split("/")[-1] + if def_name in definitions: + # Inline the definition (and resolve its refs recursively) + inlined = resolve_schema_refs(definitions[def_name], definitions) + # Merge any extra keys from the original schema (e.g. description override) + # Schema keys take precedence over definition keys + return {**inlined, **{k: v for k, v in schema.items() if k != "$ref"}} + + # Recursively resolve properties + resolved = {} + for k, v in schema.items(): + if k == "items": + resolved[k] = resolve_schema_refs(v, definitions) + elif k == "properties": + resolved[k] = {pk: resolve_schema_refs(pv, definitions) for pk, pv in v.items()} + elif k in ("allOf", "anyOf", "oneOf"): + resolved[k] = [resolve_schema_refs(i, definitions) for i in v] + else: + resolved[k] = v + + return resolved + + def build_responses( - response_model: Optional[Type[BaseModel]], definitions: Dict[str, Any] + response_model: Optional[Type[BaseModel]], + definitions: Dict[str, Any], + response_content_types: Optional[List[str]] = None, + links: Optional[Dict[str, Any]] = None, + method: str = "post" ) -> Dict[str, Any]: """Build OpenAPI responses object.""" responses = {} - # Success response (200) - if response_model: - # Strip validation from response schema to save tokens - schema = strip_validation(pydantic_to_json_schema(response_model, mode="serialization")) - schema = extract_definitions(schema, definitions) - responses["200"] = { - "description": "Successful response", - "content": { - "application/json": { - "schema": schema - } - } - } + # Use a fresh list for response content types to avoid a shared mutable default. + if response_content_types is None: + response_content_types = ["application/json"] else: - responses["200"] = { - "description": "Successful response", - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": { - "success": {"type": "boolean"}, - "message": {"type": "string"} - } - } - } - } - } + # Copy provided list to ensure each call gets its own list + response_content_types = list(response_content_types) - # Standard error responses - MINIMIZED context - # Annotate that these errors can occur, but provide no schema/content to save tokens. - # The LLM knows what "Bad Request" or "Not Found" means. - error_codes = { - "400": "Bad Request", - "401": "Unauthorized", - "403": "Forbidden", - "404": "Not Found", - "422": "Validation Error", - "500": "Internal Server Error" + # Success response (200) + effective_model = response_model or BaseResponse + schema = strip_validation(pydantic_to_json_schema(effective_model, mode="serialization")) + schema = extract_definitions(schema, definitions) + + content = {} + for ct in response_content_types: + if ct == "application/json": + content[ct] = {"schema": schema} + else: + # For non-JSON types like CSV, we don't necessarily use the JSON schema + content[ct] = {"schema": {"type": "string", "format": "binary"}} + + response_obj = { + "description": "Successful response", + "content": content + } + if links: + response_obj["links"] = links + responses["200"] = response_obj + + # Standard error responses + error_configs = { + "400": ("Invalid JSON", "Request body must be valid JSON"), + "401": ("Unauthorized", None), + "403": ("Forbidden", "ERROR: Not authorized"), + "404": ("API route not found", "The requested URL /example/path was not found on the server."), + "422": ("Validation Error", None), + "500": ("Internal Server Error", "Something went wrong on the server") } - for code, desc in error_codes.items(): + for code, (error_val, message_val) in error_configs.items(): + # Generate a fresh schema for each error to customize examples + error_schema_raw = strip_validation(pydantic_to_json_schema(ErrorResponse, mode="serialization")) + error_schema = extract_definitions(error_schema_raw, definitions) + + # Inject status-specific example + if "examples" in error_schema and len(error_schema["examples"]) > 0: + example = { + "success": False, + "error": error_val + } + if message_val: + example["message"] = message_val + + if code == "422": + example["error"] = "Validation Error: Input should be a valid string" + example["details"] = [ + { + "input": "invalid_value", + "loc": ["field_name"], + "msg": "Input should be a valid string", + "type": "string_type", + "url": "https://errors.pydantic.dev/2.12/v/string_type" + } + ] + + error_schema["examples"] = [example] + responses[code] = { - "description": desc - # No "content" schema provided + "description": error_val, + "content": { + "application/json": { + "schema": error_schema + } + } } return responses diff --git a/server/api_server/openapi/schemas.py b/server/api_server/openapi/schemas.py index 57db072b..96f862de 100644 --- a/server/api_server/openapi/schemas.py +++ b/server/api_server/openapi/schemas.py @@ -52,6 +52,16 @@ ALLOWED_LOG_FILES = Literal[ "app.php_errors.log", "execution_queue.log", "db_is_locked.log" ] +ALLOWED_SCAN_TYPES = Literal["ARPSCAN", "NMAPDEV", "NMAP", "INTRNT", "AVAHISCAN", "NBTSCAN"] + +ALLOWED_SESSION_CONNECTION_TYPES = Literal["Connected", "Reconnected", "New Device", "Down Reconnected"] +ALLOWED_SESSION_DISCONNECTION_TYPES = Literal["Disconnected", "Device Down", "Timeout"] + +ALLOWED_EVENT_TYPES = Literal[ + "Device Down", "New Device", "Connected", "Disconnected", + "IP Changed", "Down Reconnected", "" +] + def validate_mac(value: str) -> str: """Validate and normalize MAC address format.""" @@ -89,14 +99,42 @@ def validate_column_identifier(value: str) -> str: class BaseResponse(BaseModel): - """Standard API response wrapper.""" - model_config = ConfigDict(extra="allow") + """ + Standard API response wrapper. + Note: The API often returns 200 OK for most operations; clients MUST parse the 'success' + boolean field to determine if the operation was actually successful. + """ + model_config = ConfigDict( + extra="allow", + json_schema_extra={ + "examples": [{ + "success": True + }] + } + ) success: bool = Field(..., description="Whether the operation succeeded") message: Optional[str] = Field(None, description="Human-readable message") error: Optional[str] = Field(None, description="Error message if success=False") +class ErrorResponse(BaseResponse): + """Standard error response model with details.""" + model_config = ConfigDict( + extra="allow", + json_schema_extra={ + "examples": [{ + "success": False, + "error": "Error message" + }] + } + ) + + success: bool = Field(False, description="Always False for errors") + details: Optional[Any] = Field(None, description="Detailed error information (e.g., validation errors)") + code: Optional[str] = Field(None, description="Internal error code") + + class PaginatedResponse(BaseResponse): """Response with pagination metadata.""" total: int = Field(0, description="Total number of items") @@ -130,7 +168,19 @@ class DeviceSearchRequest(BaseModel): class DeviceInfo(BaseModel): """Detailed device information model (Raw record).""" - model_config = ConfigDict(extra="allow") + model_config = ConfigDict( + extra="allow", + json_schema_extra={ + "examples": [{ + "devMac": "00:11:22:33:44:55", + "devName": "My iPhone", + "devLastIP": "192.168.1.10", + "devVendor": "Apple", + "devStatus": "online", + "devFavorite": 0 + }] + } + ) devMac: str = Field(..., description="Device MAC address") devName: Optional[str] = Field(None, description="Device display name/alias") @@ -138,13 +188,27 @@ class DeviceInfo(BaseModel): devPrimaryIPv4: Optional[str] = Field(None, description="Primary IPv4 address") devPrimaryIPv6: Optional[str] = Field(None, description="Primary IPv6 address") devVlan: Optional[str] = Field(None, description="VLAN identifier") - devForceStatus: Optional[str] = Field(None, description="Force device status (online/offline/dont_force)") + devForceStatus: Optional[Literal["online", "offline", "dont_force"]] = Field( + "dont_force", + description="Force device status (online/offline/dont_force)" + ) devVendor: Optional[str] = Field(None, description="Hardware vendor from OUI lookup") devOwner: Optional[str] = Field(None, description="Device owner") devType: Optional[str] = Field(None, description="Device type classification") - devFavorite: Optional[int] = Field(0, description="Favorite flag (0 or 1)") - devPresentLastScan: Optional[int] = Field(None, description="Present in last scan (0 or 1)") - devStatus: Optional[str] = Field(None, description="Online/Offline status") + devFavorite: Optional[int] = Field( + 0, + description="Favorite flag (0=False, 1=True). Legacy boolean representation.", + json_schema_extra={"enum": [0, 1]} + ) + devPresentLastScan: Optional[int] = Field( + None, + description="Present in last scan (0 or 1)", + json_schema_extra={"enum": [0, 1]} + ) + devStatus: Optional[Literal["online", "offline"]] = Field( + None, + description="Online/Offline status" + ) devMacSource: Optional[str] = Field(None, description="Source of devMac (USER, LOCKED, or plugin prefix)") devNameSource: Optional[str] = Field(None, description="Source of devName") devFQDNSource: Optional[str] = Field(None, description="Source of devFQDN") @@ -169,7 +233,17 @@ class DeviceListRequest(BaseModel): "offline" ]] = Field( None, - description="Filter devices by status (connected, down, favorites, new, archived, all, my, offline)" + description=( + "Filter devices by status:\n" + "- connected: Active devices present in the last scan\n" + "- down: Devices with active 'Device Down' alert\n" + "- favorites: Devices marked as favorite\n" + "- new: Devices flagged as new\n" + "- archived: Devices moved to archive\n" + "- all: All active (non-archived) devices\n" + "- my: All active devices (alias for 'all')\n" + "- offline: Devices not present in the last scan" + ) ) @@ -270,12 +344,23 @@ class CopyDeviceRequest(BaseModel): class UpdateDeviceColumnRequest(BaseModel): """Request to update a specific device database column.""" columnName: ALLOWED_DEVICE_COLUMNS = Field(..., description="Database column name") - columnValue: Any = Field(..., description="New value for the column") + columnValue: Union[str, int, bool, None] = Field( + ..., + description="New value for the column. Must match the column's expected data type (e.g., string for devName, integer for devFavorite).", + json_schema_extra={ + "oneOf": [ + {"type": "string"}, + {"type": "integer"}, + {"type": "boolean"}, + {"type": "null"} + ] + } + ) class LockDeviceFieldRequest(BaseModel): """Request to lock/unlock a device field.""" - fieldName: Optional[str] = Field(None, description="Field name to lock/unlock (devMac, devName, devLastIP, etc.)") + fieldName: str = Field(..., description="Field name to lock/unlock (e.g., devName, devVendor). Required.") lock: bool = Field(True, description="True to lock the field, False to unlock") @@ -301,12 +386,18 @@ class DeviceUpdateRequest(BaseModel): devName: Optional[str] = Field(None, description="Device name") devOwner: Optional[str] = Field(None, description="Device owner") - devType: Optional[str] = Field(None, description="Device type") + devType: Optional[str] = Field( + None, + description="Device type", + json_schema_extra={ + "examples": ["Phone", "Laptop", "Desktop", "Router", "IoT", "Camera", "Server", "TV"] + } + ) devVendor: Optional[str] = Field(None, description="Device vendor") devGroup: Optional[str] = Field(None, description="Device group") devLocation: Optional[str] = Field(None, description="Device location") devComments: Optional[str] = Field(None, description="Comments") - createNew: bool = Field(False, description="Create new device if not exists") + createNew: bool = Field(False, description="If True, creates a new device. Recommended to provide at least devName and devVendor. If False, updates existing device.") @field_validator("devName", "devOwner", "devType", "devVendor", "devGroup", "devLocation", "devComments") @classmethod @@ -340,10 +431,9 @@ class DeleteDevicesRequest(BaseModel): class TriggerScanRequest(BaseModel): """Request to trigger a network scan.""" - type: str = Field( + type: ALLOWED_SCAN_TYPES = Field( "ARPSCAN", - description="Scan plugin type to execute (e.g., ARPSCAN, NMAPDEV, NMAP)", - json_schema_extra={"examples": ["ARPSCAN", "NMAPDEV", "NMAP"]} + description="Scan plugin type to execute (e.g., ARPSCAN, NMAPDEV, NMAP)" ) @@ -381,8 +471,9 @@ class OpenPortsResponse(BaseResponse): class WakeOnLanRequest(BaseModel): """Request to send Wake-on-LAN packet.""" - devMac: Optional[str] = Field( + mac: Optional[str] = Field( None, + alias="devMac", description="Target device MAC address", json_schema_extra={"examples": ["00:11:22:33:44:55"]} ) @@ -396,7 +487,7 @@ class WakeOnLanRequest(BaseModel): # But Pydantic V2 with populate_by_name=True allows both "devLastIP" and "ip". model_config = ConfigDict(populate_by_name=True) - @field_validator("devMac") + @field_validator("mac") @classmethod def validate_mac_if_provided(cls, v: Optional[str]) -> Optional[str]: if v is not None: @@ -412,15 +503,19 @@ class WakeOnLanRequest(BaseModel): @model_validator(mode="after") def require_mac_or_ip(self) -> "WakeOnLanRequest": - """Ensure at least one of devMac or devLastIP is provided.""" - if self.devMac is None and self.devLastIP is None: - raise ValueError("Either 'devMac' or 'devLastIP' (alias 'ip') must be provided") + """Ensure at least one of mac or devLastIP is provided.""" + if self.mac is None and self.devLastIP is None: + raise ValueError("Either devMac (aka mac) or devLastIP (aka ip) must be provided") return self class WakeOnLanResponse(BaseResponse): """Response for Wake-on-LAN operation.""" - output: Optional[str] = Field(None, description="Command output") + output: Optional[str] = Field( + None, + description="Command output", + json_schema_extra={"examples": ["Sent magic packet to AA:BB:CC:DD:EE:FF"]} + ) class TracerouteRequest(BaseModel): @@ -446,7 +541,7 @@ class NmapScanRequest(BaseModel): """Request to perform NMAP scan.""" scan: str = Field( ..., - description="Target IP address for NMAP scan" + description="Target IP address for NMAP scan (Single IP only, no CIDR/ranges/hostnames)." ) mode: ALLOWED_NMAP_MODES = Field( ..., @@ -507,7 +602,17 @@ class NetworkInterfacesResponse(BaseResponse): class EventInfo(BaseModel): """Event/alert information.""" - model_config = ConfigDict(extra="allow") + model_config = ConfigDict( + extra="allow", + json_schema_extra={ + "examples": [{ + "eveMAC": "00:11:22:33:44:55", + "eveIP": "192.168.1.10", + "eveDateTime": "2024-01-29 10:00:00", + "eveEventType": "Device Down" + }] + } + ) eveRowid: Optional[int] = Field(None, description="Event row ID") eveMAC: Optional[str] = Field(None, description="Device MAC address") @@ -547,9 +652,19 @@ class LastEventsResponse(BaseResponse): class CreateEventRequest(BaseModel): """Request to create a device event.""" ip: Optional[str] = Field("0.0.0.0", description="Device IP") - event_type: str = Field("Device Down", description="Event type") + event_type: str = Field( + "Device Down", + description="Event type", + json_schema_extra={ + "examples": ["Device Down", "New Device", "Connected", "Disconnected", "IP Changed", "Down Reconnected", ""] + } + ) additional_info: Optional[str] = Field("", description="Additional info") - pending_alert: int = Field(1, description="Pending alert flag") + pending_alert: int = Field( + 1, + description="Pending alert flag (0 or 1)", + json_schema_extra={"enum": [0, 1]} + ) event_time: Optional[str] = Field(None, description="Event timestamp (ISO)") @field_validator("ip", mode="before") @@ -564,11 +679,19 @@ class CreateEventRequest(BaseModel): # ============================================================================= # SESSIONS SCHEMAS # ============================================================================= - - class SessionInfo(BaseModel): """Session information.""" - model_config = ConfigDict(extra="allow") + model_config = ConfigDict( + extra="allow", + json_schema_extra={ + "examples": [{ + "sesMac": "00:11:22:33:44:55", + "sesDateTimeConnection": "2024-01-29 08:00:00", + "sesDateTimeDisconnection": "2024-01-29 09:00:00", + "sesIPAddress": "192.168.1.10" + }] + } + ) sesRowid: Optional[int] = Field(None, description="Session row ID") sesMac: Optional[str] = Field(None, description="Device MAC address") @@ -583,8 +706,20 @@ class CreateSessionRequest(BaseModel): ip: str = Field(..., description="Device IP") start_time: str = Field(..., description="Start time") end_time: Optional[str] = Field(None, description="End time") - event_type_conn: str = Field("Connected", description="Connection event type") - event_type_disc: str = Field("Disconnected", description="Disconnection event type") + event_type_conn: str = Field( + "Connected", + description="Connection event type", + json_schema_extra={ + "examples": ["Connected", "Reconnected", "New Device", "Down Reconnected"] + } + ) + event_type_disc: str = Field( + "Disconnected", + description="Disconnection event type", + json_schema_extra={ + "examples": ["Disconnected", "Device Down", "Timeout"] + } + ) @field_validator("mac") @classmethod @@ -620,7 +755,11 @@ class InAppNotification(BaseModel): guid: Optional[str] = Field(None, description="Unique notification GUID") text: str = Field(..., description="Notification text content") level: NOTIFICATION_LEVELS = Field("info", description="Notification level") - read: Optional[int] = Field(0, description="Read status (0 or 1)") + read: Optional[int] = Field( + 0, + description="Read status (0 or 1)", + json_schema_extra={"enum": [0, 1]} + ) created_at: Optional[str] = Field(None, description="Creation timestamp") @@ -665,10 +804,12 @@ class DbQueryRequest(BaseModel): """ Request for raw database query. WARNING: This is a highly privileged operation. + Can be used to read settings by querying the 'Settings' table. """ rawSql: str = Field( ..., - description="Base64-encoded SQL query. (UNSAFE: Use only for administrative tasks)" + description="Base64-encoded SQL query. (UNSAFE: Use only for administrative tasks)", + json_schema_extra={"examples": ["U0VMRUNUICogRlJPTSBTZXR0aW5ncw=="]} ) # Legacy compatibility: removed strict safety check # TODO: SECURITY CRITICAL - Re-enable strict safety checks. @@ -690,9 +831,23 @@ class DbQueryRequest(BaseModel): class DbQueryUpdateRequest(BaseModel): - """Request for DB update query.""" + """ + Request for DB update query. + Can be used to update settings by targeting the 'Settings' table. + """ columnName: str = Field(..., description="Column to filter by") - id: List[Any] = Field(..., description="List of IDs to update") + id: List[Union[str, int]] = Field( + ..., + description="List of IDs to update. Use MAC address strings for 'Devices' table, and integer RowIDs for all other tables.", + json_schema_extra={ + "items": { + "oneOf": [ + {"type": "string", "description": "A string identifier (e.g., MAC address)"}, + {"type": "integer", "description": "A numeric row ID"} + ] + } + } + ) dbtable: ALLOWED_TABLES = Field(..., description="Table name") columns: List[str] = Field(..., description="Columns to update") values: List[Any] = Field(..., description="New values") @@ -715,9 +870,23 @@ class DbQueryUpdateRequest(BaseModel): class DbQueryDeleteRequest(BaseModel): - """Request for DB delete query.""" + """ + Request for DB delete query. + Can be used to delete settings by targeting the 'Settings' table. + """ columnName: str = Field(..., description="Column to filter by") - id: List[Any] = Field(..., description="List of IDs to delete") + id: List[Union[str, int]] = Field( + ..., + description="List of IDs to delete. Use MAC address strings for 'Devices' table, and integer RowIDs for all other tables.", + json_schema_extra={ + "items": { + "oneOf": [ + {"type": "string", "description": "A string identifier (e.g., MAC address)"}, + {"type": "integer", "description": "A numeric row ID"} + ] + } + } + ) dbtable: ALLOWED_TABLES = Field(..., description="Table name") @field_validator("columnName") @@ -772,3 +941,14 @@ class SettingValue(BaseModel): class GetSettingResponse(BaseResponse): """Response for getting a setting value.""" value: Any = Field(None, description="The setting value") + + +# ============================================================================= +# GRAPHQL SCHEMAS +# ============================================================================= + + +class GraphQLRequest(BaseModel): + """Request payload for GraphQL queries.""" + query: str = Field(..., description="GraphQL query string", json_schema_extra={"examples": ["{ devices { devMac devName } }"]}) + variables: Optional[Dict[str, Any]] = Field(None, description="Variables for the GraphQL query") diff --git a/server/api_server/openapi/spec_generator.py b/server/api_server/openapi/spec_generator.py index 12154624..3d42779f 100644 --- a/server/api_server/openapi/spec_generator.py +++ b/server/api_server/openapi/spec_generator.py @@ -29,7 +29,7 @@ Usage: """ from __future__ import annotations - +import os import threading from typing import Optional, List, Dict, Any @@ -52,7 +52,7 @@ _rebuild_lock = threading.Lock() def generate_openapi_spec( title: str = "NetAlertX API", version: str = "2.0.0", - description: str = "NetAlertX Network Monitoring API - MCP Compatible", + description: str = "NetAlertX Network Monitoring API - Official Documentation - MCP Compatible", servers: Optional[List[Dict[str, str]]] = None, flask_app: Optional[Any] = None ) -> Dict[str, Any]: @@ -74,18 +74,58 @@ def generate_openapi_spec( introspect_graphql_schema(devicesSchema) introspect_flask_app(flask_app) + # Apply default disabled tools from setting `MCP_DISABLED_TOOLS`, env var, or hard-coded defaults + # Format: comma-separated operation IDs, e.g. "dbquery_read,dbquery_write" + try: + disabled_env = None + # Prefer setting from app.conf/settings when available + try: + from helper import get_setting_value + setting_val = get_setting_value("MCP_DISABLED_TOOLS") + if setting_val: + disabled_env = str(setting_val).strip() + except Exception: + # If helper is unavailable, fall back to environment + pass + + if not disabled_env: + env_val = os.getenv("MCP_DISABLED_TOOLS") + if env_val: + disabled_env = env_val.strip() + + # If still not set, apply safe hard-coded defaults + if not disabled_env: + disabled_env = "dbquery_read,dbquery_write" + + if disabled_env: + from .registry import set_tool_disabled + for op in [p.strip() for p in disabled_env.split(",") if p.strip()]: + set_tool_disabled(op, True) + except Exception: + # Never fail spec generation due to disablement application issues + pass + spec = { "openapi": "3.1.0", "info": { "title": title, "version": version, "description": description, + "termsOfService": "https://github.com/netalertx/NetAlertX/blob/main/LICENSE.txt", "contact": { - "name": "NetAlertX", - "url": "https://github.com/jokob-sk/NetAlertX" + "name": "Open Source Project - NetAlertX - Github", + "url": "https://github.com/netalertx/NetAlertX" + }, + "license": { + "name": "Licensed under GPLv3", + "url": "https://www.gnu.org/licenses/gpl-3.0.html" } }, - "servers": servers or [{"url": "/", "description": "Local server"}], + "externalDocs": { + "description": "NetAlertX Official Documentation", + "url": "https://docs.netalertx.com/" + }, + "servers": servers or [{"url": "/", "description": "This NetAlertX instance"}], "security": [ {"BearerAuth": []} ], @@ -152,7 +192,11 @@ def generate_openapi_spec( # Add responses operation["responses"] = build_responses( - entry.get("response_model"), definitions + entry.get("response_model"), + definitions, + response_content_types=entry.get("response_content_types", ["application/json"]), + links=entry.get("links"), + method=method ) spec["paths"][path][method] = operation diff --git a/server/api_server/openapi/validation.py b/server/api_server/openapi/validation.py index 33f1adcc..97617dd1 100644 --- a/server/api_server/openapi/validation.py +++ b/server/api_server/openapi/validation.py @@ -44,7 +44,11 @@ def validate_request( query_params: Optional[list[dict]] = None, validation_error_code: int = 422, auth_callable: Optional[Callable[[], bool]] = None, - allow_multipart_payload: bool = False + allow_multipart_payload: bool = False, + exclude_from_spec: bool = False, + response_content_types: Optional[list[str]] = None, + links: Optional[dict] = None, + error_responses: Optional[dict] = None ): """ Decorator to register a Flask route with the OpenAPI registry and validate incoming requests. @@ -56,6 +60,10 @@ def validate_request( - Supports auth_callable to check permissions before validation. - Returns 422 (default) if validation fails. - allow_multipart_payload: If True, allows multipart/form-data and attempts validation from form fields. + - exclude_from_spec: If True, this endpoint will be omitted from the generated OpenAPI specification. + - response_content_types: List of supported response media types (e.g. ["application/json", "text/csv"]). + - links: Dictionary of OpenAPI links to include in the response definition. + - error_responses: Dictionary of custom error examples (e.g. {"404": "Device not found"}). """ def decorator(f: Callable) -> Callable: @@ -73,7 +81,11 @@ def validate_request( "tags": tags, "path_params": path_params, "query_params": query_params, - "allow_multipart_payload": allow_multipart_payload + "allow_multipart_payload": allow_multipart_payload, + "exclude_from_spec": exclude_from_spec, + "response_content_types": response_content_types, + "links": links, + "error_responses": error_responses } @wraps(f) @@ -150,6 +162,7 @@ def validate_request( data = request.args.to_dict() validated_instance = request_model(**data) except ValidationError as e: + # Use configured validation error code (default 422) return _handle_validation_error(e, operation_id, validation_error_code) except (TypeError, ValueError, KeyError) as e: mylog("verbose", [f"[Validation] Query param validation failed for {operation_id}: {e}"]) diff --git a/test/api_endpoints/test_mcp_disabled_tools.py b/test/api_endpoints/test_mcp_disabled_tools.py new file mode 100644 index 00000000..fd9e67b2 --- /dev/null +++ b/test/api_endpoints/test_mcp_disabled_tools.py @@ -0,0 +1,63 @@ + +import os +import sys +import pytest +from unittest.mock import patch, MagicMock + +# Use cwd as fallback if env var is not set, assuming running from project root +INSTALL_PATH = os.getenv('NETALERTX_APP', os.getcwd()) +sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) + +from api_server.openapi.spec_generator import generate_openapi_spec +from api_server.api_server_start import app + +class TestMCPDisabledTools: + + def test_disabled_tools_via_env_var(self): + """Test that MCP_DISABLED_TOOLS env var disables specific tools.""" + # Clean registry first to ensure clean state + from api_server.openapi.registry import clear_registry + clear_registry() + + # Mock get_setting_value to return None (simulating no config setting) + # and mock os.getenv to return our target list + with patch("helper.get_setting_value", return_value=None), \ + patch.dict(os.environ, {"MCP_DISABLED_TOOLS": "search_devices_api"}): + + spec = generate_openapi_spec(flask_app=app) + + # Locate the operation + # search_devices_api is usually mapped to /devices/search [POST] or similar + # We search the spec for the operationId + + found = False + for path, methods in spec["paths"].items(): + for method, op in methods.items(): + if op["operationId"] == "search_devices_api": + assert op.get("x-mcp-disabled") is True + found = True + + assert found, "search_devices_api operation not found in spec" + + def test_disabled_tools_default_fallback(self): + """Test fallback to defaults when no setting or env var exists.""" + from api_server.openapi.registry import clear_registry + clear_registry() + + with patch("helper.get_setting_value", return_value=None), \ + patch.dict(os.environ, {}, clear=True): # Clear env to ensure no MCP_DISABLED_TOOLS + + spec = generate_openapi_spec(flask_app=app) + + # Default is "dbquery_read,dbquery_write" + + # Check dbquery_read + found_read = False + for path, methods in spec["paths"].items(): + for method, op in methods.items(): + if op["operationId"] == "dbquery_read": + assert op.get("x-mcp-disabled") is True + found_read = True + + assert found_read, "dbquery_read should be disabled by default" + diff --git a/test/api_endpoints/test_mcp_openapi_spec.py b/test/api_endpoints/test_mcp_openapi_spec.py index f92b1f82..c92348b8 100644 --- a/test/api_endpoints/test_mcp_openapi_spec.py +++ b/test/api_endpoints/test_mcp_openapi_spec.py @@ -66,7 +66,7 @@ class TestPydanticSchemas: """WakeOnLanRequest should validate MAC format.""" # Valid MAC req = WakeOnLanRequest(devMac="00:11:22:33:44:55") - assert req.devMac == "00:11:22:33:44:55" + assert req.mac == "00:11:22:33:44:55" # Invalid MAC # with pytest.raises(ValidationError): @@ -76,7 +76,7 @@ class TestPydanticSchemas: """WakeOnLanRequest should accept either MAC or IP.""" req_mac = WakeOnLanRequest(devMac="00:11:22:33:44:55") req_ip = WakeOnLanRequest(devLastIP="192.168.1.50") - assert req_mac.devMac is not None + assert req_mac.mac is not None assert req_ip.devLastIP == "192.168.1.50" def test_traceroute_request_ip_validation(self): @@ -197,7 +197,7 @@ class TestOpenAPISpecGenerator: f"Path param '{param_name}' not defined: {method.upper()} {path}" def test_standard_error_responses(self): - """Operations should have minimal standard error responses (400, 403, 404, etc) without schema bloat.""" + """Operations should have standard error responses (400, 403, 404, etc).""" spec = generate_openapi_spec() expected_minimal_codes = ["400", "401", "403", "404", "500", "422"] @@ -207,21 +207,28 @@ class TestOpenAPISpecGenerator: continue responses = details.get("responses", {}) for code in expected_minimal_codes: - assert code in responses, f"Missing minimal {code} response in: {method.upper()} {path}." - # Verify no "content" or schema is present (minimalism) - assert "content" not in responses[code], f"Response {code} in {method.upper()} {path} should not have content/schema." + assert code in responses, f"Missing {code} response in: {method.upper()} {path}." + # Content should now be present (BaseResponse/Error schema) + assert "content" in responses[code], f"Response {code} in {method.upper()} {path} should have content/schema." class TestMCPToolMapping: """Test MCP tool generation from OpenAPI spec.""" def test_tools_match_registry_count(self): - """Number of MCP tools should match registered endpoints.""" + """Number of MCP tools should match unique original operation IDs in registry.""" spec = generate_openapi_spec() tools = map_openapi_to_mcp_tools(spec) registry = get_registry() - assert len(tools) == len(registry) + # Count unique operation IDs (accounting for our deduplication logic) + unique_ops = set() + for entry in registry: + # We used x-original-operationId for deduplication logic, or operation_id if not present + op_id = entry.get("original_operation_id") or entry["operation_id"] + unique_ops.add(op_id) + + assert len(tools) == len(unique_ops) def test_tools_have_input_schema(self): """All MCP tools should have inputSchema.""" @@ -239,9 +246,9 @@ class TestMCPToolMapping: spec = generate_openapi_spec() tools = map_openapi_to_mcp_tools(spec) - search_tool = next((t for t in tools if t["name"] == "search_devices"), None) + search_tool = next((t for t in tools if t["name"] == "search_devices_api"), None) assert search_tool is not None - assert "query" in search_tool["inputSchema"].get("required", []) + assert "query" in search_tool["inputSchema"]["required"] def test_tool_descriptions_present(self): """All tools should have non-empty descriptions.""" diff --git a/test/api_endpoints/test_schema_converter.py b/test/api_endpoints/test_schema_converter.py new file mode 100644 index 00000000..e69de29b diff --git a/test/authoritative_fields/test_device_field_lock.py b/test/authoritative_fields/test_device_field_lock.py index 168d6875..7703daf1 100644 --- a/test/authoritative_fields/test_device_field_lock.py +++ b/test/authoritative_fields/test_device_field_lock.py @@ -99,8 +99,9 @@ class TestDeviceFieldLock: json=payload, headers=auth_headers ) - assert resp.status_code == 400 - assert "fieldName is required" in resp.json.get("error", "") + assert resp.status_code == 422 + # Pydantic error message format for missing fields + assert "Missing required 'fieldName'" in resp.json.get("error", "") def test_lock_field_invalid_field_name(self, client, test_mac, auth_headers): """Lock endpoint rejects untracked fields."""