diff --git a/server/api_server/api_server_start.py b/server/api_server/api_server_start.py index eb330dbe..298530f8 100755 --- a/server/api_server/api_server_start.py +++ b/server/api_server/api_server_start.py @@ -100,6 +100,18 @@ from .sse_endpoint import ( # noqa: E402 [flake8 lint suppression] app = Flask(__name__) +@app.errorhandler(500) +@app.errorhandler(Exception) +def handle_500_error(e): + """Global error handler for uncaught exceptions.""" + mylog("none", [f"[API] Uncaught exception: {e}"]) + return jsonify({ + "success": False, + "error": "Internal Server Error", + "message": "Something went wrong on the server" + }), 500 + + # Parse CORS origins from environment or use safe defaults _cors_origins_env = os.environ.get("CORS_ORIGINS", "") _cors_origins = [ @@ -599,7 +611,7 @@ def api_device_open_ports(payload=None): @validate_request( operation_id="get_all_devices", summary="Get All Devices", - description="Retrieve a list of all devices in the system.", + description="Retrieve a list of all devices in the system. Returns all records. No pagination supported.", response_model=DeviceListWrapperResponse, tags=["devices"], auth_callable=is_authorized @@ -662,7 +674,7 @@ def api_delete_unknown_devices(payload=None): @app.route("/devices/export", methods=["GET"]) @app.route("/devices/export/", methods=["GET"]) @validate_request( - operation_id="export_devices", + operation_id="export_devices_all", summary="Export Devices", description="Export all devices in CSV or JSON format.", query_params=[{ @@ -679,7 +691,8 @@ def api_delete_unknown_devices(payload=None): }], response_model=DeviceExportResponse, tags=["devices"], - auth_callable=is_authorized + auth_callable=is_authorized, + response_content_types=["application/json", "text/csv"] ) def api_export_devices(format=None, payload=None): export_format = (format or request.args.get("format", "csv")).lower() @@ -747,7 +760,7 @@ def api_devices_totals(payload=None): @app.route('/mcp/sse/devices/by-status', methods=['GET', 'POST']) @app.route("/devices/by-status", methods=["GET", "POST"]) @validate_request( - operation_id="list_devices_by_status", + operation_id="list_devices_by_status_api", summary="List Devices by Status", description="List devices filtered by their online/offline status.", request_model=DeviceListRequest, @@ -763,7 +776,30 @@ def api_devices_totals(payload=None): "connected", "down", "favorites", "new", "archived", "all", "my", "offline" ]} - }] + }], + links={ + "GetOpenPorts": { + "operationId": "get_open_ports", + "parameters": { + "target": "$response.body#/0/devLastIP" + }, + "description": "The `target` parameter for `get_open_ports` requires an IP address. Use the `devLastIP` from the first device in the list." + }, + "WakeOnLan": { + "operationId": "wake_on_lan", + "parameters": { + "devMac": "$response.body#/0/devMac" + }, + "description": "The `devMac` parameter for `wake_on_lan` requires a MAC address. Use the `devMac` from the first device in the list." + }, + "UpdateDevice": { + "operationId": "update_device", + "parameters": { + "mac": "$response.body#/0/devMac" + }, + "description": "The `mac` parameter for `update_device` is a path parameter. Use the `devMac` from the first device in the list." + } + } ) def api_devices_by_status(payload: DeviceListRequest = None): status = payload.status if payload else request.args.get("status") @@ -774,13 +810,43 @@ def api_devices_by_status(payload: DeviceListRequest = None): @app.route('/mcp/sse/devices/search', methods=['POST']) @app.route('/devices/search', methods=['POST']) @validate_request( - operation_id="search_devices", + operation_id="search_devices_api", summary="Search Devices", description="Search for devices based on various criteria like name, IP, MAC, or vendor.", request_model=DeviceSearchRequest, response_model=DeviceSearchResponse, tags=["devices"], - auth_callable=is_authorized + auth_callable=is_authorized, + links={ + "GetOpenPorts": { + "operationId": "get_open_ports", + "parameters": { + "target": "$response.body#/devices/0/devLastIP" + }, + "description": "The `target` parameter for `get_open_ports` requires an IP address. Use the `devLastIP` from the first device in the search results." + }, + "WakeOnLan": { + "operationId": "wake_on_lan", + "parameters": { + "devMac": "$response.body#/devices/0/devMac" + }, + "description": "The `devMac` parameter for `wake_on_lan` requires a MAC address. Use the `devMac` from the first device in the search results." + }, + "NmapScan": { + "operationId": "run_nmap_scan", + "parameters": { + "scan": "$response.body#/devices/0/devLastIP" + }, + "description": "The `scan` parameter for `run_nmap_scan` requires an IP or range. Use the `devLastIP` from the first device in the search results." + }, + "UpdateDevice": { + "operationId": "update_device", + "parameters": { + "mac": "$response.body#/devices/0/devMac" + }, + "description": "The `mac` parameter for `update_device` is a path parameter. Use the `devMac` from the first device in the search results." + } + } ) def api_devices_search(payload=None): """Device search: accepts 'query' in JSON and maps to device info/search.""" @@ -1011,7 +1077,7 @@ def api_network_interfaces(payload=None): @app.route('/mcp/sse/nettools/trigger-scan', methods=['POST']) -@app.route("/nettools/trigger-scan", methods=["GET"]) +@app.route("/nettools/trigger-scan", methods=["GET", "POST"]) @validate_request( operation_id="trigger_network_scan", summary="Trigger Network Scan", @@ -1300,13 +1366,25 @@ def api_create_event(mac, payload=None): @app.route("/events/", methods=["DELETE"]) @validate_request( - operation_id="delete_events_by_mac", - summary="Delete Events by MAC", - description="Delete all events for a specific device MAC address.", + operation_id="delete_events", + summary="Delete Events", + description="Delete events by device MAC address or older than a specified number of days.", path_params=[{ "name": "mac", - "description": "Device MAC address", - "schema": {"type": "string"} + "description": "Device MAC address or number of days", + "schema": { + "oneOf": [ + { + "type": "integer", + "description": "Number of days (e.g., 30) to delete events older than this value." + }, + { + "type": "string", + "pattern": "^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$", + "description": "Device MAC address to delete all events for a specific device." + } + ] + } }], response_model=BaseResponse, tags=["events"], @@ -1315,6 +1393,7 @@ def api_create_event(mac, payload=None): def api_events_by_mac(mac, payload=None): """Delete events for a specific device MAC; string converter keeps this distinct from /events/.""" device_handler = DeviceInstance() + result = device_handler.deleteDeviceEvents(mac) return jsonify(result) @@ -1338,7 +1417,7 @@ def api_delete_all_events(payload=None): @validate_request( operation_id="get_all_events", summary="Get Events", - description="Retrieve a list of events, optionally filtered by MAC.", + description="Retrieve a list of events, optionally filtered by MAC. Returns all matching records. No pagination supported.", query_params=[{ "name": "mac", "description": "Filter by Device MAC", @@ -1372,7 +1451,8 @@ def api_get_events(payload=None): }], response_model=BaseResponse, tags=["events"], - auth_callable=is_authorized + auth_callable=is_authorized, + exclude_from_spec=True ) def api_delete_old_events(days: int, payload=None): """ @@ -1406,7 +1486,7 @@ def api_get_events_totals(payload=None): @app.route('/mcp/sse/events/recent', methods=['GET', 'POST']) -@app.route('/events/recent', methods=['GET']) +@app.route('/events/recent', methods=['GET', 'POST']) @validate_request( operation_id="get_recent_events", summary="Get Recent Events", @@ -1426,7 +1506,7 @@ def api_events_default_24h(payload=None): @app.route('/mcp/sse/events/last', methods=['GET', 'POST']) -@app.route('/events/last', methods=['GET']) +@app.route('/events/last', methods=['GET', 'POST']) @validate_request( operation_id="get_last_events", summary="Get Last Events", @@ -1763,7 +1843,7 @@ def sync_endpoint_post(payload=None): @validate_request( operation_id="check_auth", summary="Check Authentication", - description="Check if the current API token is valid.", + description="Check if the current API token is valid. Note: tokens must be generated externally via the UI or CLI.", response_model=BaseResponse, tags=["auth"], auth_callable=is_authorized diff --git a/server/api_server/openapi/introspection.py b/server/api_server/openapi/introspection.py index 2c1454de..d2b74c22 100644 --- a/server/api_server/openapi/introspection.py +++ b/server/api_server/openapi/introspection.py @@ -1,10 +1,12 @@ from __future__ import annotations import re -from typing import Any +from typing import Any, Dict, Optional import graphene from .registry import register_tool, _operation_ids +from .schemas import GraphQLRequest +from .schema_converter import pydantic_to_json_schema, resolve_schema_refs def introspect_graphql_schema(schema: graphene.Schema): @@ -26,6 +28,7 @@ def introspect_graphql_schema(schema: graphene.Schema): operation_id="graphql_query", summary="GraphQL Endpoint", description="Execute arbitrary GraphQL queries against the system schema.", + request_model=GraphQLRequest, tags=["graphql"] ) @@ -36,6 +39,20 @@ def _flask_to_openapi_path(flask_path: str) -> str: return re.sub(r'<(?:\w+:)?(\w+)>', r'{\1}', flask_path) +def _get_openapi_metadata(func: Any) -> Optional[Dict[str, Any]]: + """Recursively find _openapi_metadata in wrapped functions.""" + # Check current function + metadata = getattr(func, "_openapi_metadata", None) + if metadata: + return metadata + + # Check __wrapped__ (standard for @wraps) + if hasattr(func, "__wrapped__"): + return _get_openapi_metadata(func.__wrapped__) + + return None + + def introspect_flask_app(app: Any): """ Introspect the Flask application to find routes decorated with @validate_request @@ -47,14 +64,13 @@ def introspect_flask_app(app: Any): if not view_func: continue - # Check for our decorator's metadata - metadata = getattr(view_func, "_openapi_metadata", None) - if not metadata: - # Fallback for wrapped functions - if hasattr(view_func, "__wrapped__"): - metadata = getattr(view_func.__wrapped__, "_openapi_metadata", None) + # Check for our decorator's metadata recursively + metadata = _get_openapi_metadata(view_func) if metadata: + if metadata.get("exclude_from_spec"): + continue + op_id = metadata["operation_id"] # Register the tool with real path and method from Flask @@ -75,11 +91,8 @@ def introspect_flask_app(app: Any): # Determine tags - create a copy to avoid mutating shared metadata tags = list(metadata.get("tags") or ["rest"]) if path.startswith("/mcp/"): - # Move specific tags to secondary position or just add MCP - if "rest" in tags: - tags.remove("rest") - if "mcp" not in tags: - tags.append("mcp") + # For MCP endpoints, we want them exclusively in the 'mcp' tag section + tags = ["mcp"] # Ensure unique operationId original_op_id = op_id @@ -89,6 +102,38 @@ def introspect_flask_app(app: Any): unique_op_id = f"{op_id}_{count}" count += 1 + # Filter path_params to only include those that are actually in the path + path_params = metadata.get("path_params") + if path_params: + path_params = [ + p for p in path_params + if f"{{{p['name']}}}" in path + ] + + # Auto-generate query_params from request_model for GET requests + query_params = metadata.get("query_params") + if method == 'GET' and not query_params and metadata.get("request_model"): + try: + schema = pydantic_to_json_schema(metadata["request_model"]) + properties = schema.get("properties", {}) + query_params = [] + for name, prop in properties.items(): + is_required = name in schema.get("required", []) + # Create param definition, preserving enum/schema + param_def = { + "name": name, + "in": "query", + "required": is_required, + "description": prop.get("description", ""), + "schema": prop + } + # Remove description from schema to avoid duplication + if "description" in param_def["schema"]: + del param_def["schema"]["description"] + query_params.append(param_def) + except Exception: + pass # Fallback to empty if schema generation fails + register_tool( path=path, method=method, @@ -98,9 +143,11 @@ def introspect_flask_app(app: Any): description=metadata["description"], request_model=metadata.get("request_model"), response_model=metadata.get("response_model"), - path_params=metadata.get("path_params"), - query_params=metadata.get("query_params"), + path_params=path_params, + query_params=query_params, tags=tags, - allow_multipart_payload=metadata.get("allow_multipart_payload", False) + allow_multipart_payload=metadata.get("allow_multipart_payload", False), + response_content_types=metadata.get("response_content_types"), + links=metadata.get("links") ) registered_ops.add(op_key) diff --git a/server/api_server/openapi/registry.py b/server/api_server/openapi/registry.py index fcd2fa91..6d8759b3 100644 --- a/server/api_server/openapi/registry.py +++ b/server/api_server/openapi/registry.py @@ -96,7 +96,9 @@ def register_tool( tags: Optional[List[str]] = None, deprecated: bool = False, original_operation_id: Optional[str] = None, - allow_multipart_payload: bool = False + allow_multipart_payload: bool = False, + response_content_types: Optional[List[str]] = None, + links: Optional[Dict[str, Any]] = None ) -> None: """ Register an API endpoint for OpenAPI spec generation. @@ -115,6 +117,8 @@ def register_tool( deprecated: Whether this endpoint is deprecated original_operation_id: The base ID before suffixing (for disablement mapping) allow_multipart_payload: Whether to allow multipart/form-data payloads + response_content_types: List of supported response media types (e.g. ["application/json", "text/csv"]) + links: Dictionary of OpenAPI links to include in the response definition. Raises: DuplicateOperationIdError: If operation_id already exists in registry @@ -140,7 +144,9 @@ def register_tool( "query_params": query_params or [], "tags": tags or ["default"], "deprecated": deprecated, - "allow_multipart_payload": allow_multipart_payload + "allow_multipart_payload": allow_multipart_payload, + "response_content_types": response_content_types or ["application/json"], + "links": links }) diff --git a/server/api_server/openapi/schema_converter.py b/server/api_server/openapi/schema_converter.py index c6979527..60e9d1b5 100644 --- a/server/api_server/openapi/schema_converter.py +++ b/server/api_server/openapi/schema_converter.py @@ -2,6 +2,7 @@ from __future__ import annotations from typing import Dict, Any, Optional, Type, List from pydantic import BaseModel +from .schemas import ErrorResponse, BaseResponse def pydantic_to_json_schema(model: Type[BaseModel], mode: str = "validation") -> Dict[str, Any]: @@ -161,57 +162,124 @@ def strip_validation(schema: Dict[str, Any]) -> Dict[str, Any]: return clean_schema +def resolve_schema_refs(schema: Dict[str, Any], definitions: Dict[str, Any]) -> Dict[str, Any]: + """ + Recursively resolve $ref in schema by inlining the definition. + Useful for standalone schema parts like query parameters where global definitions aren't available. + """ + if not isinstance(schema, dict): + return schema + + if "$ref" in schema: + ref = schema["$ref"] + # Handle #/$defs/Name syntax + if ref.startswith("#/$defs/"): + def_name = ref.split("/")[-1] + if def_name in definitions: + # Inline the definition (and resolve its refs recursively) + inlined = resolve_schema_refs(definitions[def_name], definitions) + # Merge any extra keys from the original schema (e.g. description override) + # Schema keys take precedence over definition keys + return {**inlined, **{k: v for k, v in schema.items() if k != "$ref"}} + + # Recursively resolve properties + resolved = {} + for k, v in schema.items(): + if k == "items": + resolved[k] = resolve_schema_refs(v, definitions) + elif k == "properties": + resolved[k] = {pk: resolve_schema_refs(pv, definitions) for pk, pv in v.items()} + elif k in ("allOf", "anyOf", "oneOf"): + resolved[k] = [resolve_schema_refs(i, definitions) for i in v] + else: + resolved[k] = v + + return resolved + + def build_responses( - response_model: Optional[Type[BaseModel]], definitions: Dict[str, Any] + response_model: Optional[Type[BaseModel]], + definitions: Dict[str, Any], + response_content_types: Optional[List[str]] = None, + links: Optional[Dict[str, Any]] = None, + method: str = "post" ) -> Dict[str, Any]: """Build OpenAPI responses object.""" responses = {} - # Success response (200) - if response_model: - # Strip validation from response schema to save tokens - schema = strip_validation(pydantic_to_json_schema(response_model, mode="serialization")) - schema = extract_definitions(schema, definitions) - responses["200"] = { - "description": "Successful response", - "content": { - "application/json": { - "schema": schema - } - } - } + # Use a fresh list for response content types to avoid a shared mutable default. + if response_content_types is None: + response_content_types = ["application/json"] else: - responses["200"] = { - "description": "Successful response", - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": { - "success": {"type": "boolean"}, - "message": {"type": "string"} - } - } - } - } - } + # Copy provided list to ensure each call gets its own list + response_content_types = list(response_content_types) - # Standard error responses - MINIMIZED context - # Annotate that these errors can occur, but provide no schema/content to save tokens. - # The LLM knows what "Bad Request" or "Not Found" means. - error_codes = { - "400": "Bad Request", - "401": "Unauthorized", - "403": "Forbidden", - "404": "Not Found", - "422": "Validation Error", - "500": "Internal Server Error" + # Success response (200) + effective_model = response_model or BaseResponse + schema = strip_validation(pydantic_to_json_schema(effective_model, mode="serialization")) + schema = extract_definitions(schema, definitions) + + content = {} + for ct in response_content_types: + if ct == "application/json": + content[ct] = {"schema": schema} + else: + # For non-JSON types like CSV, we don't necessarily use the JSON schema + content[ct] = {"schema": {"type": "string", "format": "binary"}} + + response_obj = { + "description": "Successful response", + "content": content + } + if links: + response_obj["links"] = links + responses["200"] = response_obj + + # Standard error responses + error_configs = { + "400": ("Invalid JSON", "Request body must be valid JSON"), + "401": ("Unauthorized", None), + "403": ("Forbidden", "ERROR: Not authorized"), + "404": ("API route not found", "The requested URL /example/path was not found on the server."), + "422": ("Validation Error", None), + "500": ("Internal Server Error", "Something went wrong on the server") } - for code, desc in error_codes.items(): + for code, (error_val, message_val) in error_configs.items(): + # Generate a fresh schema for each error to customize examples + error_schema_raw = strip_validation(pydantic_to_json_schema(ErrorResponse, mode="serialization")) + error_schema = extract_definitions(error_schema_raw, definitions) + + # Inject status-specific example + if "examples" in error_schema and len(error_schema["examples"]) > 0: + example = { + "success": False, + "error": error_val + } + if message_val: + example["message"] = message_val + + if code == "422": + example["error"] = "Validation Error: Input should be a valid string" + example["details"] = [ + { + "input": "invalid_value", + "loc": ["field_name"], + "msg": "Input should be a valid string", + "type": "string_type", + "url": "https://errors.pydantic.dev/2.12/v/string_type" + } + ] + + error_schema["examples"] = [example] + responses[code] = { - "description": desc - # No "content" schema provided + "description": error_val, + "content": { + "application/json": { + "schema": error_schema + } + } } return responses diff --git a/server/api_server/openapi/schemas.py b/server/api_server/openapi/schemas.py index 57db072b..c3fdc6c4 100644 --- a/server/api_server/openapi/schemas.py +++ b/server/api_server/openapi/schemas.py @@ -52,6 +52,16 @@ ALLOWED_LOG_FILES = Literal[ "app.php_errors.log", "execution_queue.log", "db_is_locked.log" ] +ALLOWED_SCAN_TYPES = Literal["ARPSCAN", "NMAPDEV", "NMAP", "INTRNT", "AVAHISCAN", "NBTSCAN"] + +ALLOWED_SESSION_CONNECTION_TYPES = Literal["Connected", "Reconnected", "New Device", "Down Reconnected"] +ALLOWED_SESSION_DISCONNECTION_TYPES = Literal["Disconnected", "Device Down", "Timeout"] + +ALLOWED_EVENT_TYPES = Literal[ + "Device Down", "Device Up", "New Device", "Connected", "Disconnected", + "IP Changed", "Down Reconnected" +] + def validate_mac(value: str) -> str: """Validate and normalize MAC address format.""" @@ -89,14 +99,42 @@ def validate_column_identifier(value: str) -> str: class BaseResponse(BaseModel): - """Standard API response wrapper.""" - model_config = ConfigDict(extra="allow") + """ + Standard API response wrapper. + Note: The API often returns 200 OK for most operations; clients MUST parse the 'success' + boolean field to determine if the operation was actually successful. + """ + model_config = ConfigDict( + extra="allow", + json_schema_extra={ + "examples": [{ + "success": True + }] + } + ) success: bool = Field(..., description="Whether the operation succeeded") message: Optional[str] = Field(None, description="Human-readable message") error: Optional[str] = Field(None, description="Error message if success=False") +class ErrorResponse(BaseResponse): + """Standard error response model with details.""" + model_config = ConfigDict( + extra="allow", + json_schema_extra={ + "examples": [{ + "success": False, + "error": "Error message" + }] + } + ) + + success: bool = Field(False, description="Always False for errors") + details: Optional[Any] = Field(None, description="Detailed error information (e.g., validation errors)") + code: Optional[str] = Field(None, description="Internal error code") + + class PaginatedResponse(BaseResponse): """Response with pagination metadata.""" total: int = Field(0, description="Total number of items") @@ -130,7 +168,19 @@ class DeviceSearchRequest(BaseModel): class DeviceInfo(BaseModel): """Detailed device information model (Raw record).""" - model_config = ConfigDict(extra="allow") + model_config = ConfigDict( + extra="allow", + json_schema_extra={ + "examples": [{ + "devMac": "00:11:22:33:44:55", + "devName": "My iPhone", + "devLastIP": "192.168.1.10", + "devVendor": "Apple", + "devStatus": "online", + "devFavorite": 0 + }] + } + ) devMac: str = Field(..., description="Device MAC address") devName: Optional[str] = Field(None, description="Device display name/alias") @@ -138,13 +188,27 @@ class DeviceInfo(BaseModel): devPrimaryIPv4: Optional[str] = Field(None, description="Primary IPv4 address") devPrimaryIPv6: Optional[str] = Field(None, description="Primary IPv6 address") devVlan: Optional[str] = Field(None, description="VLAN identifier") - devForceStatus: Optional[str] = Field(None, description="Force device status (online/offline/dont_force)") + devForceStatus: Optional[Literal["online", "offline", "dont_force"]] = Field( + "dont_force", + description="Force device status (online/offline/dont_force)" + ) devVendor: Optional[str] = Field(None, description="Hardware vendor from OUI lookup") devOwner: Optional[str] = Field(None, description="Device owner") devType: Optional[str] = Field(None, description="Device type classification") - devFavorite: Optional[int] = Field(0, description="Favorite flag (0 or 1)") - devPresentLastScan: Optional[int] = Field(None, description="Present in last scan (0 or 1)") - devStatus: Optional[str] = Field(None, description="Online/Offline status") + devFavorite: Optional[int] = Field( + 0, + description="Favorite flag (0=False, 1=True). Legacy boolean representation.", + json_schema_extra={"enum": [0, 1]} + ) + devPresentLastScan: Optional[int] = Field( + None, + description="Present in last scan (0 or 1)", + json_schema_extra={"enum": [0, 1]} + ) + devStatus: Optional[Literal["online", "offline"]] = Field( + None, + description="Online/Offline status" + ) devMacSource: Optional[str] = Field(None, description="Source of devMac (USER, LOCKED, or plugin prefix)") devNameSource: Optional[str] = Field(None, description="Source of devName") devFQDNSource: Optional[str] = Field(None, description="Source of devFQDN") @@ -270,7 +334,18 @@ class CopyDeviceRequest(BaseModel): class UpdateDeviceColumnRequest(BaseModel): """Request to update a specific device database column.""" columnName: ALLOWED_DEVICE_COLUMNS = Field(..., description="Database column name") - columnValue: Any = Field(..., description="New value for the column") + columnValue: Union[str, int, bool, None] = Field( + ..., + description="New value for the column", + json_schema_extra={ + "oneOf": [ + {"type": "string"}, + {"type": "integer"}, + {"type": "boolean"}, + {"type": "null"} + ] + } + ) class LockDeviceFieldRequest(BaseModel): @@ -301,7 +376,13 @@ class DeviceUpdateRequest(BaseModel): devName: Optional[str] = Field(None, description="Device name") devOwner: Optional[str] = Field(None, description="Device owner") - devType: Optional[str] = Field(None, description="Device type") + devType: Optional[str] = Field( + None, + description="Device type", + json_schema_extra={ + "examples": ["Phone", "Laptop", "Desktop", "Router", "IoT", "Camera", "Server", "TV"] + } + ) devVendor: Optional[str] = Field(None, description="Device vendor") devGroup: Optional[str] = Field(None, description="Device group") devLocation: Optional[str] = Field(None, description="Device location") @@ -340,10 +421,9 @@ class DeleteDevicesRequest(BaseModel): class TriggerScanRequest(BaseModel): """Request to trigger a network scan.""" - type: str = Field( + type: ALLOWED_SCAN_TYPES = Field( "ARPSCAN", - description="Scan plugin type to execute (e.g., ARPSCAN, NMAPDEV, NMAP)", - json_schema_extra={"examples": ["ARPSCAN", "NMAPDEV", "NMAP"]} + description="Scan plugin type to execute (e.g., ARPSCAN, NMAPDEV, NMAP)" ) @@ -420,7 +500,11 @@ class WakeOnLanRequest(BaseModel): class WakeOnLanResponse(BaseResponse): """Response for Wake-on-LAN operation.""" - output: Optional[str] = Field(None, description="Command output") + output: Optional[str] = Field( + None, + description="Command output", + json_schema_extra={"examples": ["Sent magic packet to AA:BB:CC:DD:EE:FF"]} + ) class TracerouteRequest(BaseModel): @@ -507,7 +591,17 @@ class NetworkInterfacesResponse(BaseResponse): class EventInfo(BaseModel): """Event/alert information.""" - model_config = ConfigDict(extra="allow") + model_config = ConfigDict( + extra="allow", + json_schema_extra={ + "examples": [{ + "eveMAC": "00:11:22:33:44:55", + "eveIP": "192.168.1.10", + "eveDateTime": "2024-01-29 10:00:00", + "eveEventType": "Device Down" + }] + } + ) eveRowid: Optional[int] = Field(None, description="Event row ID") eveMAC: Optional[str] = Field(None, description="Device MAC address") @@ -547,9 +641,19 @@ class LastEventsResponse(BaseResponse): class CreateEventRequest(BaseModel): """Request to create a device event.""" ip: Optional[str] = Field("0.0.0.0", description="Device IP") - event_type: str = Field("Device Down", description="Event type") + event_type: str = Field( + "Device Down", + description="Event type", + json_schema_extra={ + "examples": ["Device Down", "Device Up", "New Device", "Connected", "Disconnected"] + } + ) additional_info: Optional[str] = Field("", description="Additional info") - pending_alert: int = Field(1, description="Pending alert flag") + pending_alert: int = Field( + 1, + description="Pending alert flag (0 or 1)", + json_schema_extra={"enum": [0, 1]} + ) event_time: Optional[str] = Field(None, description="Event timestamp (ISO)") @field_validator("ip", mode="before") @@ -564,11 +668,18 @@ class CreateEventRequest(BaseModel): # ============================================================================= # SESSIONS SCHEMAS # ============================================================================= - - -class SessionInfo(BaseModel): """Session information.""" - model_config = ConfigDict(extra="allow") + model_config = ConfigDict( + extra="allow", + json_schema_extra={ + "examples": [{ + "sesMac": "00:11:22:33:44:55", + "sesDateTimeConnection": "2024-01-29 08:00:00", + "sesDateTimeDisconnection": "2024-01-29 09:00:00", + "sesIPAddress": "192.168.1.10" + }] + } + ) sesRowid: Optional[int] = Field(None, description="Session row ID") sesMac: Optional[str] = Field(None, description="Device MAC address") @@ -583,8 +694,20 @@ class CreateSessionRequest(BaseModel): ip: str = Field(..., description="Device IP") start_time: str = Field(..., description="Start time") end_time: Optional[str] = Field(None, description="End time") - event_type_conn: str = Field("Connected", description="Connection event type") - event_type_disc: str = Field("Disconnected", description="Disconnection event type") + event_type_conn: str = Field( + "Connected", + description="Connection event type", + json_schema_extra={ + "examples": ["Connected", "Reconnected", "New Device", "Down Reconnected"] + } + ) + event_type_disc: str = Field( + "Disconnected", + description="Disconnection event type", + json_schema_extra={ + "examples": ["Disconnected", "Device Down", "Timeout"] + } + ) @field_validator("mac") @classmethod @@ -620,7 +743,11 @@ class InAppNotification(BaseModel): guid: Optional[str] = Field(None, description="Unique notification GUID") text: str = Field(..., description="Notification text content") level: NOTIFICATION_LEVELS = Field("info", description="Notification level") - read: Optional[int] = Field(0, description="Read status (0 or 1)") + read: Optional[int] = Field( + 0, + description="Read status (0 or 1)", + json_schema_extra={"enum": [0, 1]} + ) created_at: Optional[str] = Field(None, description="Creation timestamp") @@ -665,10 +792,12 @@ class DbQueryRequest(BaseModel): """ Request for raw database query. WARNING: This is a highly privileged operation. + Can be used to read settings by querying the 'Settings' table. """ rawSql: str = Field( ..., - description="Base64-encoded SQL query. (UNSAFE: Use only for administrative tasks)" + description="Base64-encoded SQL query. (UNSAFE: Use only for administrative tasks)", + json_schema_extra={"examples": ["U0VMRUNUICogRlJPTSBTZXR0aW5ncw=="]} ) # Legacy compatibility: removed strict safety check # TODO: SECURITY CRITICAL - Re-enable strict safety checks. @@ -690,9 +819,23 @@ class DbQueryRequest(BaseModel): class DbQueryUpdateRequest(BaseModel): - """Request for DB update query.""" + """ + Request for DB update query. + Can be used to update settings by targeting the 'Settings' table. + """ columnName: str = Field(..., description="Column to filter by") - id: List[Any] = Field(..., description="List of IDs to update") + id: List[Union[str, int]] = Field( + ..., + description="List of IDs to update (strings for MACs, integers for row IDs)", + json_schema_extra={ + "items": { + "oneOf": [ + {"type": "string", "description": "A string identifier (e.g., MAC address)"}, + {"type": "integer", "description": "A numeric row ID"} + ] + } + } + ) dbtable: ALLOWED_TABLES = Field(..., description="Table name") columns: List[str] = Field(..., description="Columns to update") values: List[Any] = Field(..., description="New values") @@ -715,9 +858,23 @@ class DbQueryUpdateRequest(BaseModel): class DbQueryDeleteRequest(BaseModel): - """Request for DB delete query.""" + """ + Request for DB delete query. + Can be used to delete settings by targeting the 'Settings' table. + """ columnName: str = Field(..., description="Column to filter by") - id: List[Any] = Field(..., description="List of IDs to delete") + id: List[Union[str, int]] = Field( + ..., + description="List of IDs to delete (strings for MACs, integers for row IDs)", + json_schema_extra={ + "items": { + "oneOf": [ + {"type": "string", "description": "A string identifier (e.g., MAC address)"}, + {"type": "integer", "description": "A numeric row ID"} + ] + } + } + ) dbtable: ALLOWED_TABLES = Field(..., description="Table name") @field_validator("columnName") @@ -772,3 +929,14 @@ class SettingValue(BaseModel): class GetSettingResponse(BaseResponse): """Response for getting a setting value.""" value: Any = Field(None, description="The setting value") + + +# ============================================================================= +# GRAPHQL SCHEMAS +# ============================================================================= + + +class GraphQLRequest(BaseModel): + """Request payload for GraphQL queries.""" + query: str = Field(..., description="GraphQL query string", json_schema_extra={"examples": ["{ devices { devMac devName } }"]}) + variables: Optional[Dict[str, Any]] = Field(None, description="Variables for the GraphQL query") diff --git a/server/api_server/openapi/spec_generator.py b/server/api_server/openapi/spec_generator.py index 12154624..9ecb461d 100644 --- a/server/api_server/openapi/spec_generator.py +++ b/server/api_server/openapi/spec_generator.py @@ -52,7 +52,7 @@ _rebuild_lock = threading.Lock() def generate_openapi_spec( title: str = "NetAlertX API", version: str = "2.0.0", - description: str = "NetAlertX Network Monitoring API - MCP Compatible", + description: str = "NetAlertX Network Monitoring API - Official Documentation - MCP Compatible", servers: Optional[List[Dict[str, str]]] = None, flask_app: Optional[Any] = None ) -> Dict[str, Any]: @@ -80,12 +80,21 @@ def generate_openapi_spec( "title": title, "version": version, "description": description, + "termsOfService": "https://github.com/netalertx/NetAlertX/blob/main/LICENSE.txt", "contact": { - "name": "NetAlertX", - "url": "https://github.com/jokob-sk/NetAlertX" + "name": "Open Source Project - NetAlertX - Github", + "url": "https://github.com/netalertx/NetAlertX" + }, + "license": { + "name": "Licensed under GPLv3", + "url": "https://www.gnu.org/licenses/gpl-3.0.html" } }, - "servers": servers or [{"url": "/", "description": "Local server"}], + "externalDocs": { + "description": "NetAlertX Official Documentation", + "url": "https://docs.netalertx.com/" + }, + "servers": servers or [{"url": "/", "description": "This NetAlertX instance"}], "security": [ {"BearerAuth": []} ], @@ -152,7 +161,11 @@ def generate_openapi_spec( # Add responses operation["responses"] = build_responses( - entry.get("response_model"), definitions + entry.get("response_model"), + definitions, + response_content_types=entry.get("response_content_types", ["application/json"]), + links=entry.get("links"), + method=method ) spec["paths"][path][method] = operation diff --git a/server/api_server/openapi/validation.py b/server/api_server/openapi/validation.py index 33f1adcc..97617dd1 100644 --- a/server/api_server/openapi/validation.py +++ b/server/api_server/openapi/validation.py @@ -44,7 +44,11 @@ def validate_request( query_params: Optional[list[dict]] = None, validation_error_code: int = 422, auth_callable: Optional[Callable[[], bool]] = None, - allow_multipart_payload: bool = False + allow_multipart_payload: bool = False, + exclude_from_spec: bool = False, + response_content_types: Optional[list[str]] = None, + links: Optional[dict] = None, + error_responses: Optional[dict] = None ): """ Decorator to register a Flask route with the OpenAPI registry and validate incoming requests. @@ -56,6 +60,10 @@ def validate_request( - Supports auth_callable to check permissions before validation. - Returns 422 (default) if validation fails. - allow_multipart_payload: If True, allows multipart/form-data and attempts validation from form fields. + - exclude_from_spec: If True, this endpoint will be omitted from the generated OpenAPI specification. + - response_content_types: List of supported response media types (e.g. ["application/json", "text/csv"]). + - links: Dictionary of OpenAPI links to include in the response definition. + - error_responses: Dictionary of custom error examples (e.g. {"404": "Device not found"}). """ def decorator(f: Callable) -> Callable: @@ -73,7 +81,11 @@ def validate_request( "tags": tags, "path_params": path_params, "query_params": query_params, - "allow_multipart_payload": allow_multipart_payload + "allow_multipart_payload": allow_multipart_payload, + "exclude_from_spec": exclude_from_spec, + "response_content_types": response_content_types, + "links": links, + "error_responses": error_responses } @wraps(f) @@ -150,6 +162,7 @@ def validate_request( data = request.args.to_dict() validated_instance = request_model(**data) except ValidationError as e: + # Use configured validation error code (default 422) return _handle_validation_error(e, operation_id, validation_error_code) except (TypeError, ValueError, KeyError) as e: mylog("verbose", [f"[Validation] Query param validation failed for {operation_id}: {e}"]) diff --git a/server/messaging/in_app.py b/server/messaging/in_app.py index fc47afdf..000e849f 100755 --- a/server/messaging/in_app.py +++ b/server/messaging/in_app.py @@ -50,28 +50,33 @@ def write_notification(content, level="alert", timestamp=None): } # If file exists, load existing data, otherwise initialize as empty list - if os.path.exists(NOTIFICATION_API_FILE): - with open(NOTIFICATION_API_FILE, "r") as file: - # Check if the file object is of type _io.TextIOWrapper - if isinstance(file, _io.TextIOWrapper): - file_contents = file.read() # Read file contents - if file_contents == "": - file_contents = "[]" # If file is empty, initialize as empty list - - # mylog('debug', ['[Notification] User Notifications file: ', file_contents]) - notifications = json.loads(file_contents) # Parse JSON data - else: - mylog("none", "[Notification] File is not of type _io.TextIOWrapper") - notifications = [] - else: + try: + if os.path.exists(NOTIFICATION_API_FILE): + with open(NOTIFICATION_API_FILE, "r") as file: + file_contents = file.read().strip() + if file_contents: + notifications = json.loads(file_contents) + if not isinstance(notifications, list): + mylog("error", "[Notification] Invalid format: not a list, resetting") + notifications = [] + else: + notifications = [] + else: + notifications = [] + except Exception as e: + mylog("error", [f"[Notification] Error reading notifications file: {e}"]) notifications = [] # Append new notification notifications.append(notification) # Write updated data back to file - with open(NOTIFICATION_API_FILE, "w") as file: - json.dump(notifications, file, indent=4) + try: + with open(NOTIFICATION_API_FILE, "w") as file: + json.dump(notifications, file, indent=4) + except Exception as e: + mylog("error", [f"[Notification] Error writing to notifications file: {e}"]) + # Don't re-raise, just log. This prevents the API from crashing 500. # Broadcast unread count update try: diff --git a/test/api_endpoints/test_mcp_openapi_spec.py b/test/api_endpoints/test_mcp_openapi_spec.py index f92b1f82..bcc7f81a 100644 --- a/test/api_endpoints/test_mcp_openapi_spec.py +++ b/test/api_endpoints/test_mcp_openapi_spec.py @@ -197,7 +197,7 @@ class TestOpenAPISpecGenerator: f"Path param '{param_name}' not defined: {method.upper()} {path}" def test_standard_error_responses(self): - """Operations should have minimal standard error responses (400, 403, 404, etc) without schema bloat.""" + """Operations should have standard error responses (400, 403, 404, etc).""" spec = generate_openapi_spec() expected_minimal_codes = ["400", "401", "403", "404", "500", "422"] @@ -207,9 +207,9 @@ class TestOpenAPISpecGenerator: continue responses = details.get("responses", {}) for code in expected_minimal_codes: - assert code in responses, f"Missing minimal {code} response in: {method.upper()} {path}." - # Verify no "content" or schema is present (minimalism) - assert "content" not in responses[code], f"Response {code} in {method.upper()} {path} should not have content/schema." + assert code in responses, f"Missing {code} response in: {method.upper()} {path}." + # Content should now be present (BaseResponse/Error schema) + assert "content" in responses[code], f"Response {code} in {method.upper()} {path} should have content/schema." class TestMCPToolMapping: @@ -239,9 +239,9 @@ class TestMCPToolMapping: spec = generate_openapi_spec() tools = map_openapi_to_mcp_tools(spec) - search_tool = next((t for t in tools if t["name"] == "search_devices"), None) + search_tool = next((t for t in tools if t["name"] == "search_devices_api"), None) assert search_tool is not None - assert "query" in search_tool["inputSchema"].get("required", []) + assert "query" in search_tool["inputSchema"]["required"] def test_tool_descriptions_present(self): """All tools should have non-empty descriptions.""" diff --git a/test/api_endpoints/test_schema_converter.py b/test/api_endpoints/test_schema_converter.py new file mode 100644 index 00000000..e69de29b