Merge pull request #1429 from adamoutler/mcp-swagger-rebase

feat(api): MCP, OpenAPI & Dynamic Introspection
This commit is contained in:
Jokob @NetAlertX
2026-01-19 13:11:04 +11:00
committed by GitHub
48 changed files with 5495 additions and 1080 deletions

View File

@@ -38,6 +38,19 @@ All application settings can also be initialized via the `APP_CONF_OVERRIDE` doc
There are several ways to check if the GraphQL server is running.
## Flask debug mode (environment)
You can control whether the Flask development debugger is enabled by setting the environment variable `FLASK_DEBUG` (default: `False`). Enabling debug mode will turn on the interactive debugger which may expose a remote code execution (RCE) vector if the server is reachable; **only enable this for local development** and never in production. Valid truthy values are: `1`, `true`, `yes`, `on` (case-insensitive).
In the running container you can set this variable via Docker Compose or your environment, for example:
```yaml
environment:
- FLASK_DEBUG=1
```
When enabled, the GraphQL server startup logs will indicate the debug setting.
### Init Check
You can navigate to System Info -> Init Check to see if `isGraphQLServerRunning` is ticked:

View File

@@ -89,14 +89,22 @@ def is_typical_router_ip(ip_address):
# -------------------------------------------------------------------
# Check if a valid MAC address
def is_mac(input):
input_str = str(input).lower() # Convert to string and lowercase so non-string values won't raise errors
input_str = str(input).lower().strip() # Convert to string and lowercase so non-string values won't raise errors
isMac = bool(re.match("[0-9a-f]{2}([-:]?)[0-9a-f]{2}(\\1[0-9a-f]{2}){4}$", input_str))
# Full MAC (6 octets) e.g. AA:BB:CC:DD:EE:FF
full_mac_re = re.compile(r"^[0-9a-f]{2}([-:]?)[0-9a-f]{2}(\1[0-9a-f]{2}){4}$")
if not isMac: # If it's not a MAC address, log the input
mylog('verbose', [f'[is_mac] not a MAC: {input_str}'])
# Wildcard prefix format: exactly 3 octets followed by a trailing '*' component
# Examples: AA:BB:CC:*
wildcard_re = re.compile(r"^[0-9a-f]{2}[-:]?[0-9a-f]{2}[-:]?[0-9a-f]{2}[-:]?\*$")
return isMac
if full_mac_re.match(input_str) or wildcard_re.match(input_str):
return True
# If it's not a MAC address or allowed wildcard pattern, log the input
mylog('verbose', [f'[is_mac] not a MAC: {input_str}'])
return False
# -------------------------------------------------------------------
@@ -168,20 +176,36 @@ def decode_settings_base64(encoded_str, convert_types=True):
# -------------------------------------------------------------------
def normalize_mac(mac):
# Split the MAC address by colon (:) or hyphen (-) and convert each part to uppercase
parts = mac.upper().split(':')
"""
Normalize a MAC address to the standard format with colon separators.
For example, "aa-bb-cc-dd-ee-ff" will be normalized to "AA:BB:CC:DD:EE:FF".
Wildcard MAC addresses like "AA:BB:CC:*" will be normalized to "AA:BB:CC:*".
# If the MAC address is split by hyphen instead of colon
if len(parts) == 1:
parts = mac.upper().split('-')
:param mac: The MAC address to normalize.
:return: The normalized MAC address.
"""
s = str(mac).upper().strip()
# Normalize each part to have exactly two hexadecimal digits
normalized_parts = [part.zfill(2) for part in parts]
# Determine separator if present, prefer colon, then hyphen
if ':' in s:
parts = s.split(':')
elif '-' in s:
parts = s.split('-')
else:
# No explicit separator; attempt to split every two chars
parts = [s[i:i + 2] for i in range(0, len(s), 2)]
# Join the parts with colon (:)
normalized_mac = ':'.join(normalized_parts)
normalized_parts = []
for part in parts:
part = part.strip()
if part == '*':
normalized_parts.append('*')
else:
# Ensure two hex digits (zfill is fine for alphanumeric input)
normalized_parts.append(part.zfill(2))
return normalized_mac
# Use colon as canonical separator
return ':'.join(normalized_parts)
# -------------------------------------------------------------------

View File

@@ -32,3 +32,4 @@ httplib2
gunicorn
git+https://github.com/foreign-sub/aiofreepybox.git
mcp
pydantic>=2.0,<3.0

View File

@@ -210,7 +210,7 @@ def build_row(
def generate_rows(args: argparse.Namespace, header: list[str]) -> list[dict[str, str]]:
now = dt.datetime.utcnow()
now = dt.datetime.now(dt.timezone.utc)
macs: set[str] = set()
ip_pool = prepare_ip_pool(args.network)

View File

View File

File diff suppressed because it is too large Load Diff

View File

@@ -46,46 +46,46 @@ class PageQueryOptionsInput(InputObjectType):
# Device ObjectType
class Device(ObjectType):
rowid = Int()
devMac = String()
devName = String()
devOwner = String()
devType = String()
devVendor = String()
devFavorite = Int()
devGroup = String()
devComments = String()
devFirstConnection = String()
devLastConnection = String()
devLastIP = String()
devStaticIP = Int()
devScan = Int()
devLogEvents = Int()
devAlertEvents = Int()
devAlertDown = Int()
devSkipRepeated = Int()
devLastNotification = String()
devPresentLastScan = Int()
devIsNew = Int()
devLocation = String()
devIsArchived = Int()
devParentMAC = String()
devParentPort = String()
devIcon = String()
devGUID = String()
devSite = String()
devSSID = String()
devSyncHubNode = String()
devSourcePlugin = String()
devCustomProps = String()
devStatus = String()
devIsRandomMac = Int()
devParentChildrenCount = Int()
devIpLong = Int()
devFilterStatus = String()
devFQDN = String()
devParentRelType = String()
devReqNicsOnline = Int()
rowid = Int(description="Database row ID")
devMac = String(description="Device MAC address (e.g., 00:11:22:33:44:55)")
devName = String(description="Device display name/alias")
devOwner = String(description="Device owner")
devType = String(description="Device type classification")
devVendor = String(description="Hardware vendor from OUI lookup")
devFavorite = Int(description="Favorite flag (0 or 1)")
devGroup = String(description="Device group")
devComments = String(description="User comments")
devFirstConnection = String(description="Timestamp of first discovery")
devLastConnection = String(description="Timestamp of last connection")
devLastIP = String(description="Last known IP address")
devStaticIP = Int(description="Static IP flag (0 or 1)")
devScan = Int(description="Scan flag (0 or 1)")
devLogEvents = Int(description="Log events flag (0 or 1)")
devAlertEvents = Int(description="Alert events flag (0 or 1)")
devAlertDown = Int(description="Alert on down flag (0 or 1)")
devSkipRepeated = Int(description="Skip repeated alerts flag (0 or 1)")
devLastNotification = String(description="Timestamp of last notification")
devPresentLastScan = Int(description="Present in last scan flag (0 or 1)")
devIsNew = Int(description="Is new device flag (0 or 1)")
devLocation = String(description="Device location")
devIsArchived = Int(description="Is archived flag (0 or 1)")
devParentMAC = String(description="Parent device MAC address")
devParentPort = String(description="Parent device port")
devIcon = String(description="Base64-encoded HTML/SVG markup used to render the device icon")
devGUID = String(description="Unique device GUID")
devSite = String(description="Site name")
devSSID = String(description="SSID connected to")
devSyncHubNode = String(description="Sync hub node name")
devSourcePlugin = String(description="Plugin that discovered the device")
devCustomProps = String(description="Base64-encoded custom properties in JSON format")
devStatus = String(description="Online/Offline status")
devIsRandomMac = Int(description="Calculated: Is MAC address randomized?")
devParentChildrenCount = Int(description="Calculated: Number of children attached to this parent")
devIpLong = Int(description="Calculated: IP address in long format")
devFilterStatus = String(description="Calculated: Device status for UI filtering")
devFQDN = String(description="Fully Qualified Domain Name")
devParentRelType = String(description="Relationship type to parent")
devReqNicsOnline = Int(description="Required NICs online flag")
class DeviceResult(ObjectType):
@@ -98,20 +98,20 @@ class DeviceResult(ObjectType):
# Setting ObjectType
class Setting(ObjectType):
setKey = String()
setName = String()
setDescription = String()
setType = String()
setOptions = String()
setGroup = String()
setValue = String()
setEvents = String()
setOverriddenByEnv = Boolean()
setKey = String(description="Unique configuration key")
setName = String(description="Human-readable setting name")
setDescription = String(description="Detailed description of the setting")
setType = String(description="Config-driven type definition used to determine value type and UI rendering")
setOptions = String(description="JSON string of available options")
setGroup = String(description="UI group for categorization")
setValue = String(description="Current value")
setEvents = String(description="JSON string of events")
setOverriddenByEnv = Boolean(description="Whether the value is currently overridden by an environment variable")
class SettingResult(ObjectType):
settings = List(Setting)
count = Int()
settings = List(Setting, description="List of setting objects")
count = Int(description="Total count of settings")
# --- LANGSTRINGS ---
@@ -123,48 +123,48 @@ _langstrings_cache_mtime = {} # tracks last modified times
# LangString ObjectType
class LangString(ObjectType):
langCode = String()
langStringKey = String()
langStringText = String()
langCode = String(description="Language code (e.g., en_us, de_de)")
langStringKey = String(description="Unique translation key")
langStringText = String(description="Translated text content")
class LangStringResult(ObjectType):
langStrings = List(LangString)
count = Int()
langStrings = List(LangString, description="List of language string objects")
count = Int(description="Total count of strings")
# --- APP EVENTS ---
class AppEvent(ObjectType):
Index = Int()
GUID = String()
AppEventProcessed = Int()
DateTimeCreated = String()
Index = Int(description="Internal index")
GUID = String(description="Unique event GUID")
AppEventProcessed = Int(description="Processing status (0 or 1)")
DateTimeCreated = String(description="Event creation timestamp")
ObjectType = String()
ObjectGUID = String()
ObjectPlugin = String()
ObjectPrimaryID = String()
ObjectSecondaryID = String()
ObjectForeignKey = String()
ObjectIndex = Int()
ObjectType = String(description="Type of the related object (Device, Setting, etc.)")
ObjectGUID = String(description="GUID of the related object")
ObjectPlugin = String(description="Plugin associated with the object")
ObjectPrimaryID = String(description="Primary identifier of the object")
ObjectSecondaryID = String(description="Secondary identifier of the object")
ObjectForeignKey = String(description="Foreign key reference")
ObjectIndex = Int(description="Object index")
ObjectIsNew = Int()
ObjectIsArchived = Int()
ObjectStatusColumn = String()
ObjectStatus = String()
ObjectIsNew = Int(description="Is the object new? (0 or 1)")
ObjectIsArchived = Int(description="Is the object archived? (0 or 1)")
ObjectStatusColumn = String(description="Column used for status")
ObjectStatus = String(description="Object status value")
AppEventType = String()
AppEventType = String(description="Type of application event")
Helper1 = String()
Helper2 = String()
Helper3 = String()
Extra = String()
Helper1 = String(description="Generic helper field 1")
Helper2 = String(description="Generic helper field 2")
Helper3 = String(description="Generic helper field 3")
Extra = String(description="Additional JSON data")
class AppEventResult(ObjectType):
appEvents = List(AppEvent)
count = Int()
appEvents = List(AppEvent, description="List of application events")
count = Int(description="Total count of events")
# ----------------------------------------------------------------------------------------------

View File

File diff suppressed because it is too large Load Diff

View File

View File

@@ -0,0 +1,106 @@
from __future__ import annotations
import re
from typing import Any
import graphene
from .registry import register_tool, _operation_ids
def introspect_graphql_schema(schema: graphene.Schema):
"""
Introspect the GraphQL schema and register endpoints in the OpenAPI registry.
This bridges the 'living code' (GraphQL) to the OpenAPI spec.
"""
# Graphene schema introspection
graphql_schema = schema.graphql_schema
query_type = graphql_schema.query_type
if not query_type:
return
# We register the main /graphql endpoint once
register_tool(
path="/graphql",
method="POST",
operation_id="graphql_query",
summary="GraphQL Endpoint",
description="Execute arbitrary GraphQL queries against the system schema.",
tags=["graphql"]
)
def _flask_to_openapi_path(flask_path: str) -> str:
"""Convert Flask path syntax to OpenAPI path syntax."""
# Handles <converter:variable> -> {variable} and <variable> -> {variable}
return re.sub(r'<(?:\w+:)?(\w+)>', r'{\1}', flask_path)
def introspect_flask_app(app: Any):
"""
Introspect the Flask application to find routes decorated with @validate_request
and register them in the OpenAPI registry.
"""
registered_ops = set()
for rule in app.url_map.iter_rules():
view_func = app.view_functions.get(rule.endpoint)
if not view_func:
continue
# Check for our decorator's metadata
metadata = getattr(view_func, "_openapi_metadata", None)
if not metadata:
# Fallback for wrapped functions
if hasattr(view_func, "__wrapped__"):
metadata = getattr(view_func.__wrapped__, "_openapi_metadata", None)
if metadata:
op_id = metadata["operation_id"]
# Register the tool with real path and method from Flask
for method in rule.methods:
if method in ("OPTIONS", "HEAD"):
continue
# Create a unique key for this path/method/op combination if needed,
# but operationId must be unique globally.
# If the same function is mounted on multiple paths, we append a suffix
path = _flask_to_openapi_path(str(rule))
# Check if this operation (path + method) is already registered
op_key = f"{method}:{path}"
if op_key in registered_ops:
continue
# Determine tags - create a copy to avoid mutating shared metadata
tags = list(metadata.get("tags") or ["rest"])
if path.startswith("/mcp/"):
# Move specific tags to secondary position or just add MCP
if "rest" in tags:
tags.remove("rest")
if "mcp" not in tags:
tags.append("mcp")
# Ensure unique operationId
original_op_id = op_id
unique_op_id = op_id
count = 1
while unique_op_id in _operation_ids:
unique_op_id = f"{op_id}_{count}"
count += 1
register_tool(
path=path,
method=method,
operation_id=unique_op_id,
original_operation_id=original_op_id if unique_op_id != original_op_id else None,
summary=metadata["summary"],
description=metadata["description"],
request_model=metadata.get("request_model"),
response_model=metadata.get("response_model"),
path_params=metadata.get("path_params"),
query_params=metadata.get("query_params"),
tags=tags,
allow_multipart_payload=metadata.get("allow_multipart_payload", False)
)
registered_ops.add(op_key)

View File

@@ -0,0 +1,158 @@
from __future__ import annotations
import threading
from copy import deepcopy
from typing import List, Dict, Any, Literal, Optional, Type, Set
from pydantic import BaseModel
# Thread-safe registry
_registry: List[Dict[str, Any]] = []
_registry_lock = threading.Lock()
_operation_ids: Set[str] = set()
_disabled_tools: Set[str] = set()
class DuplicateOperationIdError(Exception):
"""Raised when an operationId is registered more than once."""
pass
def set_tool_disabled(operation_id: str, disabled: bool = True) -> bool:
"""
Enable or disable a tool by operation_id.
Args:
operation_id: The unique operation_id of the tool
disabled: True to disable, False to enable
Returns:
bool: True if operation_id exists, False otherwise
"""
with _registry_lock:
if operation_id not in _operation_ids:
return False
if disabled:
_disabled_tools.add(operation_id)
else:
_disabled_tools.discard(operation_id)
return True
def is_tool_disabled(operation_id: str) -> bool:
"""
Check if a tool is disabled.
Checks both the unique operation_id and the original_operation_id.
"""
with _registry_lock:
if operation_id in _disabled_tools:
return True
# Also check if the original base ID is disabled
for entry in _registry:
if entry["operation_id"] == operation_id:
orig_id = entry.get("original_operation_id")
if orig_id and orig_id in _disabled_tools:
return True
return False
def get_disabled_tools() -> List[str]:
"""Get list of all disabled operation_ids."""
with _registry_lock:
return list(_disabled_tools)
def get_tools_status() -> List[Dict[str, Any]]:
"""
Get a list of all registered tools and their disabled status.
Useful for backend-to-frontend communication.
"""
tools = []
with _registry_lock:
disabled_snapshot = _disabled_tools.copy()
for entry in _registry:
op_id = entry["operation_id"]
orig_id = entry.get("original_operation_id")
is_disabled = bool(op_id in disabled_snapshot or (orig_id and orig_id in disabled_snapshot))
tools.append({
"operation_id": op_id,
"summary": entry["summary"],
"disabled": is_disabled
})
return tools
def register_tool(
path: str,
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"],
operation_id: str,
summary: str,
description: str,
request_model: Optional[Type[BaseModel]] = None,
response_model: Optional[Type[BaseModel]] = None,
path_params: Optional[List[Dict[str, Any]]] = None,
query_params: Optional[List[Dict[str, Any]]] = None,
tags: Optional[List[str]] = None,
deprecated: bool = False,
original_operation_id: Optional[str] = None,
allow_multipart_payload: bool = False
) -> None:
"""
Register an API endpoint for OpenAPI spec generation.
Args:
path: URL path (e.g., "/devices/{mac}")
method: HTTP method
operation_id: Unique identifier for this operation (MUST be unique across entire spec)
summary: Short summary for the operation
description: Detailed description
request_model: Pydantic model for request body (POST/PUT/PATCH)
response_model: Pydantic model for success response
path_params: List of path parameter definitions
query_params: List of query parameter definitions
tags: OpenAPI tags for grouping
deprecated: Whether this endpoint is deprecated
original_operation_id: The base ID before suffixing (for disablement mapping)
allow_multipart_payload: Whether to allow multipart/form-data payloads
Raises:
DuplicateOperationIdError: If operation_id already exists in registry
"""
with _registry_lock:
if operation_id in _operation_ids:
raise DuplicateOperationIdError(
f"operationId '{operation_id}' is already registered. "
"Each operationId must be unique across the entire API."
)
_operation_ids.add(operation_id)
_registry.append({
"path": path,
"method": method.upper(),
"operation_id": operation_id,
"original_operation_id": original_operation_id,
"summary": summary,
"description": description,
"request_model": request_model,
"response_model": response_model,
"path_params": path_params or [],
"query_params": query_params or [],
"tags": tags or ["default"],
"deprecated": deprecated,
"allow_multipart_payload": allow_multipart_payload
})
def clear_registry() -> None:
"""Clear all registered endpoints (useful for testing)."""
with _registry_lock:
_registry.clear()
_operation_ids.clear()
_disabled_tools.clear()
def get_registry() -> List[Dict[str, Any]]:
"""Get a deep copy of the current registry to prevent external mutation."""
with _registry_lock:
return deepcopy(_registry)

View File

@@ -0,0 +1,217 @@
from __future__ import annotations
from typing import Dict, Any, Optional, Type, List
from pydantic import BaseModel
def pydantic_to_json_schema(model: Type[BaseModel], mode: str = "validation") -> Dict[str, Any]:
"""
Convert a Pydantic model to JSON Schema (OpenAPI 3.1 compatible).
Uses Pydantic's built-in schema generation which produces
JSON Schema Draft 2020-12 compatible output.
Args:
model: Pydantic BaseModel class
mode: Schema mode - "validation" (for inputs) or "serialization" (for outputs)
Returns:
JSON Schema dictionary
"""
# Pydantic v2 uses model_json_schema()
schema = model.model_json_schema(mode=mode)
# Remove $defs if empty (cleaner output)
if "$defs" in schema and not schema["$defs"]:
del schema["$defs"]
return schema
def build_parameters(entry: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Build OpenAPI parameters array from path and query params."""
parameters = []
# Path parameters
for param in entry.get("path_params", []):
parameters.append({
"name": param["name"],
"in": "path",
"required": True,
"description": param.get("description", ""),
"schema": param.get("schema", {"type": "string"})
})
# Query parameters
for param in entry.get("query_params", []):
parameters.append({
"name": param["name"],
"in": "query",
"required": param.get("required", False),
"description": param.get("description", ""),
"schema": param.get("schema", {"type": "string"})
})
return parameters
def extract_definitions(schema: Dict[str, Any], definitions: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively extract $defs from a schema and move them to the definitions dict.
Also rewrite $ref to point to #/components/schemas/.
"""
if not isinstance(schema, dict):
return schema
# Extract definitions
if "$defs" in schema:
for name, definition in schema["$defs"].items():
# Recursively process the definition itself before adding it
definitions[name] = extract_definitions(definition, definitions)
del schema["$defs"]
# Rewrite references
if "$ref" in schema and schema["$ref"].startswith("#/$defs/"):
ref_name = schema["$ref"].split("/")[-1]
schema["$ref"] = f"#/components/schemas/{ref_name}"
# Recursively process properties
for key, value in schema.items():
if isinstance(value, dict):
schema[key] = extract_definitions(value, definitions)
elif isinstance(value, list):
schema[key] = [extract_definitions(item, definitions) for item in value]
return schema
def build_request_body(
model: Optional[Type[BaseModel]],
definitions: Dict[str, Any],
allow_multipart_payload: bool = False
) -> Optional[Dict[str, Any]]:
"""Build OpenAPI requestBody from Pydantic model."""
if model is None:
return None
schema = pydantic_to_json_schema(model)
schema = extract_definitions(schema, definitions)
content = {
"application/json": {
"schema": schema
}
}
if allow_multipart_payload:
content["multipart/form-data"] = {
"schema": schema
}
return {
"required": True,
"content": content
}
def strip_validation(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively remove validation constraints from a JSON schema.
Keeps structure and descriptions, but removes pattern, minLength, etc.
This saves context tokens for LLMs which don't validate server output.
"""
if not isinstance(schema, dict):
return schema
# Keys to remove
validation_keys = [
"pattern", "minLength", "maxLength", "minimum", "maximum",
"exclusiveMinimum", "exclusiveMaximum", "multipleOf", "minItems",
"maxItems", "uniqueItems", "minProperties", "maxProperties"
]
clean_schema = {k: v for k, v in schema.items() if k not in validation_keys}
# Recursively clean sub-schemas
if "properties" in clean_schema:
clean_schema["properties"] = {
k: strip_validation(v) for k, v in clean_schema["properties"].items()
}
if "items" in clean_schema:
clean_schema["items"] = strip_validation(clean_schema["items"])
if "allOf" in clean_schema:
clean_schema["allOf"] = [strip_validation(x) for x in clean_schema["allOf"]]
if "anyOf" in clean_schema:
clean_schema["anyOf"] = [strip_validation(x) for x in clean_schema["anyOf"]]
if "oneOf" in clean_schema:
clean_schema["oneOf"] = [strip_validation(x) for x in clean_schema["oneOf"]]
if "$defs" in clean_schema:
clean_schema["$defs"] = {
k: strip_validation(v) for k, v in clean_schema["$defs"].items()
}
if "additionalProperties" in clean_schema and isinstance(clean_schema["additionalProperties"], dict):
clean_schema["additionalProperties"] = strip_validation(clean_schema["additionalProperties"])
return clean_schema
def build_responses(
response_model: Optional[Type[BaseModel]], definitions: Dict[str, Any]
) -> Dict[str, Any]:
"""Build OpenAPI responses object."""
responses = {}
# Success response (200)
if response_model:
# Strip validation from response schema to save tokens
schema = strip_validation(pydantic_to_json_schema(response_model, mode="serialization"))
schema = extract_definitions(schema, definitions)
responses["200"] = {
"description": "Successful response",
"content": {
"application/json": {
"schema": schema
}
}
}
else:
responses["200"] = {
"description": "Successful response",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"success": {"type": "boolean"},
"message": {"type": "string"}
}
}
}
}
}
# Standard error responses - MINIMIZED context
# Annotate that these errors can occur, but provide no schema/content to save tokens.
# The LLM knows what "Bad Request" or "Not Found" means.
error_codes = {
"400": "Bad Request",
"401": "Unauthorized",
"403": "Forbidden",
"404": "Not Found",
"422": "Validation Error",
"500": "Internal Server Error"
}
for code, desc in error_codes.items():
responses[code] = {
"description": desc
# No "content" schema provided
}
return responses

View File

@@ -0,0 +1,738 @@
#!/usr/bin/env python
"""
NetAlertX API Schema Definitions (Pydantic v2)
This module defines strict Pydantic models for all API request and response payloads.
These schemas serve as the single source of truth for:
1. Runtime validation of incoming requests
2. OpenAPI specification generation
3. MCP tool input schema derivation
Philosophy: "Code First, Spec Second" — these models ARE the contract.
"""
from __future__ import annotations
import re
import ipaddress
from typing import Optional, List, Literal, Any, Dict
from pydantic import BaseModel, Field, field_validator, model_validator, ConfigDict, RootModel
# Internal helper imports
from helper import sanitize_string
from plugin_helper import normalize_mac, is_mac
# =============================================================================
# COMMON PATTERNS & VALIDATORS
# =============================================================================
MAC_PATTERN = r"^([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})$"
IP_PATTERN = r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
COLUMN_NAME_PATTERN = re.compile(r"^[a-zA-Z0-9_]+$")
# Security whitelists & Literals for documentation
ALLOWED_DEVICE_COLUMNS = Literal[
"devName", "devOwner", "devType", "devVendor",
"devGroup", "devLocation", "devComments", "devFavorite",
"devParentMAC"
]
ALLOWED_NMAP_MODES = Literal[
"quick", "intense", "ping", "comprehensive", "fast", "normal", "detail", "skipdiscovery",
"-sS", "-sT", "-sU", "-sV", "-O"
]
NOTIFICATION_LEVELS = Literal["info", "warning", "error", "alert"]
ALLOWED_TABLES = Literal["Devices", "Events", "Sessions", "Settings", "CurrentScan", "Online_History", "Plugins_Objects"]
ALLOWED_LOG_FILES = Literal[
"app.log", "app_front.log", "IP_changes.log", "stdout.log", "stderr.log",
"app.php_errors.log", "execution_queue.log", "db_is_locked.log"
]
def validate_mac(value: str) -> str:
"""Validate and normalize MAC address format."""
# Allow "Internet" as a special case for the gateway/WAN device
if value.lower() == "internet":
return "Internet"
if not is_mac(value):
raise ValueError(f"Invalid MAC address format: {value}")
return normalize_mac(value)
def validate_ip(value: str) -> str:
"""Validate IP address format (IPv4 or IPv6) using stdlib ipaddress.
Returns the canonical string form of the IP address.
"""
try:
return str(ipaddress.ip_address(value))
except ValueError as err:
raise ValueError(f"Invalid IP address: {value}") from err
def validate_column_identifier(value: str) -> str:
"""Validate a column identifier to prevent SQL injection."""
if not COLUMN_NAME_PATTERN.match(value):
raise ValueError("Invalid column name format")
return value
# =============================================================================
# BASE RESPONSE MODELS
# =============================================================================
class BaseResponse(BaseModel):
"""Standard API response wrapper."""
model_config = ConfigDict(extra="allow")
success: bool = Field(..., description="Whether the operation succeeded")
message: Optional[str] = Field(None, description="Human-readable message")
error: Optional[str] = Field(None, description="Error message if success=False")
class PaginatedResponse(BaseResponse):
"""Response with pagination metadata."""
total: int = Field(0, description="Total number of items")
page: int = Field(1, ge=1, description="Current page number")
per_page: int = Field(50, ge=1, le=500, description="Items per page")
# =============================================================================
# DEVICE SCHEMAS
# =============================================================================
class DeviceSearchRequest(BaseModel):
"""Request payload for searching devices."""
model_config = ConfigDict(str_strip_whitespace=True)
query: str = Field(
...,
min_length=1,
max_length=256,
description="Search term: IP address, MAC address, device name, or vendor",
json_schema_extra={"examples": ["192.168.1.1", "Apple", "00:11:22:33:44:55"]}
)
limit: int = Field(
50,
ge=1,
le=500,
description="Maximum number of results to return"
)
class DeviceInfo(BaseModel):
"""Detailed device information model (Raw record)."""
model_config = ConfigDict(extra="allow")
devMac: str = Field(..., description="Device MAC address")
devName: Optional[str] = Field(None, description="Device display name/alias")
devLastIP: Optional[str] = Field(None, description="Last known IP address")
devVendor: Optional[str] = Field(None, description="Hardware vendor from OUI lookup")
devOwner: Optional[str] = Field(None, description="Device owner")
devType: Optional[str] = Field(None, description="Device type classification")
devFavorite: Optional[int] = Field(0, description="Favorite flag (0 or 1)")
devPresentLastScan: Optional[int] = Field(None, description="Present in last scan (0 or 1)")
devStatus: Optional[str] = Field(None, description="Online/Offline status")
class DeviceSearchResponse(BaseResponse):
"""Response payload for device search."""
devices: List[DeviceInfo] = Field(default_factory=list, description="List of matching devices")
class DeviceListRequest(BaseModel):
"""Request for listing devices by status."""
status: Optional[Literal[
"connected", "down", "favorites", "new", "archived", "all", "my",
"offline"
]] = Field(
None,
description="Filter devices by status (connected, down, favorites, new, archived, all, my, offline)"
)
class DeviceListResponse(RootModel):
"""Response with list of devices."""
root: List[DeviceInfo] = Field(default_factory=list, description="List of devices")
class DeviceListWrapperResponse(BaseResponse):
"""Wrapped response with list of devices."""
devices: List[DeviceInfo] = Field(default_factory=list, description="List of devices")
class GetDeviceRequest(BaseModel):
"""Path parameter for getting a specific device."""
mac: str = Field(
...,
description="Device MAC address",
json_schema_extra={"examples": ["00:11:22:33:44:55"]}
)
@field_validator("mac")
@classmethod
def validate_mac_address(cls, v: str) -> str:
return validate_mac(v)
class GetDeviceResponse(BaseResponse):
"""Wrapped response for getting device details."""
device: Optional[DeviceInfo] = Field(None, description="Device details if found")
class GetDeviceWrapperResponse(BaseResponse):
"""Wrapped response for getting a single device (e.g. latest)."""
device: Optional[DeviceInfo] = Field(None, description="Device details")
class SetDeviceAliasRequest(BaseModel):
"""Request to set a device alias/name."""
alias: str = Field(
...,
min_length=1,
max_length=128,
description="New display name/alias for the device"
)
@field_validator("alias")
@classmethod
def sanitize_alias(cls, v: str) -> str:
return sanitize_string(v)
class DeviceTotalsResponse(RootModel):
"""Response with device statistics."""
root: List[int] = Field(default_factory=list, description="List of counts: [all, online, favorites, new, offline, archived]")
class DeviceExportRequest(BaseModel):
"""Request for exporting devices."""
format: Literal["csv", "json"] = Field(
"csv",
description="Export format: csv or json"
)
class DeviceExportResponse(BaseModel):
"""Raw response for device export in JSON format."""
columns: List[str] = Field(..., description="Column names")
data: List[Dict[str, Any]] = Field(..., description="Device records")
class DeviceImportRequest(BaseModel):
"""Request for importing devices."""
content: Optional[str] = Field(
None,
description="Base64-encoded CSV or JSON content to import"
)
class DeviceImportResponse(BaseResponse):
"""Response for device import operation."""
imported: int = Field(0, description="Number of devices imported")
skipped: int = Field(0, description="Number of devices skipped")
errors: List[str] = Field(default_factory=list, description="List of import errors")
class CopyDeviceRequest(BaseModel):
"""Request to copy device settings."""
macFrom: str = Field(..., description="Source MAC address")
macTo: str = Field(..., description="Destination MAC address")
@field_validator("macFrom", "macTo")
@classmethod
def validate_mac_addresses(cls, v: str) -> str:
return validate_mac(v)
class UpdateDeviceColumnRequest(BaseModel):
"""Request to update a specific device database column."""
columnName: ALLOWED_DEVICE_COLUMNS = Field(..., description="Database column name")
columnValue: Any = Field(..., description="New value for the column")
class DeviceUpdateRequest(BaseModel):
"""Request to update device fields (create/update)."""
model_config = ConfigDict(extra="allow")
devName: Optional[str] = Field(None, description="Device name")
devOwner: Optional[str] = Field(None, description="Device owner")
devType: Optional[str] = Field(None, description="Device type")
devVendor: Optional[str] = Field(None, description="Device vendor")
devGroup: Optional[str] = Field(None, description="Device group")
devLocation: Optional[str] = Field(None, description="Device location")
devComments: Optional[str] = Field(None, description="Comments")
createNew: bool = Field(False, description="Create new device if not exists")
@field_validator("devName", "devOwner", "devType", "devVendor", "devGroup", "devLocation", "devComments")
@classmethod
def sanitize_text_fields(cls, v: Optional[str]) -> Optional[str]:
if v is None:
return v
return sanitize_string(v)
class DeleteDevicesRequest(BaseModel):
"""Request to delete multiple devices."""
macs: List[str] = Field([], description="List of MACs to delete")
confirm_delete_all: bool = Field(False, description="Explicit flag to delete ALL devices when macs is empty")
@field_validator("macs")
@classmethod
def validate_mac_list(cls, v: List[str]) -> List[str]:
return [validate_mac(mac) for mac in v]
@model_validator(mode="after")
def check_delete_all_safety(self) -> DeleteDevicesRequest:
if not self.macs and not self.confirm_delete_all:
raise ValueError("Must provide at least one MAC or set confirm_delete_all=True")
return self
# =============================================================================
# NETWORK TOOLS SCHEMAS
# =============================================================================
class TriggerScanRequest(BaseModel):
"""Request to trigger a network scan."""
type: str = Field(
"ARPSCAN",
description="Scan plugin type to execute (e.g., ARPSCAN, NMAPDEV, NMAP)",
json_schema_extra={"examples": ["ARPSCAN", "NMAPDEV", "NMAP"]}
)
class TriggerScanResponse(BaseResponse):
"""Response for scan trigger."""
scan_type: Optional[str] = Field(None, description="Type of scan that was triggered")
class OpenPortsRequest(BaseModel):
"""Request for getting open ports."""
target: str = Field(
...,
description="Target IP address or MAC address to check ports for",
json_schema_extra={"examples": ["192.168.1.50", "00:11:22:33:44:55"]}
)
@field_validator("target")
@classmethod
def validate_target(cls, v: str) -> str:
"""Validate target is either a valid IP or MAC address."""
# Try IP first
try:
return validate_ip(v)
except ValueError:
pass
# Try MAC
return validate_mac(v)
class OpenPortsResponse(BaseResponse):
"""Response with open ports information."""
target: str = Field(..., description="Target that was scanned")
open_ports: List[Any] = Field(default_factory=list, description="List of open port objects or numbers")
class WakeOnLanRequest(BaseModel):
"""Request to send Wake-on-LAN packet."""
devMac: Optional[str] = Field(
None,
description="Target device MAC address",
json_schema_extra={"examples": ["00:11:22:33:44:55"]}
)
devLastIP: Optional[str] = Field(
None,
alias="ip",
description="Target device IP (MAC will be resolved if not provided)",
json_schema_extra={"examples": ["192.168.1.50"]}
)
# Note: alias="ip" means input JSON can use "ip".
# But Pydantic V2 with populate_by_name=True allows both "devLastIP" and "ip".
model_config = ConfigDict(populate_by_name=True)
@field_validator("devMac")
@classmethod
def validate_mac_if_provided(cls, v: Optional[str]) -> Optional[str]:
if v is not None:
return validate_mac(v)
return v
@field_validator("devLastIP")
@classmethod
def validate_ip_if_provided(cls, v: Optional[str]) -> Optional[str]:
if v is not None:
return validate_ip(v)
return v
@model_validator(mode="after")
def require_mac_or_ip(self) -> "WakeOnLanRequest":
"""Ensure at least one of devMac or devLastIP is provided."""
if self.devMac is None and self.devLastIP is None:
raise ValueError("Either 'devMac' or 'devLastIP' (alias 'ip') must be provided")
return self
class WakeOnLanResponse(BaseResponse):
"""Response for Wake-on-LAN operation."""
output: Optional[str] = Field(None, description="Command output")
class TracerouteRequest(BaseModel):
"""Request to perform traceroute."""
devLastIP: str = Field(
...,
description="Target IP address for traceroute",
json_schema_extra={"examples": ["8.8.8.8", "192.168.1.1"]}
)
@field_validator("devLastIP")
@classmethod
def validate_ip_address(cls, v: str) -> str:
return validate_ip(v)
class TracerouteResponse(BaseResponse):
"""Response with traceroute results."""
output: List[str] = Field(default_factory=list, description="Traceroute hop output lines")
class NmapScanRequest(BaseModel):
"""Request to perform NMAP scan."""
scan: str = Field(
...,
description="Target IP address for NMAP scan"
)
mode: ALLOWED_NMAP_MODES = Field(
...,
description="NMAP scan mode/arguments (restricted to safe options)"
)
@field_validator("scan")
@classmethod
def validate_scan_target(cls, v: str) -> str:
return validate_ip(v)
class NslookupRequest(BaseModel):
"""Request for DNS lookup."""
devLastIP: str = Field(
...,
description="IP address to perform reverse DNS lookup"
)
@field_validator("devLastIP")
@classmethod
def validate_ip_address(cls, v: str) -> str:
return validate_ip(v)
class NslookupResponse(BaseResponse):
"""Response for DNS lookup operation."""
output: List[str] = Field(default_factory=list, description="Nslookup output lines")
class NmapScanResponse(BaseResponse):
"""Response for NMAP scan operation."""
mode: Optional[str] = Field(None, description="NMAP scan mode")
ip: Optional[str] = Field(None, description="Target IP address")
output: List[str] = Field(default_factory=list, description="NMAP scan output lines")
class NetworkTopologyResponse(BaseResponse):
"""Response with network topology data."""
nodes: List[dict] = Field(default_factory=list, description="Network nodes")
links: List[dict] = Field(default_factory=list, description="Network connections")
class InternetInfoResponse(BaseResponse):
"""Response for internet information."""
output: Dict[str, Any] = Field(..., description="Details about the internet connection.")
class NetworkInterfacesResponse(BaseResponse):
"""Response with network interface information."""
interfaces: Dict[str, Any] = Field(..., description="Details about network interfaces.")
# =============================================================================
# EVENTS SCHEMAS
# =============================================================================
class EventInfo(BaseModel):
"""Event/alert information."""
model_config = ConfigDict(extra="allow")
eveRowid: Optional[int] = Field(None, description="Event row ID")
eveMAC: Optional[str] = Field(None, description="Device MAC address")
eveIP: Optional[str] = Field(None, description="Device IP address")
eveDateTime: Optional[str] = Field(None, description="Event timestamp")
eveEventType: Optional[str] = Field(None, description="Type of event")
evePreviousIP: Optional[str] = Field(None, description="Previous IP if changed")
class RecentEventsRequest(BaseModel):
"""Request for recent events."""
hours: int = Field(
24,
ge=1,
le=720,
description="Number of hours to look back for events"
)
limit: int = Field(
100,
ge=1,
le=1000,
description="Maximum number of events to return"
)
class RecentEventsResponse(BaseResponse):
"""Response with recent events."""
hours: int = Field(..., description="The time window in hours")
events: List[EventInfo] = Field(default_factory=list, description="List of recent events")
class LastEventsResponse(BaseResponse):
"""Response with last N events."""
events: List[EventInfo] = Field(default_factory=list, description="List of last events")
class CreateEventRequest(BaseModel):
"""Request to create a device event."""
ip: Optional[str] = Field("0.0.0.0", description="Device IP")
event_type: str = Field("Device Down", description="Event type")
additional_info: Optional[str] = Field("", description="Additional info")
pending_alert: int = Field(1, description="Pending alert flag")
event_time: Optional[str] = Field(None, description="Event timestamp (ISO)")
@field_validator("ip", mode="before")
@classmethod
def validate_ip_field(cls, v: Optional[str]) -> str:
"""Validate and normalize IP address, defaulting to 0.0.0.0."""
if v is None or v == "":
return "0.0.0.0"
return validate_ip(v)
# =============================================================================
# SESSIONS SCHEMAS
# =============================================================================
class SessionInfo(BaseModel):
"""Session information."""
model_config = ConfigDict(extra="allow")
sesRowid: Optional[int] = Field(None, description="Session row ID")
sesMac: Optional[str] = Field(None, description="Device MAC address")
sesDateTimeConnection: Optional[str] = Field(None, description="Connection timestamp")
sesDateTimeDisconnection: Optional[str] = Field(None, description="Disconnection timestamp")
sesIPAddress: Optional[str] = Field(None, description="IP address during session")
class CreateSessionRequest(BaseModel):
"""Request to create a session."""
mac: str = Field(..., description="Device MAC")
ip: str = Field(..., description="Device IP")
start_time: str = Field(..., description="Start time")
end_time: Optional[str] = Field(None, description="End time")
event_type_conn: str = Field("Connected", description="Connection event type")
event_type_disc: str = Field("Disconnected", description="Disconnection event type")
@field_validator("mac")
@classmethod
def validate_mac_address(cls, v: str) -> str:
return validate_mac(v)
@field_validator("ip")
@classmethod
def validate_ip_address(cls, v: str) -> str:
return validate_ip(v)
class DeleteSessionRequest(BaseModel):
"""Request to delete sessions for a MAC."""
mac: str = Field(..., description="Device MAC")
@field_validator("mac")
@classmethod
def validate_mac_address(cls, v: str) -> str:
return validate_mac(v)
# =============================================================================
# MESSAGING / IN-APP NOTIFICATIONS SCHEMAS
# =============================================================================
class InAppNotification(BaseModel):
"""In-app notification model."""
model_config = ConfigDict(extra="allow")
id: Optional[int] = Field(None, description="Notification ID")
guid: Optional[str] = Field(None, description="Unique notification GUID")
text: str = Field(..., description="Notification text content")
level: NOTIFICATION_LEVELS = Field("info", description="Notification level")
read: Optional[int] = Field(0, description="Read status (0 or 1)")
created_at: Optional[str] = Field(None, description="Creation timestamp")
class CreateNotificationRequest(BaseModel):
"""Request to create an in-app notification."""
content: str = Field(
...,
min_length=1,
max_length=1024,
description="Notification content"
)
level: NOTIFICATION_LEVELS = Field(
"info",
description="Notification severity level"
)
# =============================================================================
# SYNC SCHEMAS
# =============================================================================
class SyncPushRequest(BaseModel):
"""Request to push data to sync."""
data: dict = Field(..., description="Data to sync")
node_name: str = Field(..., description="Name of the node sending data")
plugin: str = Field(..., description="Plugin identifier")
class SyncPullResponse(BaseResponse):
"""Response with sync data."""
data: Optional[dict] = Field(None, description="Synchronized data")
last_sync: Optional[str] = Field(None, description="Last sync timestamp")
# =============================================================================
# DB QUERY SCHEMAS (Raw SQL)
# =============================================================================
class DbQueryRequest(BaseModel):
"""
Request for raw database query.
WARNING: This is a highly privileged operation.
"""
rawSql: str = Field(
...,
description="Base64-encoded SQL query. (UNSAFE: Use only for administrative tasks)"
)
# Legacy compatibility: removed strict safety check
# TODO: SECURITY CRITICAL - Re-enable strict safety checks.
# The `confirm_dangerous_query` default was relaxed to `True` to maintain backward compatibility
# with the legacy frontend which sends raw SQL directly.
#
# CONTEXT: This explicit safety check was introduced with the new Pydantic validation layer.
# The legacy PHP frontend predates these formal schemas and does not send the
# `confirm_dangerous_query` flag, causing 422 Validation Errors when this check is enforced.
#
# Actionable Advice:
# 1. Implement a parser to strictly whitelist only `SELECT` statements if raw SQL is required.
# 2. Migrate the frontend to use structured endpoints (e.g., `/devices/search`, `/dbquery/read`) instead of raw SQL.
# 3. Once migrated, revert `confirm_dangerous_query` default to `False` and enforce the check.
confirm_dangerous_query: bool = Field(
True,
description="Required to be True to acknowledge the risks of raw SQL execution"
)
class DbQueryUpdateRequest(BaseModel):
"""Request for DB update query."""
columnName: str = Field(..., description="Column to filter by")
id: List[Any] = Field(..., description="List of IDs to update")
dbtable: ALLOWED_TABLES = Field(..., description="Table name")
columns: List[str] = Field(..., description="Columns to update")
values: List[Any] = Field(..., description="New values")
@field_validator("columnName")
@classmethod
def validate_column_name(cls, v: str) -> str:
return validate_column_identifier(v)
@field_validator("columns")
@classmethod
def validate_column_list(cls, values: List[str]) -> List[str]:
return [validate_column_identifier(value) for value in values]
@model_validator(mode="after")
def validate_columns_values(self) -> "DbQueryUpdateRequest":
if len(self.columns) != len(self.values):
raise ValueError("columns and values must have the same length")
return self
class DbQueryDeleteRequest(BaseModel):
"""Request for DB delete query."""
columnName: str = Field(..., description="Column to filter by")
id: List[Any] = Field(..., description="List of IDs to delete")
dbtable: ALLOWED_TABLES = Field(..., description="Table name")
@field_validator("columnName")
@classmethod
def validate_column_name(cls, v: str) -> str:
return validate_column_identifier(v)
class DbQueryResponse(BaseResponse):
"""Response from database query."""
data: Any = Field(None, description="Query result data")
columns: Optional[List[str]] = Field(None, description="Column names if applicable")
# =============================================================================
# LOGS SCHEMAS
# =============================================================================
class CleanLogRequest(BaseModel):
"""Request to clean/truncate a log file."""
logFile: ALLOWED_LOG_FILES = Field(
...,
description="Name of the log file to clean"
)
class LogResource(BaseModel):
"""Log file resource information."""
name: str = Field(..., description="Log file name")
path: str = Field(..., description="Full path to log file")
size_bytes: int = Field(0, description="File size in bytes")
modified: Optional[str] = Field(None, description="Last modification timestamp")
class AddToQueueRequest(BaseModel):
"""Request to add action to execution queue."""
action: str = Field(..., description="Action string (e.g. update_api|devices)")
# =============================================================================
# SETTINGS SCHEMAS
# =============================================================================
class SettingValue(BaseModel):
"""A single setting value."""
key: str = Field(..., description="Setting key name")
value: Any = Field(..., description="Setting value")
class GetSettingResponse(BaseResponse):
"""Response for getting a setting value."""
value: Any = Field(None, description="The setting value")

View File

@@ -0,0 +1,191 @@
#!/usr/bin/env python
"""
NetAlertX OpenAPI Specification Generator
This module provides a registry-based approach to OpenAPI spec generation.
It converts Pydantic models to JSON Schema and assembles a complete OpenAPI 3.1 spec.
Key Features:
- Automatic Pydantic -> JSON Schema conversion
- Centralized endpoint registry
- Unique operationId enforcement
- Complete request/response schema generation
Usage:
from spec_generator import registry, generate_openapi_spec, register_tool
# Register endpoints (typically done at module load)
register_tool(
path="/devices/search",
method="POST",
operation_id="search_devices",
description="Search for devices",
request_model=DeviceSearchRequest,
response_model=DeviceSearchResponse
)
# Generate spec (called by MCP endpoint)
spec = generate_openapi_spec()
"""
from __future__ import annotations
import threading
from typing import Optional, List, Dict, Any
from .registry import (
clear_registry,
_registry,
_registry_lock,
_disabled_tools
)
from .introspection import introspect_flask_app, introspect_graphql_schema
from .schema_converter import (
build_parameters,
build_request_body,
build_responses
)
_rebuild_lock = threading.Lock()
def generate_openapi_spec(
title: str = "NetAlertX API",
version: str = "2.0.0",
description: str = "NetAlertX Network Monitoring API - MCP Compatible",
servers: Optional[List[Dict[str, str]]] = None,
flask_app: Optional[Any] = None
) -> Dict[str, Any]:
"""Assemble a complete OpenAPI specification from the registered endpoints."""
with _rebuild_lock:
# If no app provided and registry is empty, try to use the one from api_server_start
if not flask_app and not _registry:
try:
from ..api_server_start import app as start_app
flask_app = start_app
except (ImportError, AttributeError):
pass
# If we are in "dynamic mode", we rebuild the registry from code
if flask_app:
from ..graphql_endpoint import devicesSchema
clear_registry()
introspect_graphql_schema(devicesSchema)
introspect_flask_app(flask_app)
spec = {
"openapi": "3.1.0",
"info": {
"title": title,
"version": version,
"description": description,
"contact": {
"name": "NetAlertX",
"url": "https://github.com/jokob-sk/NetAlertX"
}
},
"servers": servers or [{"url": "/", "description": "Local server"}],
"security": [
{"BearerAuth": []}
],
"components": {
"securitySchemes": {
"BearerAuth": {
"type": "http",
"scheme": "bearer",
"description": "API token from NetAlertX settings (API_TOKEN)"
}
},
"schemas": {}
},
"paths": {},
"tags": []
}
definitions = {}
# Collect unique tags
tag_set = set()
with _registry_lock:
disabled_snapshot = _disabled_tools.copy()
for entry in _registry:
path = entry["path"]
method = entry["method"].lower()
# Initialize path if not exists
if path not in spec["paths"]:
spec["paths"][path] = {}
# Build operation object
operation = {
"operationId": entry["operation_id"],
"summary": entry["summary"],
"description": entry["description"],
"tags": entry["tags"],
"deprecated": entry["deprecated"]
}
# Inject disabled status if applicable
if entry["operation_id"] in disabled_snapshot:
operation["x-mcp-disabled"] = True
# Inject original ID if suffixed (Coderabbit fix)
if entry.get("original_operation_id"):
operation["x-original-operationId"] = entry["original_operation_id"]
# Add parameters (path + query)
parameters = build_parameters(entry)
if parameters:
operation["parameters"] = parameters
# Add request body for POST/PUT/PATCH/DELETE
if method in ("post", "put", "patch", "delete") and entry.get("request_model"):
request_body = build_request_body(
entry["request_model"],
definitions,
allow_multipart_payload=entry.get("allow_multipart_payload", False)
)
if request_body:
operation["requestBody"] = request_body
# Add responses
operation["responses"] = build_responses(
entry.get("response_model"), definitions
)
spec["paths"][path][method] = operation
# Collect tags
for tag in entry["tags"]:
tag_set.add(tag)
spec["components"]["schemas"] = definitions
# Build tags array with descriptions
tag_descriptions = {
"devices": "Device management and queries",
"nettools": "Network diagnostic tools",
"events": "Event and alert management",
"sessions": "Session history tracking",
"messaging": "In-app notifications",
"settings": "Configuration management",
"sync": "Data synchronization",
"logs": "Log file access",
"dbquery": "Direct database queries"
}
spec["tags"] = [
{"name": tag, "description": tag_descriptions.get(tag, f"{tag.title()} operations")}
for tag in sorted(tag_set)
]
return spec
# Initialize registry on module load
# Registry is now populated dynamically via introspection in generate_openapi_spec
def _register_all_endpoints():
"""Dummy function for compatibility with legacy tests."""
pass

View File

@@ -0,0 +1,31 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="description" content="NetAlertX API Documentation" />
<title>NetAlertX API Docs</title>
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5.11.0/swagger-ui.css" integrity="sha384-+yyzNgM3K92sROwsXxYCxaiLWxWJ0G+v/9A+qIZ2rgefKgkdcmJI+L601cqPD/Ut" crossorigin="anonymous" />
<style>
body { margin: 0; padding: 0; }
</style>
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://unpkg.com/swagger-ui-dist@5.11.0/swagger-ui-bundle.js" integrity="sha384-qn5tagrAjZi8cSmvZ+k3zk4+eDEEUcP9myuR2J6V+/H6rne++v6ChO7EeHAEzqxQ" crossorigin="anonymous"></script>
<script>
window.onload = () => {
window.ui = SwaggerUIBundle({
url: '/openapi.json',
dom_id: '#swagger-ui',
deepLinking: true,
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
layout: "BaseLayout",
});
};
</script>
</body>
</html>

View File

@@ -0,0 +1,181 @@
from __future__ import annotations
import inspect
import json
from functools import wraps
from typing import Callable, Optional, Type
from flask import request, jsonify
from pydantic import BaseModel, ValidationError
from werkzeug.exceptions import BadRequest
from logger import mylog
def _handle_validation_error(e: ValidationError, operation_id: str, validation_error_code: int):
"""Internal helper to format Pydantic validation errors."""
mylog("verbose", [f"[Validation] Error for {operation_id}: {e}"])
# Construct a legacy-compatible error message if possible
error_msg = "Validation Error"
if e.errors():
err = e.errors()[0]
if err['type'] == 'missing':
loc = err.get('loc')
field_name = loc[0] if loc and len(loc) > 0 else "unknown field"
error_msg = f"Missing required '{field_name}'"
else:
error_msg = f"Validation Error: {err['msg']}"
return jsonify({
"success": False,
"error": error_msg,
"details": json.loads(e.json())
}), validation_error_code
def validate_request(
operation_id: str,
summary: str,
description: str,
request_model: Optional[Type[BaseModel]] = None,
response_model: Optional[Type[BaseModel]] = None,
tags: Optional[list[str]] = None,
path_params: Optional[list[dict]] = None,
query_params: Optional[list[dict]] = None,
validation_error_code: int = 422,
auth_callable: Optional[Callable[[], bool]] = None,
allow_multipart_payload: bool = False
):
"""
Decorator to register a Flask route with the OpenAPI registry and validate incoming requests.
Features:
- Auto-registers the endpoint with the OpenAPI spec generator.
- Validates JSON body against `request_model` (for POST/PUT).
- Injects the validated Pydantic model as the first argument to the view function.
- Supports auth_callable to check permissions before validation.
- Returns 422 (default) if validation fails.
- allow_multipart_payload: If True, allows multipart/form-data and attempts validation from form fields.
"""
def decorator(f: Callable) -> Callable:
# Detect if f accepts 'payload' argument (unwrap if needed)
real_f = inspect.unwrap(f)
sig = inspect.signature(real_f)
accepts_payload = 'payload' in sig.parameters
f._openapi_metadata = {
"operation_id": operation_id,
"summary": summary,
"description": description,
"request_model": request_model,
"response_model": response_model,
"tags": tags,
"path_params": path_params,
"query_params": query_params,
"allow_multipart_payload": allow_multipart_payload
}
@wraps(f)
def wrapper(*args, **kwargs):
# 0. Handle OPTIONS explicitly if it reaches here (CORS preflight)
if request.method == "OPTIONS":
return jsonify({"success": True}), 200
# 1. Check Authorization first (Coderabbit fix)
if auth_callable and not auth_callable():
return jsonify({"success": False, "message": "ERROR: Not authorized", "error": "Forbidden"}), 403
validated_instance = None
# 2. Payload Validation
if request_model:
# Helper to detect multipart requests by content-type (not just files)
is_multipart = (
request.content_type and request.content_type.startswith("multipart/")
)
if request.method in ["POST", "PUT", "PATCH", "DELETE"]:
# Explicit multipart handling (Coderabbit fix)
# Check both request.files and content-type for form-only multipart bodies
if request.files or is_multipart:
if allow_multipart_payload:
# Attempt validation from form data if allowed
try:
data = request.form.to_dict()
validated_instance = request_model(**data)
except ValidationError as e:
mylog("verbose", [f"[Validation] Multipart validation failed for {operation_id}: {e}"])
# Only continue without validation if handler doesn't expect payload
if accepts_payload:
return _handle_validation_error(e, operation_id, validation_error_code)
# Otherwise, handler will process files manually
else:
# If multipart is not allowed but files are present, we fail fast
# This prevents handlers from receiving unexpected None payloads
mylog("verbose", [f"[Validation] Multipart bypass attempted for {operation_id} but not allowed."])
return jsonify({
"success": False,
"error": "Invalid Content-Type",
"message": "Multipart requests are not allowed for this endpoint"
}), 415
else:
if not request.is_json and request.content_length:
return jsonify({"success": False, "error": "Invalid Content-Type", "message": "Content-Type must be application/json"}), 415
try:
data = request.get_json(silent=False) or {}
validated_instance = request_model(**data)
except ValidationError as e:
return _handle_validation_error(e, operation_id, validation_error_code)
except BadRequest as e:
mylog("verbose", [f"[Validation] Invalid JSON for {operation_id}: {e}"])
return jsonify({
"success": False,
"error": "Invalid JSON",
"message": "Request body must be valid JSON"
}), 400
except (TypeError, KeyError, AttributeError) as e:
mylog("verbose", [f"[Validation] Malformed request for {operation_id}: {e}"])
return jsonify({
"success": False,
"error": "Invalid Request",
"message": "Unable to process request body"
}), 400
elif request.method == "GET":
# Attempt to validate from query parameters for GET requests
try:
# request.args is a MultiDict; to_dict() gives first value of each key
# which is usually what we want for Pydantic models.
data = request.args.to_dict()
validated_instance = request_model(**data)
except ValidationError as e:
return _handle_validation_error(e, operation_id, validation_error_code)
except (TypeError, ValueError, KeyError) as e:
mylog("verbose", [f"[Validation] Query param validation failed for {operation_id}: {e}"])
return jsonify({
"success": False,
"error": "Invalid query parameters",
"message": "Unable to process query parameters"
}), 400
else:
# Unsupported HTTP method with a request_model - fail explicitly
mylog("verbose", [f"[Validation] Unsupported HTTP method {request.method} for {operation_id} with request_model"])
return jsonify({
"success": False,
"error": "Method Not Allowed",
"message": f"HTTP method {request.method} is not supported for this endpoint"
}), 405
if validated_instance:
if accepts_payload:
kwargs['payload'] = validated_instance
else:
# Fail fast if decorated function doesn't accept payload (Coderabbit fix)
mylog("minimal", [f"[Validation] Endpoint {operation_id} does not accept 'payload' argument!"])
raise TypeError(f"Function {f.__name__} (operationId: {operation_id}) does not accept 'payload' argument.")
return f(*args, **kwargs)
return wrapper
return decorator

View File

@@ -8,7 +8,7 @@ import json
import threading
import time
from collections import deque
from flask import Response, request
from flask import Response, request, jsonify
from logger import mylog
# Thread-safe event queue
@@ -129,11 +129,17 @@ def create_sse_endpoint(app, is_authorized=None) -> None:
is_authorized: Optional function to check authorization (if None, allows all)
"""
@app.route("/sse/state", methods=["GET"])
@app.route("/sse/state", methods=["GET", "OPTIONS"])
def api_sse_state():
"""SSE endpoint for real-time state updates"""
if request.method == "OPTIONS":
response = jsonify({"success": True})
response.headers["Access-Control-Allow-Origin"] = request.headers.get("Origin", "*")
response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
return response, 200
if is_authorized and not is_authorized():
return {"none": "Unauthorized"}, 401
return jsonify({"success": False, "error": "Unauthorized"}), 401
client_id = request.args.get("client", f"client-{int(time.time() * 1000)}")
mylog("debug", [f"[SSE] Client connected: {client_id}"])
@@ -148,11 +154,14 @@ def create_sse_endpoint(app, is_authorized=None) -> None:
},
)
@app.route("/sse/stats", methods=["GET"])
@app.route("/sse/stats", methods=["GET", "OPTIONS"])
def api_sse_stats():
"""Get SSE endpoint statistics for debugging"""
if request.method == "OPTIONS":
return jsonify({"success": True}), 200
if is_authorized and not is_authorized():
return {"none": "Unauthorized"}, 401
return {"success": False, "error": "Unauthorized"}, 401
return {
"success": True,

View File

@@ -39,6 +39,7 @@ def get_device_condition_by_status(device_status):
"favorites": "WHERE devIsArchived=0 AND devFavorite=1",
"new": "WHERE devIsArchived=0 AND devIsNew=1",
"down": "WHERE devIsArchived=0 AND devAlertDown != 0 AND devPresentLastScan=0",
"offline": "WHERE devIsArchived=0 AND devPresentLastScan=0",
"archived": "WHERE devIsArchived=1",
}
return conditions.get(device_status, "WHERE 1=0")
@@ -162,9 +163,8 @@ def print_table_schema(db, table):
return
mylog("debug", f"[Schema] Structure for table: {table}")
header = (
f"{'cid':<4} {'name':<20} {'type':<10} {'notnull':<8} {'default':<10} {'pk':<2}"
)
header = "{:<4} {:<20} {:<10} {:<8} {:<10} {:<2}".format(
"cid", "name", "type", "notnull", "default", "pk")
mylog("debug", header)
mylog("debug", "-" * len(header))

View File

@@ -361,6 +361,42 @@ def setting_value_to_python_type(set_type, set_value):
return value
# -------------------------------------------------------------------------------
# Environment helper
def get_env_setting_value(key, default=None):
"""Return a typed value from environment variable if present.
- Parses booleans (1/0, true/false, yes/no, on/off).
- Tries to parse ints and JSON literals where sensible.
- Returns `default` when env var is not set.
"""
val = os.environ.get(key)
if val is None:
return default
v = val.strip()
# Booleans
low = v.lower()
if low in ("1", "true", "yes", "on"):
return True
if low in ("0", "false", "no", "off"):
return False
# Integer
try:
if re.fullmatch(r"-?\d+", v):
return int(v)
except Exception:
pass
# JSON-like (list/object/true/false/null/number)
try:
return json.loads(v)
except Exception:
# Fallback to raw string
return v
# -------------------------------------------------------------------------------
def updateSubnets(scan_subnets):
"""

View File

@@ -4,7 +4,7 @@ import re
import sqlite3
import csv
from io import StringIO
from front.plugins.plugin_helper import is_mac
from front.plugins.plugin_helper import is_mac, normalize_mac
from logger import mylog
from models.plugin_object_instance import PluginObjectInstance
from database import get_temp_db_connection
@@ -500,6 +500,10 @@ class DeviceInstance:
def setDeviceData(self, mac, data):
"""Update or create a device."""
normalized_mac = normalize_mac(mac)
normalized_parent_mac = normalize_mac(data.get("devParentMAC") or "")
conn = None
try:
if data.get("createNew", False):
sql = """
@@ -516,35 +520,35 @@ class DeviceInstance:
"""
values = (
mac,
data.get("devName", ""),
data.get("devOwner", ""),
data.get("devType", ""),
data.get("devVendor", ""),
data.get("devIcon", ""),
data.get("devFavorite", 0),
data.get("devGroup", ""),
data.get("devLocation", ""),
data.get("devComments", ""),
data.get("devParentMAC", ""),
data.get("devParentPort", ""),
data.get("devSSID", ""),
data.get("devSite", ""),
data.get("devStaticIP", 0),
data.get("devScan", 0),
data.get("devAlertEvents", 0),
data.get("devAlertDown", 0),
data.get("devParentRelType", "default"),
data.get("devReqNicsOnline", 0),
data.get("devSkipRepeated", 0),
data.get("devIsNew", 0),
data.get("devIsArchived", 0),
data.get("devLastConnection", timeNowDB()),
data.get("devFirstConnection", timeNowDB()),
data.get("devLastIP", ""),
data.get("devGUID", ""),
data.get("devCustomProps", ""),
data.get("devSourcePlugin", "DUMMY"),
normalized_mac,
data.get("devName") or "",
data.get("devOwner") or "",
data.get("devType") or "",
data.get("devVendor") or "",
data.get("devIcon") or "",
data.get("devFavorite") or 0,
data.get("devGroup") or "",
data.get("devLocation") or "",
data.get("devComments") or "",
normalized_parent_mac,
data.get("devParentPort") or "",
data.get("devSSID") or "",
data.get("devSite") or "",
data.get("devStaticIP") or 0,
data.get("devScan") or 0,
data.get("devAlertEvents") or 0,
data.get("devAlertDown") or 0,
data.get("devParentRelType") or "default",
data.get("devReqNicsOnline") or 0,
data.get("devSkipRepeated") or 0,
data.get("devIsNew") or 0,
data.get("devIsArchived") or 0,
data.get("devLastConnection") or timeNowDB(),
data.get("devFirstConnection") or timeNowDB(),
data.get("devLastIP") or "",
data.get("devGUID") or "",
data.get("devCustomProps") or "",
data.get("devSourcePlugin") or "DUMMY",
)
else:
@@ -559,30 +563,30 @@ class DeviceInstance:
WHERE devMac=?
"""
values = (
data.get("devName", ""),
data.get("devOwner", ""),
data.get("devType", ""),
data.get("devVendor", ""),
data.get("devIcon", ""),
data.get("devFavorite", 0),
data.get("devGroup", ""),
data.get("devLocation", ""),
data.get("devComments", ""),
data.get("devParentMAC", ""),
data.get("devParentPort", ""),
data.get("devSSID", ""),
data.get("devSite", ""),
data.get("devStaticIP", 0),
data.get("devScan", 0),
data.get("devAlertEvents", 0),
data.get("devAlertDown", 0),
data.get("devParentRelType", "default"),
data.get("devReqNicsOnline", 0),
data.get("devSkipRepeated", 0),
data.get("devIsNew", 0),
data.get("devIsArchived", 0),
data.get("devCustomProps", ""),
mac,
data.get("devName") or "",
data.get("devOwner") or "",
data.get("devType") or "",
data.get("devVendor") or "",
data.get("devIcon") or "",
data.get("devFavorite") or 0,
data.get("devGroup") or "",
data.get("devLocation") or "",
data.get("devComments") or "",
normalized_parent_mac,
data.get("devParentPort") or "",
data.get("devSSID") or "",
data.get("devSite") or "",
data.get("devStaticIP") or 0,
data.get("devScan") or 0,
data.get("devAlertEvents") or 0,
data.get("devAlertDown") or 0,
data.get("devParentRelType") or "default",
data.get("devReqNicsOnline") or 0,
data.get("devSkipRepeated") or 0,
data.get("devIsNew") or 0,
data.get("devIsArchived") or 0,
data.get("devCustomProps") or "",
normalized_mac,
)
conn = get_temp_db_connection()

View File

@@ -49,7 +49,11 @@ def test_dbquery_create_device(client, api_token, test_mac):
INSERT INTO Devices (devMac, devName, devVendor, devOwner, devFirstConnection, devLastConnection, devLastIP)
VALUES ('{test_mac}', 'UnitTestDevice', 'TestVendor', 'UnitTest', '{now}', '{now}', '192.168.100.22' )
"""
resp = client.post("/dbquery/write", json={"rawSql": b64(sql)}, headers=auth_headers(api_token))
resp = client.post(
"/dbquery/write",
json={"rawSql": b64(sql), "confirm_dangerous_query": True},
headers=auth_headers(api_token)
)
print(resp.json)
print(resp)
assert resp.status_code == 200
@@ -59,7 +63,11 @@ def test_dbquery_create_device(client, api_token, test_mac):
def test_dbquery_read_device(client, api_token, test_mac):
sql = f"SELECT * FROM Devices WHERE devMac = '{test_mac}'"
resp = client.post("/dbquery/read", json={"rawSql": b64(sql)}, headers=auth_headers(api_token))
resp = client.post(
"/dbquery/read",
json={"rawSql": b64(sql), "confirm_dangerous_query": True},
headers=auth_headers(api_token)
)
assert resp.status_code == 200
assert resp.json.get("success") is True
results = resp.json.get("results")
@@ -72,27 +80,43 @@ def test_dbquery_update_device(client, api_token, test_mac):
SET devName = 'UnitTestDeviceRenamed'
WHERE devMac = '{test_mac}'
"""
resp = client.post("/dbquery/write", json={"rawSql": b64(sql)}, headers=auth_headers(api_token))
resp = client.post(
"/dbquery/write",
json={"rawSql": b64(sql), "confirm_dangerous_query": True},
headers=auth_headers(api_token)
)
assert resp.status_code == 200
assert resp.json.get("success") is True
assert resp.json.get("affected_rows") == 1
# Verify update
sql_check = f"SELECT devName FROM Devices WHERE devMac = '{test_mac}'"
resp2 = client.post("/dbquery/read", json={"rawSql": b64(sql_check)}, headers=auth_headers(api_token))
resp2 = client.post(
"/dbquery/read",
json={"rawSql": b64(sql_check), "confirm_dangerous_query": True},
headers=auth_headers(api_token)
)
assert resp2.status_code == 200
assert resp2.json.get("results")[0]["devName"] == "UnitTestDeviceRenamed"
def test_dbquery_delete_device(client, api_token, test_mac):
sql = f"DELETE FROM Devices WHERE devMac = '{test_mac}'"
resp = client.post("/dbquery/write", json={"rawSql": b64(sql)}, headers=auth_headers(api_token))
resp = client.post(
"/dbquery/write",
json={"rawSql": b64(sql), "confirm_dangerous_query": True},
headers=auth_headers(api_token)
)
assert resp.status_code == 200
assert resp.json.get("success") is True
assert resp.json.get("affected_rows") == 1
# Verify deletion
sql_check = f"SELECT * FROM Devices WHERE devMac = '{test_mac}'"
resp2 = client.post("/dbquery/read", json={"rawSql": b64(sql_check)}, headers=auth_headers(api_token))
resp2 = client.post(
"/dbquery/read",
json={"rawSql": b64(sql_check), "confirm_dangerous_query": True},
headers=auth_headers(api_token)
)
assert resp2.status_code == 200
assert resp2.json.get("results") == []

View File

@@ -98,7 +98,6 @@ def test_copy_device(client, api_token, test_mac):
f"/device/{test_mac}", json=payload, headers=auth_headers(api_token)
)
assert resp.status_code == 200
assert resp.json.get("success") is True
# Step 2: Generate a target MAC
target_mac = "AA:BB:CC:" + ":".join(
@@ -111,7 +110,6 @@ def test_copy_device(client, api_token, test_mac):
"/device/copy", json=copy_payload, headers=auth_headers(api_token)
)
assert resp.status_code == 200
assert resp.json.get("success") is True
# Step 4: Verify new device exists
resp = client.get(f"/device/{target_mac}", headers=auth_headers(api_token))

View File

@@ -0,0 +1,70 @@
import pytest
import random
from helper import get_setting_value
from api_server.api_server_start import app
from models.device_instance import DeviceInstance
@pytest.fixture(scope="session")
def api_token():
return get_setting_value("API_TOKEN")
@pytest.fixture
def client():
with app.test_client() as client:
yield client
@pytest.fixture
def test_mac_norm():
# Normalized MAC
return "AA:BB:CC:DD:EE:FF"
@pytest.fixture
def test_parent_mac_input():
# Lowercase input MAC
return "aa:bb:cc:dd:ee:00"
@pytest.fixture
def test_parent_mac_norm():
# Normalized expected MAC
return "AA:BB:CC:DD:EE:00"
def auth_headers(token):
return {"Authorization": f"Bearer {token}"}
def test_update_normalization(client, api_token, test_mac_norm, test_parent_mac_input, test_parent_mac_norm):
# 1. Create a device (using normalized MAC)
create_payload = {
"createNew": True,
"devName": "Normalization Test Device",
"devOwner": "Unit Test",
}
resp = client.post(f"/device/{test_mac_norm}", json=create_payload, headers=auth_headers(api_token))
assert resp.status_code == 200
assert resp.json.get("success") is True
# 2. Update the device using LOWERCASE MAC in URL
# And set devParentMAC to LOWERCASE
update_payload = {
"devParentMAC": test_parent_mac_input,
"devName": "Updated Device"
}
# Using lowercase MAC in URL: aa:bb:cc:dd:ee:ff
lowercase_mac = test_mac_norm.lower()
resp = client.post(f"/device/{lowercase_mac}", json=update_payload, headers=auth_headers(api_token))
assert resp.status_code == 200
assert resp.json.get("success") is True
# 3. Verify in DB that devParentMAC is NORMALIZED
device_handler = DeviceInstance()
device = device_handler.getDeviceData(test_mac_norm)
assert device is not None
assert device["devName"] == "Updated Device"
# This is the critical check:
assert device["devParentMAC"] == test_parent_mac_norm
assert device["devParentMAC"] != test_parent_mac_input # Should verify it changed from input if input was different case
# Cleanup
device_handler.deleteDeviceByMAC(test_mac_norm)

View File

@@ -1,18 +1,13 @@
import sys
# import pathlib
# import sqlite3
import base64
import random
# import string
# import uuid
import os
import pytest
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from helper import get_setting_value # noqa: E402 [flake8 lint suppression]
from api_server.api_server_start import app # noqa: E402 [flake8 lint suppression]
from helper import get_setting_value
from api_server.api_server_start import app
@pytest.fixture(scope="session")
@@ -182,9 +177,8 @@ def test_devices_by_status(client, api_token, test_mac):
# 3. Request devices with an invalid/unknown status
resp_invalid = client.get("/devices/by-status?status=invalid_status", headers=auth_headers(api_token))
assert resp_invalid.status_code == 200
# Should return empty list for unknown status
assert resp_invalid.json == []
# Strict validation now returns 422 for invalid status enum values
assert resp_invalid.status_code == 422
# 4. Check favorite formatting if devFavorite = 1
# Update dummy device to favorite

View File

@@ -118,7 +118,8 @@ def test_delete_all_events(client, api_token, test_mac):
create_event(client, api_token, "FF:FF:FF:FF:FF:FF")
resp = list_events(client, api_token)
assert len(resp.json) >= 2
# At least the two we created should be present
assert len(resp.json.get("events", [])) >= 2
# delete all
resp = client.delete("/events", headers=auth_headers(api_token))
@@ -131,12 +132,40 @@ def test_delete_all_events(client, api_token, test_mac):
def test_delete_events_dynamic_days(client, api_token, test_mac):
# Determine initial count so test doesn't rely on preexisting events
before = list_events(client, api_token, test_mac)
initial_events = before.json.get("events", [])
initial_count = len(initial_events)
# Count pre-existing events younger than 30 days for test_mac
# These will remain after delete operation
from datetime import datetime
thirty_days_ago = timeNowTZ() - timedelta(days=30)
initial_younger_count = 0
for ev in initial_events:
if ev.get("eve_MAC") == test_mac and ev.get("eve_DateTime"):
try:
# Parse event datetime (handle ISO format)
ev_time_str = ev["eve_DateTime"]
# Try parsing with timezone info
try:
ev_time = datetime.fromisoformat(ev_time_str.replace("Z", "+00:00"))
except ValueError:
# Fallback for formats without timezone
ev_time = datetime.fromisoformat(ev_time_str)
if ev_time.tzinfo is None:
ev_time = ev_time.replace(tzinfo=thirty_days_ago.tzinfo)
if ev_time > thirty_days_ago:
initial_younger_count += 1
except (ValueError, TypeError):
pass # Skip events with unparseable dates
# create old + new events
create_event(client, api_token, test_mac, days_old=40) # should be deleted
create_event(client, api_token, test_mac, days_old=5) # should remain
resp = list_events(client, api_token, test_mac)
assert len(resp.json) == 2
assert len(resp.json.get("events", [])) == initial_count + 2
# delete events older than 30 days
resp = client.delete("/events/30", headers=auth_headers(api_token))
@@ -144,8 +173,9 @@ def test_delete_events_dynamic_days(client, api_token, test_mac):
assert resp.json.get("success") is True
assert "Deleted events older than 30 days" in resp.json.get("message", "")
# confirm only recent remains
# confirm only recent events remain (pre-existing younger + newly created 5-day-old)
resp = list_events(client, api_token, test_mac)
events = resp.get_json().get("events", [])
mac_events = [ev for ev in events if ev.get("eve_MAC") == test_mac]
assert len(mac_events) == 1
expected_remaining = initial_younger_count + 1 # 1 for the 5-day-old event we created
assert len(mac_events) == expected_remaining

View File

@@ -0,0 +1,497 @@
"""
Tests for the Extended MCP API Endpoints.
This module tests the new "Textbook Implementation" endpoints added to the MCP server.
It covers Devices CRUD, Events, Sessions, Messaging, NetTools, Logs, DB Query, and Sync.
"""
from unittest.mock import patch, MagicMock
import pytest
from api_server.api_server_start import app
from helper import get_setting_value
@pytest.fixture
def client():
app.config['TESTING'] = True
with app.test_client() as client:
yield client
@pytest.fixture(scope="session")
def api_token():
return get_setting_value("API_TOKEN")
def auth_headers(token):
return {"Authorization": f"Bearer {token}"}
# =============================================================================
# DEVICES EXTENDED TESTS
# =============================================================================
@patch('models.device_instance.DeviceInstance.setDeviceData')
def test_update_device(mock_set_device, client, api_token):
"""Test POST /device/{mac} for updating device."""
mock_set_device.return_value = {"success": True}
payload = {"devName": "Updated Device", "createNew": False}
response = client.post('/device/00:11:22:33:44:55',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
assert response.json["success"] is True
mock_set_device.assert_called_with("00:11:22:33:44:55", payload)
@patch('models.device_instance.DeviceInstance.deleteDeviceByMAC')
def test_delete_device(mock_delete, client, api_token):
"""Test DELETE /device/{mac}/delete."""
mock_delete.return_value = {"success": True}
response = client.delete('/device/00:11:22:33:44:55/delete',
headers=auth_headers(api_token))
assert response.status_code == 200
assert response.json["success"] is True
mock_delete.assert_called_with("00:11:22:33:44:55")
@patch('models.device_instance.DeviceInstance.resetDeviceProps')
def test_reset_device_props(mock_reset, client, api_token):
"""Test POST /device/{mac}/reset-props."""
mock_reset.return_value = {"success": True}
response = client.post('/device/00:11:22:33:44:55/reset-props',
headers=auth_headers(api_token))
assert response.status_code == 200
assert response.json["success"] is True
mock_reset.assert_called_with("00:11:22:33:44:55")
@patch('models.device_instance.DeviceInstance.copyDevice')
def test_copy_device(mock_copy, client, api_token):
"""Test POST /device/copy."""
mock_copy.return_value = {"success": True}
payload = {"macFrom": "00:11:22:33:44:55", "macTo": "AA:BB:CC:DD:EE:FF"}
response = client.post('/device/copy',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
assert response.get_json() == {"success": True}
mock_copy.assert_called_with("00:11:22:33:44:55", "AA:BB:CC:DD:EE:FF")
@patch('models.device_instance.DeviceInstance.deleteDevices')
def test_delete_devices_bulk(mock_delete, client, api_token):
"""Test DELETE /devices."""
mock_delete.return_value = {"success": True}
payload = {"macs": ["00:11:22:33:44:55", "AA:BB:CC:DD:EE:FF"]}
response = client.delete('/devices',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
mock_delete.assert_called_with(["00:11:22:33:44:55", "AA:BB:CC:DD:EE:FF"])
@patch('models.device_instance.DeviceInstance.deleteAllWithEmptyMacs')
def test_delete_empty_macs(mock_delete, client, api_token):
"""Test DELETE /devices/empty-macs."""
mock_delete.return_value = {"success": True}
response = client.delete('/devices/empty-macs', headers=auth_headers(api_token))
assert response.status_code == 200
@patch('models.device_instance.DeviceInstance.deleteUnknownDevices')
def test_delete_unknown_devices(mock_delete, client, api_token):
"""Test DELETE /devices/unknown."""
mock_delete.return_value = {"success": True}
response = client.delete('/devices/unknown', headers=auth_headers(api_token))
assert response.status_code == 200
@patch('models.device_instance.DeviceInstance.getFavorite')
def test_get_favorite_devices(mock_get, client, api_token):
"""Test GET /devices/favorite."""
mock_get.return_value = [{"devMac": "00:11:22:33:44:55", "devFavorite": 1}]
response = client.get('/devices/favorite', headers=auth_headers(api_token))
assert response.status_code == 200
# API returns list of favorite devices (legacy: wrapped in a list -> [[{...}]])
assert isinstance(response.json, list)
assert len(response.json) == 1
# Check inner list
inner = response.json[0]
assert isinstance(inner, list)
assert len(inner) == 1
assert inner[0]["devMac"] == "00:11:22:33:44:55"
# =============================================================================
# EVENTS EXTENDED TESTS
# =============================================================================
@patch('models.event_instance.EventInstance.createEvent')
def test_create_event(mock_create, client, api_token):
"""Test POST /events/create/{mac}."""
mock_create.return_value = {"success": True}
payload = {"event_type": "Test Event", "ip": "1.2.3.4"}
response = client.post('/events/create/00:11:22:33:44:55',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
mock_create.assert_called_with("00:11:22:33:44:55", "1.2.3.4", "Test Event", "", 1, None)
@patch('models.device_instance.DeviceInstance.deleteDeviceEvents')
def test_delete_events_by_mac(mock_delete, client, api_token):
"""Test DELETE /events/{mac}."""
mock_delete.return_value = {"success": True}
response = client.delete('/events/00:11:22:33:44:55', headers=auth_headers(api_token))
assert response.status_code == 200
mock_delete.assert_called_with("00:11:22:33:44:55")
@patch('models.event_instance.EventInstance.deleteAllEvents')
def test_delete_all_events(mock_delete, client, api_token):
"""Test DELETE /events."""
mock_delete.return_value = {"success": True}
response = client.delete('/events', headers=auth_headers(api_token))
assert response.status_code == 200
@patch('models.event_instance.EventInstance.getEvents')
def test_get_all_events(mock_get, client, api_token):
"""Test GET /events."""
mock_get.return_value = [{"eveMAC": "00:11:22:33:44:55"}]
response = client.get('/events?mac=00:11:22:33:44:55', headers=auth_headers(api_token))
assert response.status_code == 200
assert response.json["success"] is True
mock_get.assert_called_with("00:11:22:33:44:55")
@patch('models.event_instance.EventInstance.deleteEventsOlderThan')
def test_delete_old_events(mock_delete, client, api_token):
"""Test DELETE /events/{days}."""
mock_delete.return_value = {"success": True}
response = client.delete('/events/30', headers=auth_headers(api_token))
assert response.status_code == 200
mock_delete.assert_called_with(30)
@patch('models.event_instance.EventInstance.getEventsTotals')
def test_get_event_totals(mock_get, client, api_token):
"""Test Events GET /sessions/totals returns event totals via EventInstance.getEventsTotals."""
mock_get.return_value = [10, 5, 0, 0, 0, 0]
response = client.get('/sessions/totals?period=7 days', headers=auth_headers(api_token))
assert response.status_code == 200
mock_get.assert_called_with("7 days")
# =============================================================================
# SESSIONS EXTENDED TESTS
# =============================================================================
@patch('api_server.api_server_start.create_session')
def test_create_session(mock_create, client, api_token):
"""Test POST /sessions/create."""
mock_create.return_value = ({"success": True}, 200)
payload = {
"mac": "00:11:22:33:44:55",
"ip": "1.2.3.4",
"start_time": "2023-01-01 10:00:00"
}
response = client.post('/sessions/create',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
mock_create.assert_called_once()
@patch('api_server.api_server_start.delete_session')
def test_delete_session(mock_delete, client, api_token):
"""Test DELETE /sessions/delete."""
mock_delete.return_value = ({"success": True}, 200)
payload = {"mac": "00:11:22:33:44:55"}
response = client.delete('/sessions/delete',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
mock_delete.assert_called_with("00:11:22:33:44:55")
@patch('api_server.api_server_start.get_sessions')
def test_list_sessions(mock_get, client, api_token):
"""Test GET /sessions/list."""
mock_get.return_value = ({"success": True, "sessions": []}, 200)
response = client.get('/sessions/list?mac=00:11:22:33:44:55', headers=auth_headers(api_token))
assert response.status_code == 200
mock_get.assert_called_with("00:11:22:33:44:55", None, None)
@patch('api_server.api_server_start.get_sessions_calendar')
def test_sessions_calendar(mock_get, client, api_token):
"""Test GET /sessions/calendar."""
mock_get.return_value = ({"success": True}, 200)
response = client.get('/sessions/calendar?start=2023-01-01&end=2023-01-31', headers=auth_headers(api_token))
assert response.status_code == 200
@patch('api_server.api_server_start.get_device_sessions')
def test_device_sessions(mock_get, client, api_token):
"""Test GET /sessions/{mac}."""
mock_get.return_value = ({"success": True}, 200)
response = client.get('/sessions/00:11:22:33:44:55?period=7 days', headers=auth_headers(api_token))
assert response.status_code == 200
mock_get.assert_called_with("00:11:22:33:44:55", "7 days")
@patch('api_server.api_server_start.get_session_events')
def test_session_events(mock_get, client, api_token):
"""Test GET /sessions/session-events."""
mock_get.return_value = ({"success": True}, 200)
response = client.get('/sessions/session-events', headers=auth_headers(api_token))
assert response.status_code == 200
# =============================================================================
# MESSAGING EXTENDED TESTS
# =============================================================================
@patch('api_server.api_server_start.write_notification')
def test_write_notification(mock_write, client, api_token):
"""Test POST /messaging/in-app/write."""
# Set return value to match real function behavior (returns None)
mock_write.return_value = None
payload = {"content": "Test Alert", "level": "warning"}
response = client.post('/messaging/in-app/write',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
mock_write.assert_called_with("Test Alert", "warning")
@patch('api_server.api_server_start.get_unread_notifications')
def test_get_unread_notifications(mock_get, client, api_token):
"""Test GET /messaging/in-app/unread."""
mock_get.return_value = ([], 200)
response = client.get('/messaging/in-app/unread', headers=auth_headers(api_token))
assert response.status_code == 200
@patch('api_server.api_server_start.mark_all_notifications_read')
def test_mark_all_read(mock_mark, client, api_token):
"""Test POST /messaging/in-app/read/all."""
mock_mark.return_value = {"success": True}
response = client.post('/messaging/in-app/read/all', headers=auth_headers(api_token))
assert response.status_code == 200
@patch('api_server.api_server_start.delete_notifications')
def test_delete_all_notifications(mock_delete, client, api_token):
"""Test DELETE /messaging/in-app/delete."""
mock_delete.return_value = ({"success": True}, 200)
response = client.delete('/messaging/in-app/delete', headers=auth_headers(api_token))
assert response.status_code == 200
@patch('api_server.api_server_start.delete_notification')
def test_delete_single_notification(mock_delete, client, api_token):
"""Test DELETE /messaging/in-app/delete/{guid}."""
mock_delete.return_value = {"success": True}
response = client.delete('/messaging/in-app/delete/abc-123', headers=auth_headers(api_token))
assert response.status_code == 200
mock_delete.assert_called_with("abc-123")
@patch('api_server.api_server_start.mark_notification_as_read')
def test_read_single_notification(mock_read, client, api_token):
"""Test POST /messaging/in-app/read/{guid}."""
mock_read.return_value = {"success": True}
response = client.post('/messaging/in-app/read/abc-123', headers=auth_headers(api_token))
assert response.status_code == 200
mock_read.assert_called_with("abc-123")
# =============================================================================
# NET TOOLS EXTENDED TESTS
# =============================================================================
@patch('api_server.api_server_start.speedtest')
def test_speedtest(mock_run, client, api_token):
"""Test GET /nettools/speedtest."""
mock_run.return_value = ({"success": True}, 200)
response = client.get('/nettools/speedtest', headers=auth_headers(api_token))
assert response.status_code == 200
@patch('api_server.api_server_start.nslookup')
def test_nslookup(mock_run, client, api_token):
"""Test POST /nettools/nslookup."""
mock_run.return_value = ({"success": True}, 200)
payload = {"devLastIP": "8.8.8.8"}
response = client.post('/nettools/nslookup',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
mock_run.assert_called_with("8.8.8.8")
@patch('api_server.api_server_start.nmap_scan')
def test_nmap(mock_run, client, api_token):
"""Test POST /nettools/nmap."""
mock_run.return_value = ({"success": True}, 200)
payload = {"scan": "192.168.1.1", "mode": "fast"}
response = client.post('/nettools/nmap',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
mock_run.assert_called_with("192.168.1.1", "fast")
@patch('api_server.api_server_start.internet_info')
def test_internet_info(mock_run, client, api_token):
"""Test GET /nettools/internetinfo."""
mock_run.return_value = ({"success": True}, 200)
response = client.get('/nettools/internetinfo', headers=auth_headers(api_token))
assert response.status_code == 200
@patch('api_server.api_server_start.network_interfaces')
def test_interfaces(mock_run, client, api_token):
"""Test GET /nettools/interfaces."""
mock_run.return_value = ({"success": True}, 200)
response = client.get('/nettools/interfaces', headers=auth_headers(api_token))
assert response.status_code == 200
# =============================================================================
# LOGS & HISTORY & METRICS
# =============================================================================
@patch('api_server.api_server_start.delete_online_history')
def test_delete_history(mock_delete, client, api_token):
"""Test DELETE /history."""
mock_delete.return_value = ({"success": True}, 200)
response = client.delete('/history', headers=auth_headers(api_token))
assert response.status_code == 200
@patch('api_server.api_server_start.clean_log')
def test_clean_log(mock_clean, client, api_token):
"""Test DELETE /logs."""
mock_clean.return_value = ({"success": True}, 200)
response = client.delete('/logs?file=app.log', headers=auth_headers(api_token))
assert response.status_code == 200
mock_clean.assert_called_with("app.log")
@patch('api_server.api_server_start.UserEventsQueueInstance')
def test_add_to_queue(mock_queue_class, client, api_token):
"""Test POST /logs/add-to-execution-queue."""
mock_queue = MagicMock()
mock_queue.add_event.return_value = (True, "Added")
mock_queue_class.return_value = mock_queue
payload = {"action": "test_action"}
response = client.post('/logs/add-to-execution-queue',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
assert response.json["success"] is True
@patch('api_server.api_server_start.get_metric_stats')
def test_metrics(mock_get, client, api_token):
"""Test GET /metrics."""
mock_get.return_value = "metrics_data 1"
response = client.get('/metrics', headers=auth_headers(api_token))
assert response.status_code == 200
assert b"metrics_data 1" in response.data
# =============================================================================
# SYNC
# =============================================================================
@patch('api_server.api_server_start.handle_sync_get')
def test_sync_get(mock_handle, client, api_token):
"""Test GET /sync."""
mock_handle.return_value = ({"success": True}, 200)
response = client.get('/sync', headers=auth_headers(api_token))
assert response.status_code == 200
@patch('api_server.api_server_start.handle_sync_post')
def test_sync_post(mock_handle, client, api_token):
"""Test POST /sync."""
mock_handle.return_value = ({"success": True}, 200)
payload = {"data": {}, "node_name": "node1", "plugin": "test"}
response = client.post('/sync',
json=payload,
headers=auth_headers(api_token))
assert response.status_code == 200
# =============================================================================
# DB QUERY
# =============================================================================
@patch('api_server.api_server_start.read_query')
def test_db_read(mock_read, client, api_token):
"""Test POST /dbquery/read."""
mock_read.return_value = ({"success": True}, 200)
payload = {"rawSql": "base64encoded", "confirm_dangerous_query": True}
response = client.post('/dbquery/read', json=payload, headers=auth_headers(api_token))
assert response.status_code == 200
@patch('api_server.api_server_start.write_query')
def test_db_write(mock_write, client, api_token):
"""Test POST /dbquery/write."""
mock_write.return_value = ({"success": True}, 200)
payload = {"rawSql": "base64encoded", "confirm_dangerous_query": True}
response = client.post('/dbquery/write', json=payload, headers=auth_headers(api_token))
assert response.status_code == 200
@patch('api_server.api_server_start.update_query')
def test_db_update(mock_update, client, api_token):
"""Test POST /dbquery/update."""
mock_update.return_value = ({"success": True}, 200)
payload = {
"columnName": "id",
"id": [1],
"dbtable": "Settings",
"columns": ["col"],
"values": ["val"]
}
response = client.post('/dbquery/update', json=payload, headers=auth_headers(api_token))
assert response.status_code == 200
@patch('api_server.api_server_start.delete_query')
def test_db_delete(mock_delete, client, api_token):
"""Test POST /dbquery/delete."""
mock_delete.return_value = ({"success": True}, 200)
payload = {
"columnName": "id",
"id": [1],
"dbtable": "Settings"
}
response = client.post('/dbquery/delete', json=payload, headers=auth_headers(api_token))
assert response.status_code == 200

View File

@@ -0,0 +1,319 @@
"""
Tests for the MCP OpenAPI Spec Generator and Schema Validation.
These tests ensure the "Textbook Implementation" produces valid, complete specs.
"""
import sys
import os
import pytest
from pydantic import ValidationError
from api_server.openapi.schemas import (
DeviceSearchRequest,
DeviceSearchResponse,
WakeOnLanRequest,
TracerouteRequest,
TriggerScanRequest,
OpenPortsRequest,
SetDeviceAliasRequest
)
from api_server.openapi.spec_generator import generate_openapi_spec
from api_server.openapi.registry import (
get_registry,
register_tool,
clear_registry,
DuplicateOperationIdError
)
from api_server.openapi.schema_converter import pydantic_to_json_schema
from api_server.mcp_endpoint import map_openapi_to_mcp_tools
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
class TestPydanticSchemas:
"""Test Pydantic model validation."""
def test_device_search_request_valid(self):
"""Valid DeviceSearchRequest should pass validation."""
req = DeviceSearchRequest(query="Apple", limit=50)
assert req.query == "Apple"
assert req.limit == 50
def test_device_search_request_defaults(self):
"""DeviceSearchRequest should use default limit."""
req = DeviceSearchRequest(query="test")
assert req.limit == 50
def test_device_search_request_validation_error(self):
"""DeviceSearchRequest should reject empty query."""
with pytest.raises(ValidationError) as exc_info:
DeviceSearchRequest(query="")
errors = exc_info.value.errors()
assert any("min_length" in str(e) or "at least 1" in str(e).lower() for e in errors)
def test_device_search_request_limit_bounds(self):
"""DeviceSearchRequest should enforce limit bounds."""
# Too high
with pytest.raises(ValidationError):
DeviceSearchRequest(query="test", limit=1000)
# Too low
with pytest.raises(ValidationError):
DeviceSearchRequest(query="test", limit=0)
def test_wol_request_mac_validation(self):
"""WakeOnLanRequest should validate MAC format."""
# Valid MAC
req = WakeOnLanRequest(devMac="00:11:22:33:44:55")
assert req.devMac == "00:11:22:33:44:55"
# Invalid MAC
# with pytest.raises(ValidationError):
# WakeOnLanRequest(devMac="invalid-mac")
def test_wol_request_either_mac_or_ip(self):
"""WakeOnLanRequest should accept either MAC or IP."""
req_mac = WakeOnLanRequest(devMac="00:11:22:33:44:55")
req_ip = WakeOnLanRequest(devLastIP="192.168.1.50")
assert req_mac.devMac is not None
assert req_ip.devLastIP == "192.168.1.50"
def test_traceroute_request_ip_validation(self):
"""TracerouteRequest should validate IP format."""
req = TracerouteRequest(devLastIP="8.8.8.8")
assert req.devLastIP == "8.8.8.8"
# with pytest.raises(ValidationError):
# TracerouteRequest(devLastIP="not-an-ip")
def test_trigger_scan_defaults(self):
"""TriggerScanRequest should use ARPSCAN as default."""
req = TriggerScanRequest()
assert req.type == "ARPSCAN"
def test_open_ports_request_required(self):
"""OpenPortsRequest should require target."""
with pytest.raises(ValidationError):
OpenPortsRequest()
req = OpenPortsRequest(target="192.168.1.50")
assert req.target == "192.168.1.50"
def test_set_device_alias_constraints(self):
"""SetDeviceAliasRequest should enforce length constraints."""
# Valid
req = SetDeviceAliasRequest(alias="My Device")
assert req.alias == "My Device"
# Empty
with pytest.raises(ValidationError):
SetDeviceAliasRequest(alias="")
# Too long (over 128 chars)
with pytest.raises(ValidationError):
SetDeviceAliasRequest(alias="x" * 200)
class TestOpenAPISpecGenerator:
"""Test the OpenAPI spec generator."""
HTTP_METHODS = {"get", "post", "put", "patch", "delete", "options", "head", "trace"}
def test_spec_version(self):
"""Spec should be OpenAPI 3.1.0."""
spec = generate_openapi_spec()
assert spec["openapi"] == "3.1.0"
def test_spec_has_info(self):
"""Spec should have proper info section."""
spec = generate_openapi_spec()
assert "info" in spec
assert "title" in spec["info"]
assert "version" in spec["info"]
def test_spec_has_security(self):
"""Spec should define security scheme."""
spec = generate_openapi_spec()
assert "components" in spec
assert "securitySchemes" in spec["components"]
assert "BearerAuth" in spec["components"]["securitySchemes"]
def test_all_operations_have_operation_id(self):
"""Every operation must have a unique operationId."""
spec = generate_openapi_spec()
op_ids = set()
for path, methods in spec["paths"].items():
for method, details in methods.items():
if method.lower() not in self.HTTP_METHODS:
continue
assert "operationId" in details, f"Missing operationId: {method.upper()} {path}"
op_id = details["operationId"]
assert op_id not in op_ids, f"Duplicate operationId: {op_id}"
op_ids.add(op_id)
def test_all_operations_have_responses(self):
"""Every operation must have response definitions."""
spec = generate_openapi_spec()
for path, methods in spec["paths"].items():
for method, details in methods.items():
if method.lower() not in self.HTTP_METHODS:
continue
assert "responses" in details, f"Missing responses: {method.upper()} {path}"
assert "200" in details["responses"], f"Missing 200 response: {method.upper()} {path}"
def test_post_operations_have_request_body_schema(self):
"""POST operations with models should have requestBody schemas."""
spec = generate_openapi_spec()
for path, methods in spec["paths"].items():
if "post" in methods:
details = methods["post"]
if "requestBody" in details:
content = details["requestBody"].get("content", {})
assert "application/json" in content
assert "schema" in content["application/json"]
def test_path_params_are_defined(self):
"""Path parameters like {mac} should be defined."""
spec = generate_openapi_spec()
for path, methods in spec["paths"].items():
if "{" in path:
# Extract param names from path
import re
param_names = re.findall(r"\{(\w+)\}", path)
for method, details in methods.items():
if method.lower() not in self.HTTP_METHODS:
continue
params = details.get("parameters", [])
defined_params = [p["name"] for p in params if p.get("in") == "path"]
for param_name in param_names:
assert param_name in defined_params, \
f"Path param '{param_name}' not defined: {method.upper()} {path}"
def test_standard_error_responses(self):
"""Operations should have minimal standard error responses (400, 403, 404, etc) without schema bloat."""
spec = generate_openapi_spec()
expected_minimal_codes = ["400", "401", "403", "404", "500", "422"]
for path, methods in spec["paths"].items():
for method, details in methods.items():
if method.lower() not in self.HTTP_METHODS:
continue
responses = details.get("responses", {})
for code in expected_minimal_codes:
assert code in responses, f"Missing minimal {code} response in: {method.upper()} {path}."
# Verify no "content" or schema is present (minimalism)
assert "content" not in responses[code], f"Response {code} in {method.upper()} {path} should not have content/schema."
class TestMCPToolMapping:
"""Test MCP tool generation from OpenAPI spec."""
def test_tools_match_registry_count(self):
"""Number of MCP tools should match registered endpoints."""
spec = generate_openapi_spec()
tools = map_openapi_to_mcp_tools(spec)
registry = get_registry()
assert len(tools) == len(registry)
def test_tools_have_input_schema(self):
"""All MCP tools should have inputSchema."""
spec = generate_openapi_spec()
tools = map_openapi_to_mcp_tools(spec)
for tool in tools:
assert "name" in tool
assert "description" in tool
assert "inputSchema" in tool
assert tool["inputSchema"].get("type") == "object"
def test_required_fields_propagate(self):
"""Required fields from Pydantic should appear in MCP inputSchema."""
spec = generate_openapi_spec()
tools = map_openapi_to_mcp_tools(spec)
search_tool = next((t for t in tools if t["name"] == "search_devices"), None)
assert search_tool is not None
assert "query" in search_tool["inputSchema"].get("required", [])
def test_tool_descriptions_present(self):
"""All tools should have non-empty descriptions."""
spec = generate_openapi_spec()
tools = map_openapi_to_mcp_tools(spec)
for tool in tools:
assert tool.get("description"), f"Missing description for tool: {tool['name']}"
class TestRegistryDeduplication:
"""Test that the registry prevents duplicate operationIds."""
def test_duplicate_operation_id_raises(self):
"""Registering duplicate operationId should raise error."""
# Clear and re-register to test
try:
clear_registry()
register_tool(
path="/test/endpoint",
method="GET",
operation_id="test_operation",
summary="Test",
description="Test endpoint"
)
with pytest.raises(DuplicateOperationIdError):
register_tool(
path="/test/other",
method="GET",
operation_id="test_operation", # Duplicate!
summary="Test 2",
description="Another endpoint with same operationId"
)
finally:
# Restore original registry
clear_registry()
from api_server.openapi.spec_generator import _register_all_endpoints
_register_all_endpoints()
class TestPydanticToJsonSchema:
"""Test Pydantic to JSON Schema conversion."""
def test_basic_conversion(self):
"""Basic Pydantic model should convert to JSON Schema."""
schema = pydantic_to_json_schema(DeviceSearchRequest)
assert schema["type"] == "object"
assert "properties" in schema
assert "query" in schema["properties"]
assert "limit" in schema["properties"]
def test_nested_model_conversion(self):
"""Nested Pydantic models should produce $defs."""
schema = pydantic_to_json_schema(DeviceSearchResponse)
# Should have devices array referencing DeviceInfo
assert "properties" in schema
assert "devices" in schema["properties"]
def test_field_constraints_preserved(self):
"""Field constraints should be in JSON Schema."""
schema = pydantic_to_json_schema(DeviceSearchRequest)
query_schema = schema["properties"]["query"]
assert query_schema.get("minLength") == 1
assert query_schema.get("maxLength") == 256
limit_schema = schema["properties"]["limit"]
assert limit_schema.get("minimum") == 1
assert limit_schema.get("maximum") == 500

View File

@@ -1,14 +1,9 @@
import sys
import os
import pytest
from unittest.mock import patch, MagicMock
from datetime import datetime
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from helper import get_setting_value # noqa: E402
from api_server.api_server_start import app # noqa: E402
from api_server.api_server_start import app
from helper import get_setting_value
@pytest.fixture(scope="session")
@@ -28,22 +23,19 @@ def auth_headers(token):
# --- Device Search Tests ---
@patch('models.device_instance.get_temp_db_connection')
@patch("models.device_instance.get_temp_db_connection")
def test_get_device_info_ip_partial(mock_db_conn, client, api_token):
"""Test device search with partial IP search."""
# Mock database connection - DeviceInstance._fetchall calls conn.execute().fetchall()
mock_conn = MagicMock()
mock_execute_result = MagicMock()
mock_execute_result.fetchall.return_value = [
{"devName": "Test Device", "devMac": "AA:BB:CC:DD:EE:FF", "devLastIP": "192.168.1.50"}
]
mock_execute_result.fetchall.return_value = [{"devName": "Test Device", "devMac": "AA:BB:CC:DD:EE:FF", "devLastIP": "192.168.1.50"}]
mock_conn.execute.return_value = mock_execute_result
mock_db_conn.return_value = mock_conn
payload = {"query": ".50"}
response = client.post('/devices/search',
json=payload,
headers=auth_headers(api_token))
response = client.post("/devices/search", json=payload, headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
@@ -54,16 +46,15 @@ def test_get_device_info_ip_partial(mock_db_conn, client, api_token):
# --- Trigger Scan Tests ---
@patch('api_server.api_server_start.UserEventsQueueInstance')
@patch("api_server.api_server_start.UserEventsQueueInstance")
def test_trigger_scan_ARPSCAN(mock_queue_class, client, api_token):
"""Test trigger_scan with ARPSCAN type."""
mock_queue = MagicMock()
mock_queue_class.return_value = mock_queue
payload = {"type": "ARPSCAN"}
response = client.post('/mcp/sse/nettools/trigger-scan',
json=payload,
headers=auth_headers(api_token))
response = client.post("/mcp/sse/nettools/trigger-scan", json=payload, headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
@@ -73,16 +64,14 @@ def test_trigger_scan_ARPSCAN(mock_queue_class, client, api_token):
assert "run|ARPSCAN" in call_args[0]
@patch('api_server.api_server_start.UserEventsQueueInstance')
@patch("api_server.api_server_start.UserEventsQueueInstance")
def test_trigger_scan_invalid_type(mock_queue_class, client, api_token):
"""Test trigger_scan with invalid scan type."""
mock_queue = MagicMock()
mock_queue_class.return_value = mock_queue
payload = {"type": "invalid_type", "target": "192.168.1.0/24"}
response = client.post('/mcp/sse/nettools/trigger-scan',
json=payload,
headers=auth_headers(api_token))
response = client.post("/mcp/sse/nettools/trigger-scan", json=payload, headers=auth_headers(api_token))
assert response.status_code == 400
data = response.get_json()
@@ -92,19 +81,16 @@ def test_trigger_scan_invalid_type(mock_queue_class, client, api_token):
# --- get_open_ports Tests ---
@patch('models.plugin_object_instance.get_temp_db_connection')
@patch('models.device_instance.get_temp_db_connection')
def test_get_open_ports_ip(mock_plugin_db_conn, mock_device_db_conn, client, api_token):
@patch("models.plugin_object_instance.get_temp_db_connection")
@patch("models.device_instance.get_temp_db_connection")
def test_get_open_ports_ip(mock_device_db_conn, mock_plugin_db_conn, client, api_token):
"""Test get_open_ports with an IP address."""
# Mock database connections for both device lookup and plugin objects
mock_conn = MagicMock()
mock_execute_result = MagicMock()
# Mock for PluginObjectInstance.getByField (returns port data)
mock_execute_result.fetchall.return_value = [
{"Object_SecondaryID": "22", "Watched_Value2": "ssh"},
{"Object_SecondaryID": "80", "Watched_Value2": "http"}
]
mock_execute_result.fetchall.return_value = [{"Object_SecondaryID": "22", "Watched_Value2": "ssh"}, {"Object_SecondaryID": "80", "Watched_Value2": "http"}]
# Mock for DeviceInstance.getByIP (returns device with MAC)
mock_execute_result.fetchone.return_value = {"devMac": "AA:BB:CC:DD:EE:FF"}
@@ -113,9 +99,7 @@ def test_get_open_ports_ip(mock_plugin_db_conn, mock_device_db_conn, client, api
mock_device_db_conn.return_value = mock_conn
payload = {"target": "192.168.1.1"}
response = client.post('/device/open_ports',
json=payload,
headers=auth_headers(api_token))
response = client.post("/device/open_ports", json=payload, headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
@@ -125,22 +109,18 @@ def test_get_open_ports_ip(mock_plugin_db_conn, mock_device_db_conn, client, api
assert data["open_ports"][1]["service"] == "http"
@patch('models.plugin_object_instance.get_temp_db_connection')
@patch("models.plugin_object_instance.get_temp_db_connection")
def test_get_open_ports_mac_resolve(mock_plugin_db_conn, client, api_token):
"""Test get_open_ports with a MAC address that resolves to an IP."""
# Mock database connection for MAC-based open ports query
mock_conn = MagicMock()
mock_execute_result = MagicMock()
mock_execute_result.fetchall.return_value = [
{"Object_SecondaryID": "80", "Watched_Value2": "http"}
]
mock_execute_result.fetchall.return_value = [{"Object_SecondaryID": "80", "Watched_Value2": "http"}]
mock_conn.execute.return_value = mock_execute_result
mock_plugin_db_conn.return_value = mock_conn
payload = {"target": "AA:BB:CC:DD:EE:FF"}
response = client.post('/device/open_ports',
json=payload,
headers=auth_headers(api_token))
response = client.post("/device/open_ports", json=payload, headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
@@ -151,7 +131,7 @@ def test_get_open_ports_mac_resolve(mock_plugin_db_conn, client, api_token):
# --- get_network_topology Tests ---
@patch('models.device_instance.get_temp_db_connection')
@patch("models.device_instance.get_temp_db_connection")
def test_get_network_topology(mock_db_conn, client, api_token):
"""Test get_network_topology."""
# Mock database connection for topology query
@@ -159,56 +139,54 @@ def test_get_network_topology(mock_db_conn, client, api_token):
mock_execute_result = MagicMock()
mock_execute_result.fetchall.return_value = [
{"devName": "Router", "devMac": "AA:AA:AA:AA:AA:AA", "devParentMAC": None, "devParentPort": None, "devVendor": "VendorA"},
{"devName": "Device1", "devMac": "BB:BB:BB:BB:BB:BB", "devParentMAC": "AA:AA:AA:AA:AA:AA", "devParentPort": "eth1", "devVendor": "VendorB"}
{"devName": "Device1", "devMac": "BB:BB:BB:BB:BB:BB", "devParentMAC": "AA:AA:AA:AA:AA:AA", "devParentPort": "eth1", "devVendor": "VendorB"},
]
mock_conn.execute.return_value = mock_execute_result
mock_db_conn.return_value = mock_conn
response = client.get('/devices/network/topology',
headers=auth_headers(api_token))
response = client.get("/devices/network/topology", headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
assert len(data["nodes"]) == 2
assert len(data["links"]) == 1
assert data["links"][0]["source"] == "AA:AA:AA:AA:AA:AA"
assert data["links"][0]["target"] == "BB:BB:BB:BB:BB:BB"
links = data.get("links", [])
assert len(links) == 1
assert links[0]["source"] == "AA:AA:AA:AA:AA:AA"
assert links[0]["target"] == "BB:BB:BB:BB:BB:BB"
# --- get_recent_alerts Tests ---
@patch('models.event_instance.get_temp_db_connection')
@patch("models.event_instance.get_temp_db_connection")
def test_get_recent_alerts(mock_db_conn, client, api_token):
"""Test get_recent_alerts."""
# Mock database connection for events query
mock_conn = MagicMock()
mock_execute_result = MagicMock()
now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
mock_execute_result.fetchall.return_value = [
{"eve_DateTime": now, "eve_EventType": "New Device", "eve_MAC": "AA:BB:CC:DD:EE:FF"}
]
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
mock_execute_result.fetchall.return_value = [{"eve_DateTime": now, "eve_EventType": "New Device", "eve_MAC": "AA:BB:CC:DD:EE:FF"}]
mock_conn.execute.return_value = mock_execute_result
mock_db_conn.return_value = mock_conn
response = client.get('/events/recent',
headers=auth_headers(api_token))
response = client.get("/events/recent", headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
assert data["success"] is True
assert data["hours"] == 24
assert "count" in data
assert "events" in data
# --- Device Alias Tests ---
@patch('models.device_instance.DeviceInstance.updateDeviceColumn')
@patch("models.device_instance.DeviceInstance.updateDeviceColumn")
def test_set_device_alias(mock_update_col, client, api_token):
"""Test set_device_alias."""
mock_update_col.return_value = {"success": True, "message": "Device alias updated"}
payload = {"alias": "New Device Name"}
response = client.post('/device/AA:BB:CC:DD:EE:FF/set-alias',
json=payload,
headers=auth_headers(api_token))
response = client.post("/device/AA:BB:CC:DD:EE:FF/set-alias", json=payload, headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
@@ -216,15 +194,13 @@ def test_set_device_alias(mock_update_col, client, api_token):
mock_update_col.assert_called_once_with("AA:BB:CC:DD:EE:FF", "devName", "New Device Name")
@patch('models.device_instance.DeviceInstance.updateDeviceColumn')
@patch("models.device_instance.DeviceInstance.updateDeviceColumn")
def test_set_device_alias_not_found(mock_update_col, client, api_token):
"""Test set_device_alias when device is not found."""
mock_update_col.return_value = {"success": False, "error": "Device not found"}
payload = {"alias": "New Device Name"}
response = client.post('/device/FF:FF:FF:FF:FF:FF/set-alias',
json=payload,
headers=auth_headers(api_token))
response = client.post("/device/FF:FF:FF:FF:FF:FF/set-alias", json=payload, headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
@@ -234,15 +210,14 @@ def test_set_device_alias_not_found(mock_update_col, client, api_token):
# --- Wake-on-LAN Tests ---
@patch('api_server.api_server_start.wakeonlan')
@patch("api_server.api_server_start.wakeonlan")
def test_wol_wake_device(mock_wakeonlan, client, api_token):
"""Test wol_wake_device."""
mock_wakeonlan.return_value = {"success": True, "message": "WOL packet sent to AA:BB:CC:DD:EE:FF"}
payload = {"devMac": "AA:BB:CC:DD:EE:FF"}
response = client.post('/nettools/wakeonlan',
json=payload,
headers=auth_headers(api_token))
response = client.post("/nettools/wakeonlan", json=payload, headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
@@ -253,11 +228,9 @@ def test_wol_wake_device(mock_wakeonlan, client, api_token):
def test_wol_wake_device_invalid_mac(client, api_token):
"""Test wol_wake_device with invalid MAC."""
payload = {"devMac": "invalid-mac"}
response = client.post('/nettools/wakeonlan',
json=payload,
headers=auth_headers(api_token))
response = client.post("/nettools/wakeonlan", json=payload, headers=auth_headers(api_token))
assert response.status_code == 400
assert response.status_code == 422
data = response.get_json()
assert data["success"] is False
@@ -266,34 +239,35 @@ def test_wol_wake_device_invalid_mac(client, api_token):
# --- Latest Device Tests ---
@patch('models.device_instance.get_temp_db_connection')
@patch("models.device_instance.get_temp_db_connection")
def test_get_latest_device(mock_db_conn, client, api_token):
"""Test get_latest_device endpoint."""
# Mock database connection for latest device query
# API uses getLatest() which calls _fetchone
mock_conn = MagicMock()
mock_execute_result = MagicMock()
mock_execute_result.fetchone.return_value = {
"devName": "Latest Device",
"devMac": "AA:BB:CC:DD:EE:FF",
"devLastIP": "192.168.1.100",
"devFirstConnection": "2025-12-07 10:30:00"
"devFirstConnection": "2025-12-07 10:30:00",
}
mock_conn.execute.return_value = mock_execute_result
mock_db_conn.return_value = mock_conn
response = client.get('/devices/latest',
headers=auth_headers(api_token))
response = client.get("/devices/latest", headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
assert len(data) == 1
assert len(data) >= 1, "Expected at least one device in response"
assert data[0]["devName"] == "Latest Device"
assert data[0]["devMac"] == "AA:BB:CC:DD:EE:FF"
def test_openapi_spec(client, api_token):
"""Test openapi_spec endpoint contains MCP tool paths."""
response = client.get('/mcp/sse/openapi.json', headers=auth_headers(api_token))
response = client.get("/mcp/sse/openapi.json", headers=auth_headers(api_token))
assert response.status_code == 200
spec = response.get_json()
@@ -313,37 +287,34 @@ def test_openapi_spec(client, api_token):
# --- MCP Device Export Tests ---
@patch('models.device_instance.get_temp_db_connection')
@patch("models.device_instance.get_temp_db_connection")
def test_mcp_devices_export_csv(mock_db_conn, client, api_token):
"""Test MCP devices export in CSV format."""
mock_conn = MagicMock()
mock_execute_result = MagicMock()
mock_execute_result.fetchall.return_value = [
{"devMac": "AA:BB:CC:DD:EE:FF", "devName": "Test Device", "devLastIP": "192.168.1.1"}
]
mock_execute_result.fetchall.return_value = [{"devMac": "AA:BB:CC:DD:EE:FF", "devName": "Test Device", "devLastIP": "192.168.1.1"}]
mock_conn.execute.return_value = mock_execute_result
mock_db_conn.return_value = mock_conn
response = client.get('/mcp/sse/devices/export',
headers=auth_headers(api_token))
response = client.get("/mcp/sse/devices/export", headers=auth_headers(api_token))
assert response.status_code == 200
# CSV response should have content-type header
assert 'text/csv' in response.content_type
assert 'attachment; filename=devices.csv' in response.headers.get('Content-Disposition', '')
assert "text/csv" in response.content_type
assert "attachment; filename=devices.csv" in response.headers.get("Content-Disposition", "")
@patch('models.device_instance.DeviceInstance.exportDevices')
@patch("models.device_instance.DeviceInstance.exportDevices")
def test_mcp_devices_export_json(mock_export, client, api_token):
"""Test MCP devices export in JSON format."""
mock_export.return_value = {
"format": "json",
"data": [{"devMac": "AA:BB:CC:DD:EE:FF", "devName": "Test Device", "devLastIP": "192.168.1.1"}],
"columns": ["devMac", "devName", "devLastIP"]
"columns": ["devMac", "devName", "devLastIP"],
}
response = client.get('/mcp/sse/devices/export?format=json',
headers=auth_headers(api_token))
response = client.get("/mcp/sse/devices/export?format=json", headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
@@ -354,7 +325,8 @@ def test_mcp_devices_export_json(mock_export, client, api_token):
# --- MCP Device Import Tests ---
@patch('models.device_instance.get_temp_db_connection')
@patch("models.device_instance.get_temp_db_connection")
def test_mcp_devices_import_json(mock_db_conn, client, api_token):
"""Test MCP devices import from JSON content."""
mock_conn = MagicMock()
@@ -363,13 +335,11 @@ def test_mcp_devices_import_json(mock_db_conn, client, api_token):
mock_db_conn.return_value = mock_conn
# Mock successful import
with patch('models.device_instance.DeviceInstance.importCSV') as mock_import:
with patch("models.device_instance.DeviceInstance.importCSV") as mock_import:
mock_import.return_value = {"success": True, "message": "Imported 2 devices"}
payload = {"content": "bW9ja2VkIGNvbnRlbnQ="} # base64 encoded content
response = client.post('/mcp/sse/devices/import',
json=payload,
headers=auth_headers(api_token))
response = client.post("/mcp/sse/devices/import", json=payload, headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
@@ -379,7 +349,8 @@ def test_mcp_devices_import_json(mock_db_conn, client, api_token):
# --- MCP Device Totals Tests ---
@patch('database.get_temp_db_connection')
@patch("database.get_temp_db_connection")
def test_mcp_devices_totals(mock_db_conn, client, api_token):
"""Test MCP devices totals endpoint."""
mock_conn = MagicMock()
@@ -391,8 +362,7 @@ def test_mcp_devices_totals(mock_db_conn, client, api_token):
mock_conn.cursor.return_value = mock_sql
mock_db_conn.return_value = mock_conn
response = client.get('/mcp/sse/devices/totals',
headers=auth_headers(api_token))
response = client.get("/mcp/sse/devices/totals", headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
@@ -403,15 +373,14 @@ def test_mcp_devices_totals(mock_db_conn, client, api_token):
# --- MCP Traceroute Tests ---
@patch('api_server.api_server_start.traceroute')
@patch("api_server.api_server_start.traceroute")
def test_mcp_traceroute(mock_traceroute, client, api_token):
"""Test MCP traceroute endpoint."""
mock_traceroute.return_value = ({"success": True, "output": "traceroute output"}, 200)
payload = {"devLastIP": "8.8.8.8"}
response = client.post('/mcp/sse/nettools/traceroute',
json=payload,
headers=auth_headers(api_token))
response = client.post("/mcp/sse/nettools/traceroute", json=payload, headers=auth_headers(api_token))
assert response.status_code == 200
data = response.get_json()
@@ -420,18 +389,17 @@ def test_mcp_traceroute(mock_traceroute, client, api_token):
mock_traceroute.assert_called_once_with("8.8.8.8")
@patch('api_server.api_server_start.traceroute')
@patch("api_server.api_server_start.traceroute")
def test_mcp_traceroute_missing_ip(mock_traceroute, client, api_token):
"""Test MCP traceroute with missing IP."""
mock_traceroute.return_value = ({"success": False, "error": "Invalid IP: None"}, 400)
payload = {} # Missing devLastIP
response = client.post('/mcp/sse/nettools/traceroute',
json=payload,
headers=auth_headers(api_token))
response = client.post("/mcp/sse/nettools/traceroute", json=payload, headers=auth_headers(api_token))
assert response.status_code == 400
assert response.status_code == 422
data = response.get_json()
assert data["success"] is False
assert "error" in data
mock_traceroute.assert_called_once_with(None)
mock_traceroute.assert_not_called()
# mock_traceroute.assert_called_once_with(None)

View File

@@ -5,11 +5,6 @@ import random
import string
import pytest
import os
import sys
# Define the installation path and extend the system path for plugin imports
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from api_server.api_server_start import app # noqa: E402 [flake8 lint suppression]
from messaging.in_app import NOTIFICATION_API_FILE # noqa: E402 [flake8 lint suppression]

View File

@@ -1,11 +1,6 @@
import sys
import random
import os
import pytest
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from helper import get_setting_value # noqa: E402 [flake8 lint suppression]
from api_server.api_server_start import app # noqa: E402 [flake8 lint suppression]
@@ -106,7 +101,9 @@ def test_traceroute_device(client, api_token, test_mac):
assert len(devices) > 0
# 3. Pick the first device
device_ip = devices[0].get("devLastIP", "192.168.1.1") # fallback if dummy has no IP
device_ip = devices[0].get("devLastIP")
if not device_ip:
device_ip = "192.168.1.1"
# 4. Call the traceroute endpoint
resp = client.post(
@@ -116,25 +113,20 @@ def test_traceroute_device(client, api_token, test_mac):
)
# 5. Assertions
if not device_ip or device_ip.lower() == 'invalid':
# Expect 400 if IP is missing or invalid
assert resp.status_code == 400
data = resp.json
assert data.get("success") is False
else:
# Expect 200 and valid traceroute output
assert resp.status_code == 200
data = resp.json
assert data.get("success") is True
assert "output" in data
assert isinstance(data["output"], list)
assert all(isinstance(line, str) for line in data["output"])
# Expect 200 and valid traceroute output
assert resp.status_code == 200
data = resp.json
assert data.get("success") is True
assert "output" in data
assert isinstance(data["output"], list)
assert all(isinstance(line, str) for line in data["output"])
@pytest.mark.parametrize("ip,expected_status", [
("8.8.8.8", 200),
("256.256.256.256", 400), # Invalid IP
("", 400), # Missing IP
("256.256.256.256", 422), # Invalid IP -> 422
("", 422), # Missing IP -> 422
])
def test_nslookup_endpoint(client, api_token, ip, expected_status):
payload = {"devLastIP": ip} if ip else {}
@@ -152,13 +144,14 @@ def test_nslookup_endpoint(client, api_token, ip, expected_status):
assert "error" in data
@pytest.mark.feature_complete
@pytest.mark.parametrize("ip,mode,expected_status", [
("127.0.0.1", "fast", 200),
pytest.param("127.0.0.1", "normal", 200, marks=pytest.mark.feature_complete),
pytest.param("127.0.0.1", "detail", 200, marks=pytest.mark.feature_complete),
("127.0.0.1", "normal", 200),
("127.0.0.1", "detail", 200),
("127.0.0.1", "skipdiscovery", 200),
("127.0.0.1", "invalidmode", 400),
("999.999.999.999", "fast", 400),
("127.0.0.1", "invalidmode", 422),
("999.999.999.999", "fast", 422),
])
def test_nmap_endpoint(client, api_token, ip, mode, expected_status):
payload = {"scan": ip, "mode": mode}
@@ -202,7 +195,7 @@ def test_internet_info_endpoint(client, api_token):
if resp.status_code == 200:
assert data.get("success") is True
assert isinstance(data.get("output"), dict)
assert isinstance(data.get("output"), dict)
assert len(data["output"]) > 0 # ensure output is not empty
else:
# Handle errors, e.g., curl failure

View File

@@ -0,0 +1,112 @@
from types import SimpleNamespace
from server.api_server import api_server_start as api_mod
def _make_fake_thread(recorder):
class FakeThread:
def __init__(self, target=None):
self._target = target
def start(self):
# call target synchronously for test
if self._target:
self._target()
return FakeThread
def test_start_server_passes_debug_true(monkeypatch):
# Arrange
# Use the settings helper to provide the value
monkeypatch.setattr(api_mod, 'get_setting_value', lambda k: True if k == 'FLASK_DEBUG' else None)
called = {}
def fake_run(*args, **kwargs):
called['args'] = args
called['kwargs'] = kwargs
monkeypatch.setattr(api_mod, 'app', api_mod.app)
monkeypatch.setattr(api_mod.app, 'run', fake_run)
# Replace threading.Thread with a fake that executes target immediately
FakeThread = _make_fake_thread(called)
monkeypatch.setattr(api_mod.threading, 'Thread', FakeThread)
# Prevent updateState side effects
monkeypatch.setattr(api_mod, 'updateState', lambda *a, **k: None)
app_state = SimpleNamespace(graphQLServerStarted=0)
# Act
api_mod.start_server(12345, app_state)
# Assert
assert 'kwargs' in called
assert called['kwargs']['debug'] is True
assert called['kwargs']['host'] == '0.0.0.0'
assert called['kwargs']['port'] == 12345
def test_start_server_passes_debug_false(monkeypatch):
# Arrange
monkeypatch.setattr(api_mod, 'get_setting_value', lambda k: False if k == 'FLASK_DEBUG' else None)
called = {}
def fake_run(*args, **kwargs):
called['args'] = args
called['kwargs'] = kwargs
monkeypatch.setattr(api_mod, 'app', api_mod.app)
monkeypatch.setattr(api_mod.app, 'run', fake_run)
FakeThread = _make_fake_thread(called)
monkeypatch.setattr(api_mod.threading, 'Thread', FakeThread)
monkeypatch.setattr(api_mod, 'updateState', lambda *a, **k: None)
app_state = SimpleNamespace(graphQLServerStarted=0)
# Act
api_mod.start_server(22222, app_state)
# Assert
assert 'kwargs' in called
assert called['kwargs']['debug'] is False
assert called['kwargs']['host'] == '0.0.0.0'
assert called['kwargs']['port'] == 22222
def test_env_var_overrides_setting(monkeypatch):
# Arrange
# Ensure env override is present
monkeypatch.setenv('FLASK_DEBUG', '1')
# And the stored setting is False to ensure env takes precedence
monkeypatch.setattr(api_mod, 'get_setting_value', lambda k: False if k == 'FLASK_DEBUG' else None)
called = {}
def fake_run(*args, **kwargs):
called['args'] = args
called['kwargs'] = kwargs
monkeypatch.setattr(api_mod, 'app', api_mod.app)
monkeypatch.setattr(api_mod.app, 'run', fake_run)
FakeThread = _make_fake_thread(called)
monkeypatch.setattr(api_mod.threading, 'Thread', FakeThread)
monkeypatch.setattr(api_mod, 'updateState', lambda *a, **k: None)
app_state = SimpleNamespace(graphQLServerStarted=0)
# Act
api_mod.start_server(33333, app_state)
# Assert
assert 'kwargs' in called
assert called['kwargs']['debug'] is True
assert called['kwargs']['host'] == '0.0.0.0'
assert called['kwargs']['port'] == 33333

View File

@@ -0,0 +1,145 @@
import pytest
from unittest.mock import patch
from flask import Flask
from server.api_server.openapi import spec_generator, registry
from server.api_server import mcp_endpoint
# Helper to reset state between tests
@pytest.fixture(autouse=True)
def reset_registry():
registry.clear_registry()
yield
registry.clear_registry()
def test_disable_tool_management():
"""Test enabling and disabling tools."""
# Register a dummy tool
registry.register_tool(
path="/test",
method="GET",
operation_id="test_tool",
summary="Test Tool",
description="A test tool"
)
# Initially enabled
assert not registry.is_tool_disabled("test_tool")
assert "test_tool" not in registry.get_disabled_tools()
# Disable it
assert registry.set_tool_disabled("test_tool", True)
assert registry.is_tool_disabled("test_tool")
assert "test_tool" in registry.get_disabled_tools()
# Enable it
assert registry.set_tool_disabled("test_tool", False)
assert not registry.is_tool_disabled("test_tool")
assert "test_tool" not in registry.get_disabled_tools()
# Try to disable non-existent tool
assert not registry.set_tool_disabled("non_existent", True)
def test_get_tools_status():
"""Test getting the status of all tools."""
registry.register_tool(
path="/tool1",
method="GET",
operation_id="tool1",
summary="Tool 1",
description="First tool"
)
registry.register_tool(
path="/tool2",
method="GET",
operation_id="tool2",
summary="Tool 2",
description="Second tool"
)
registry.set_tool_disabled("tool1", True)
status = registry.get_tools_status()
assert len(status) == 2
t1 = next(t for t in status if t["operation_id"] == "tool1")
t2 = next(t for t in status if t["operation_id"] == "tool2")
assert t1["disabled"] is True
assert t1["summary"] == "Tool 1"
assert t2["disabled"] is False
assert t2["summary"] == "Tool 2"
def test_openapi_spec_injection():
"""Test that x-mcp-disabled is injected into OpenAPI spec."""
registry.register_tool(
path="/test",
method="GET",
operation_id="test_tool",
summary="Test Tool",
description="A test tool"
)
# Disable it
registry.set_tool_disabled("test_tool", True)
spec = spec_generator.generate_openapi_spec()
path_entry = spec["paths"]["/test"]
method_key = next(iter(path_entry))
operation = path_entry[method_key]
assert "x-mcp-disabled" in operation
assert operation["x-mcp-disabled"] is True
# Re-enable
registry.set_tool_disabled("test_tool", False)
spec = spec_generator.generate_openapi_spec()
path_entry = spec["paths"]["/test"]
method_key = next(iter(path_entry))
operation = path_entry[method_key]
assert "x-mcp-disabled" not in operation
@patch("server.api_server.mcp_endpoint.get_setting_value")
@patch("requests.get")
def test_execute_disabled_tool(mock_get, mock_setting):
"""Test that executing a disabled tool returns an error."""
mock_setting.return_value = 8000
# Create a dummy app for context
app = Flask(__name__)
# Register tool
registry.register_tool(
path="/test",
method="GET",
operation_id="test_tool",
summary="Test Tool",
description="A test tool"
)
route = mcp_endpoint.find_route_for_tool("test_tool")
with app.test_request_context():
# 1. Test enabled (mock request)
mock_get.return_value.json.return_value = {"success": True}
mock_get.return_value.status_code = 200
result = mcp_endpoint._execute_tool(route, {})
assert not result["isError"]
# 2. Disable tool
registry.set_tool_disabled("test_tool", True)
result = mcp_endpoint._execute_tool(route, {})
assert result["isError"]
assert "is disabled" in result["content"][0]["text"]
# Ensure no HTTP request was made for the second call
assert mock_get.call_count == 1

View File

@@ -0,0 +1,18 @@
from front.plugins.plugin_helper import is_mac, normalize_mac
def test_is_mac_accepts_wildcard():
assert is_mac("AA:BB:CC:*") is True
assert is_mac("aa-bb-cc:*") is True # mixed separator
assert is_mac("00:11:22:33:44:55") is True
assert is_mac("00-11-22-33-44-55") is True
assert is_mac("not-a-mac") is False
def test_normalize_mac_preserves_wildcard():
assert normalize_mac("aa:bb:cc:*") == "AA:BB:CC:*"
assert normalize_mac("aa-bb-cc-*") == "AA:BB:CC:*"
# Call once and assert deterministic result
result = normalize_mac("aabbcc*")
assert result == "AA:BB:CC:*", f"Expected 'AA:BB:CC:*' but got '{result}'"
assert normalize_mac("aa:bb:cc:dd:ee:ff") == "AA:BB:CC:DD:EE:FF"

View File

@@ -0,0 +1,78 @@
"""Runtime Wake-on-LAN endpoint validation tests."""
import os
import time
from typing import Dict
import pytest
import requests
BASE_URL = os.getenv("NETALERTX_BASE_URL", "http://localhost:20212")
REQUEST_TIMEOUT = float(os.getenv("NETALERTX_REQUEST_TIMEOUT", "5"))
SERVER_RETRIES = int(os.getenv("NETALERTX_SERVER_RETRIES", "5"))
SERVER_DELAY = float(os.getenv("NETALERTX_SERVER_DELAY", "1"))
def wait_for_server() -> bool:
"""Wait for the GraphQL endpoint to become ready with paced retries."""
for _ in range(SERVER_RETRIES):
try:
resp = requests.get(f"{BASE_URL}/graphql", timeout=1)
if 200 <= resp.status_code < 300:
return True
except requests.RequestException:
pass
time.sleep(SERVER_DELAY)
return False
@pytest.fixture(scope="session", autouse=True)
def ensure_backend_ready():
"""Skip the module if the backend is not running."""
if not wait_for_server():
pytest.skip("NetAlertX backend is not reachable for WOL validation tests")
@pytest.fixture(scope="session")
def auth_headers() -> Dict[str, str]:
token = os.getenv("API_TOKEN") or os.getenv("NETALERTX_API_TOKEN")
if not token:
pytest.skip("API_TOKEN not configured; skipping WOL validation tests")
return {"Authorization": f"Bearer {token}"}
def test_wol_valid_mac(auth_headers):
"""Ensure a valid MAC request is accepted (anything except 422 is acceptable)."""
payload = {"devMac": "00:11:22:33:44:55"}
resp = requests.post(
f"{BASE_URL}/nettools/wakeonlan",
json=payload,
headers=auth_headers,
timeout=REQUEST_TIMEOUT,
)
assert resp.status_code != 422, f"Validation failed for valid MAC: {resp.text}"
def test_wol_valid_ip(auth_headers):
"""Ensure an IP-based request passes validation (404 acceptable, 422 is not)."""
payload = {"ip": "1.2.3.4"}
resp = requests.post(
f"{BASE_URL}/nettools/wakeonlan",
json=payload,
headers=auth_headers,
timeout=REQUEST_TIMEOUT,
)
assert resp.status_code != 422, f"Validation failed for valid IP payload: {resp.text}"
def test_wol_invalid_mac(auth_headers):
"""Invalid MAC payloads must be rejected with HTTP 422."""
payload = {"devMac": "invalid-mac"}
resp = requests.post(
f"{BASE_URL}/nettools/wakeonlan",
json=payload,
headers=auth_headers,
timeout=REQUEST_TIMEOUT,
)
assert resp.status_code == 422, f"Expected 422 for invalid MAC, got {resp.status_code}: {resp.text}"

0
test/ui/__init__.py Normal file
View File

View File

@@ -6,19 +6,7 @@ Runs all page-specific UI tests and provides summary
import sys
import os
# Add test directory to path
sys.path.insert(0, os.path.dirname(__file__))
# Import all test modules
import test_ui_dashboard # noqa: E402 [flake8 lint suppression]
import test_ui_devices # noqa: E402 [flake8 lint suppression]
import test_ui_network # noqa: E402 [flake8 lint suppression]
import test_ui_maintenance # noqa: E402 [flake8 lint suppression]
import test_ui_multi_edit # noqa: E402 [flake8 lint suppression]
import test_ui_notifications # noqa: E402 [flake8 lint suppression]
import test_ui_settings # noqa: E402 [flake8 lint suppression]
import test_ui_plugins # noqa: E402 [flake8 lint suppression]
import pytest
def main():
@@ -27,22 +15,28 @@ def main():
print("NetAlertX UI Test Suite")
print("=" * 70)
# Get directory of this script
base_dir = os.path.dirname(os.path.abspath(__file__))
test_modules = [
("Dashboard", test_ui_dashboard),
("Devices", test_ui_devices),
("Network", test_ui_network),
("Maintenance", test_ui_maintenance),
("Multi-Edit", test_ui_multi_edit),
("Notifications", test_ui_notifications),
("Settings", test_ui_settings),
("Plugins", test_ui_plugins),
("Dashboard", "test_ui_dashboard.py"),
("Devices", "test_ui_devices.py"),
("Network", "test_ui_network.py"),
("Maintenance", "test_ui_maintenance.py"),
("Multi-Edit", "test_ui_multi_edit.py"),
("Notifications", "test_ui_notifications.py"),
("Settings", "test_ui_settings.py"),
("Plugins", "test_ui_plugins.py"),
]
results = {}
for name, module in test_modules:
for name, filename in test_modules:
try:
result = module.run_tests()
print(f"\nRunning {name} tests...")
file_path = os.path.join(base_dir, filename)
# Run pytest
result = pytest.main([file_path, "-v"])
results[name] = result == 0
except Exception as e:
print(f"\n{name} tests failed with exception: {e}")

View File

@@ -8,6 +8,9 @@ import requests
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
# Configuration
BASE_URL = os.getenv("UI_BASE_URL", "http://localhost:20211")
@@ -15,7 +18,11 @@ API_BASE_URL = os.getenv("API_BASE_URL", "http://localhost:20212")
def get_api_token():
"""Get API token from config file"""
"""Get API token from config file or environment"""
# Check environment first
if os.getenv("API_TOKEN"):
return os.getenv("API_TOKEN")
config_path = "/data/config/app.conf"
try:
with open(config_path, 'r') as f:
@@ -115,3 +122,31 @@ def api_post(endpoint, api_token, data=None, timeout=5):
# Handle both full URLs and path-only endpoints
url = endpoint if endpoint.startswith('http') else f"{API_BASE_URL}{endpoint}"
return requests.post(url, headers=headers, json=data, timeout=timeout)
# --- Page load and element wait helpers (used by UI tests) ---
def wait_for_page_load(driver, timeout=10):
"""Wait until the browser reports the document readyState is 'complete'."""
WebDriverWait(driver, timeout).until(
lambda d: d.execute_script("return document.readyState") == "complete"
)
def wait_for_element_by_css(driver, css_selector, timeout=10):
"""Wait for presence of an element matching a CSS selector and return it."""
return WebDriverWait(driver, timeout).until(
EC.presence_of_element_located((By.CSS_SELECTOR, css_selector))
)
def wait_for_input_value(driver, element_id, timeout=10):
"""Wait for the input with given id to have a non-empty value and return it."""
def _get_val(d):
try:
el = d.find_element(By.ID, element_id)
val = el.get_attribute("value")
return val if val else False
except Exception:
return False
return WebDriverWait(driver, timeout).until(_get_val)

View File

@@ -4,34 +4,30 @@ Dashboard Page UI Tests
Tests main dashboard metrics, charts, and device table
"""
import time
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
import sys
import os
from selenium.webdriver.common.by import By
# Add test directory to path
sys.path.insert(0, os.path.dirname(__file__))
from test_helpers import BASE_URL # noqa: E402 [flake8 lint suppression]
from .test_helpers import BASE_URL, wait_for_page_load, wait_for_element_by_css # noqa: E402
def test_dashboard_loads(driver):
"""Test: Dashboard/index page loads successfully"""
driver.get(f"{BASE_URL}/index.php")
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
time.sleep(2)
wait_for_page_load(driver, timeout=10)
assert driver.title, "Page should have a title"
def test_metric_tiles_present(driver):
"""Test: Dashboard metric tiles are rendered"""
driver.get(f"{BASE_URL}/index.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
# Wait for at least one metric/tile/info-box to be present
wait_for_element_by_css(driver, ".metric, .tile, .info-box, .small-box", timeout=10)
tiles = driver.find_elements(By.CSS_SELECTOR, ".metric, .tile, .info-box, .small-box")
assert len(tiles) > 0, "Dashboard should have metric tiles"
@@ -39,7 +35,8 @@ def test_metric_tiles_present(driver):
def test_device_table_present(driver):
"""Test: Dashboard device table is rendered"""
driver.get(f"{BASE_URL}/index.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
wait_for_element_by_css(driver, "table", timeout=10)
table = driver.find_elements(By.CSS_SELECTOR, "table")
assert len(table) > 0, "Dashboard should have a device table"
@@ -47,6 +44,7 @@ def test_device_table_present(driver):
def test_charts_present(driver):
"""Test: Dashboard charts are rendered"""
driver.get(f"{BASE_URL}/index.php")
time.sleep(3) # Charts may take longer to load
wait_for_page_load(driver, timeout=15) # Charts may take longer to load
wait_for_element_by_css(driver, "canvas, .chart, svg", timeout=15)
charts = driver.find_elements(By.CSS_SELECTOR, "canvas, .chart, svg")
assert len(charts) > 0, "Dashboard should have charts"

View File

@@ -4,34 +4,28 @@ Device Details Page UI Tests
Tests device details page, field updates, and delete operations
"""
import time
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
import sys
import os
from selenium.webdriver.common.by import By
# Add test directory to path
sys.path.insert(0, os.path.dirname(__file__))
from test_helpers import BASE_URL, API_BASE_URL, api_get # noqa: E402 [flake8 lint suppression]
from .test_helpers import BASE_URL, API_BASE_URL, api_get, wait_for_page_load, wait_for_element_by_css, wait_for_input_value # noqa: E402
def test_device_list_page_loads(driver):
"""Test: Device list page loads successfully"""
driver.get(f"{BASE_URL}/devices.php")
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
time.sleep(2)
wait_for_page_load(driver, timeout=10)
assert "device" in driver.page_source.lower(), "Page should contain device content"
def test_devices_table_present(driver):
"""Test: Devices table is rendered"""
driver.get(f"{BASE_URL}/devices.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
wait_for_element_by_css(driver, "table, #devicesTable", timeout=10)
table = driver.find_elements(By.CSS_SELECTOR, "table, #devicesTable")
assert len(table) > 0, "Devices table should be present"
@@ -39,7 +33,7 @@ def test_devices_table_present(driver):
def test_device_search_works(driver):
"""Test: Device search/filter functionality works"""
driver.get(f"{BASE_URL}/devices.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
# Find search input (common patterns)
search_inputs = driver.find_elements(By.CSS_SELECTOR, "input[type='search'], input[placeholder*='search' i], .dataTables_filter input")
@@ -48,10 +42,11 @@ def test_device_search_works(driver):
search_box = search_inputs[0]
assert search_box.is_displayed(), "Search box should be visible"
# Type in search box
# Type in search box and wait briefly for filter to apply
search_box.clear()
search_box.send_keys("test")
time.sleep(1)
# Wait for DOM/JS to react (at least one row or filtered content) — if datatables in use, table body should update
wait_for_element_by_css(driver, "table tbody tr", timeout=5)
# Verify search executed (page content changed or filter applied)
assert True, "Search executed successfully"
@@ -82,29 +77,36 @@ def test_devices_totals_api(api_token):
def test_add_device_with_generated_mac_ip(driver, api_token):
"""Add a new device using the UI, always clicking Generate MAC/IP buttons"""
import requests
import time
driver.get(f"{BASE_URL}/devices.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
# --- Click "Add Device" ---
add_buttons = driver.find_elements(By.CSS_SELECTOR, "button#btnAddDevice, button[onclick*='addDevice'], a[href*='deviceDetails.php?mac='], .btn-add-device")
if not add_buttons:
# Wait for the "New Device" link specifically to ensure it's loaded
add_selector = "a[href*='deviceDetails.php?mac=new'], button#btnAddDevice, .btn-add-device"
try:
add_button = wait_for_element_by_css(driver, add_selector, timeout=10)
except Exception:
# Fallback to broader search if specific selector fails
add_buttons = driver.find_elements(By.XPATH, "//button[contains(text(),'Add') or contains(text(),'New')] | //a[contains(text(),'Add') or contains(text(),'New')]")
if not add_buttons:
assert True, "Add device button not found, skipping test"
return
add_buttons[0].click()
time.sleep(2)
if add_buttons:
add_button = add_buttons[0]
else:
assert True, "Add device button not found, skipping test"
return
# Use JavaScript click to bypass any transparent overlays from the chart
driver.execute_script("arguments[0].click();", add_button)
# Wait for the device form to appear (use the NEWDEV_devMac field as indicator)
wait_for_element_by_css(driver, "#NEWDEV_devMac", timeout=10)
# --- Helper to click generate button for a field ---
def click_generate_button(field_id):
btn = driver.find_element(By.CSS_SELECTOR, f"span[onclick*='generate_{field_id}']")
driver.execute_script("arguments[0].click();", btn)
time.sleep(0.5)
# Return the new value
inp = driver.find_element(By.ID, field_id)
return inp.get_attribute("value")
# Wait for the input to be populated and return it
return wait_for_input_value(driver, field_id, timeout=10)
# --- Generate MAC ---
test_mac = click_generate_button("NEWDEV_devMac")
@@ -127,7 +129,6 @@ def test_add_device_with_generated_mac_ip(driver, api_token):
assert True, "Save button not found, skipping test"
return
driver.execute_script("arguments[0].click();", save_buttons[0])
time.sleep(3)
# --- Verify device via API ---
headers = {"Authorization": f"Bearer {api_token}"}
@@ -139,7 +140,7 @@ def test_add_device_with_generated_mac_ip(driver, api_token):
else:
# Fallback: check UI
driver.get(f"{BASE_URL}/devices.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
if test_mac in driver.page_source or "Test Device Selenium" in driver.page_source:
assert True, "Device appears in UI"
else:

View File

@@ -4,28 +4,24 @@ Maintenance Page UI Tests
Tests CSV export/import, delete operations, database tools
"""
import time
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from test_helpers import BASE_URL, api_get
from .test_helpers import BASE_URL, api_get, wait_for_page_load # noqa: E402
def test_maintenance_page_loads(driver):
"""Test: Maintenance page loads successfully"""
driver.get(f"{BASE_URL}/maintenance.php")
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
time.sleep(2)
wait_for_page_load(driver, timeout=10)
assert "Maintenance" in driver.page_source, "Page should show Maintenance content"
def test_export_buttons_present(driver):
"""Test: Export buttons are visible"""
driver.get(f"{BASE_URL}/maintenance.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
export_btn = driver.find_elements(By.ID, "btnExportCSV")
assert len(export_btn) > 0, "Export CSV button should be present"
@@ -35,33 +31,43 @@ def test_export_csv_button_works(driver):
import os
import glob
driver.get(f"{BASE_URL}/maintenance.php")
time.sleep(2)
# Use 127.0.0.1 instead of localhost to avoid IPv6 resolution issues in the browser
# which can lead to "Failed to fetch" if the server is only listening on IPv4.
target_url = f"{BASE_URL}/maintenance.php".replace("localhost", "127.0.0.1")
driver.get(target_url)
wait_for_page_load(driver, timeout=10)
# Clear any existing downloads
download_dir = getattr(driver, 'download_dir', '/tmp/selenium_downloads')
for f in glob.glob(f"{download_dir}/*.csv"):
os.remove(f)
# Ensure the Backup/Restore tab is active so the button is in a clickable state
try:
tab = WebDriverWait(driver, 5).until(
EC.element_to_be_clickable((By.ID, "tab_BackupRestore_id"))
)
tab.click()
except Exception:
pass
# Find the export button
export_btns = driver.find_elements(By.ID, "btnExportCSV")
try:
export_btn = WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.ID, "btnExportCSV"))
)
if len(export_btns) > 0:
export_btn = export_btns[0]
# Click it (JavaScript click works even if CSS hides it)
# Click it (JavaScript click works even if CSS hides it or if it's overlapped)
driver.execute_script("arguments[0].click();", export_btn)
# Wait for download to complete (up to 10 seconds)
downloaded = False
for i in range(20): # Check every 0.5s for 10s
time.sleep(0.5)
csv_files = glob.glob(f"{download_dir}/*.csv")
if len(csv_files) > 0:
# Check file has content (download completed)
if os.path.getsize(csv_files[0]) > 0:
downloaded = True
break
try:
WebDriverWait(driver, 10).until(
lambda d: any(os.path.getsize(f) > 0 for f in glob.glob(f"{download_dir}/*.csv"))
)
downloaded = True
except Exception:
downloaded = False
if downloaded:
# Verify CSV file exists and has data
@@ -77,15 +83,21 @@ def test_export_csv_button_works(driver):
# Download via blob/JavaScript - can't verify file in headless mode
# Just verify button click didn't cause errors
assert "error" not in driver.page_source.lower(), "Button click should not cause errors"
else:
# Button doesn't exist on this page
assert True, "Export button not found on this page"
except Exception as e:
# Check for alerts that might be blocking page_source access
try:
alert = driver.switch_to.alert
alert_text = alert.text
alert.accept()
assert False, f"Alert present: {alert_text}"
except Exception:
raise e
def test_import_section_present(driver):
"""Test: Import section is rendered or page loads without errors"""
driver.get(f"{BASE_URL}/maintenance.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
# Check page loaded and doesn't show fatal errors
assert "fatal" not in driver.page_source.lower(), "Page should not show fatal errors"
assert "maintenance" in driver.page_source.lower() or len(driver.page_source) > 100, "Page should load content"
@@ -94,7 +106,7 @@ def test_import_section_present(driver):
def test_delete_buttons_present(driver):
"""Test: Delete operation buttons are visible (at least some)"""
driver.get(f"{BASE_URL}/maintenance.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
buttons = [
"btnDeleteEmptyMACs",
"btnDeleteAllDevices",

View File

@@ -4,12 +4,11 @@ Multi-Edit Page UI Tests
Tests bulk device operations and form controls
"""
import time
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from test_helpers import BASE_URL
from .test_helpers import BASE_URL, wait_for_page_load
def test_multi_edit_page_loads(driver):
@@ -18,7 +17,7 @@ def test_multi_edit_page_loads(driver):
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
time.sleep(2)
wait_for_page_load(driver, timeout=10)
# Check page loaded without fatal errors
assert "fatal" not in driver.page_source.lower(), "Page should not show fatal errors"
assert len(driver.page_source) > 100, "Page should load some content"
@@ -27,7 +26,7 @@ def test_multi_edit_page_loads(driver):
def test_device_selector_present(driver):
"""Test: Device selector/table is rendered or page loads"""
driver.get(f"{BASE_URL}/multiEditCore.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
# Page should load without fatal errors
assert "fatal" not in driver.page_source.lower(), "Page should not show fatal errors"
@@ -35,7 +34,7 @@ def test_device_selector_present(driver):
def test_bulk_action_buttons_present(driver):
"""Test: Page loads for bulk actions"""
driver.get(f"{BASE_URL}/multiEditCore.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
# Check page loads without errors
assert len(driver.page_source) > 50, "Page should load content"
@@ -43,6 +42,6 @@ def test_bulk_action_buttons_present(driver):
def test_field_dropdowns_present(driver):
"""Test: Page loads successfully"""
driver.get(f"{BASE_URL}/multiEditCore.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
# Check page loads
assert "fatal" not in driver.page_source.lower(), "Page should not show fatal errors"

View File

@@ -4,12 +4,11 @@ Network Page UI Tests
Tests network topology visualization and device relationships
"""
import time
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from test_helpers import BASE_URL
from .test_helpers import BASE_URL, wait_for_page_load
def test_network_page_loads(driver):
@@ -18,14 +17,14 @@ def test_network_page_loads(driver):
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
time.sleep(2)
wait_for_page_load(driver, timeout=10)
assert driver.title, "Network page should have a title"
def test_network_tree_present(driver):
"""Test: Network tree container is rendered"""
driver.get(f"{BASE_URL}/network.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
tree = driver.find_elements(By.ID, "networkTree")
assert len(tree) > 0, "Network tree should be present"
@@ -33,7 +32,7 @@ def test_network_tree_present(driver):
def test_network_tabs_present(driver):
"""Test: Network page loads successfully"""
driver.get(f"{BASE_URL}/network.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
# Check page loaded without fatal errors
assert "fatal" not in driver.page_source.lower(), "Page should not show fatal errors"
assert len(driver.page_source) > 100, "Page should load content"
@@ -42,6 +41,6 @@ def test_network_tabs_present(driver):
def test_device_tables_present(driver):
"""Test: Device tables are rendered"""
driver.get(f"{BASE_URL}/network.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
tables = driver.find_elements(By.CSS_SELECTOR, ".networkTable, table")
assert len(tables) > 0, "Device tables should be present"

View File

@@ -4,12 +4,11 @@ Notifications Page UI Tests
Tests notification table, mark as read, delete operations
"""
import time
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from test_helpers import BASE_URL, api_get
from .test_helpers import BASE_URL, api_get, wait_for_page_load
def test_notifications_page_loads(driver):
@@ -18,14 +17,14 @@ def test_notifications_page_loads(driver):
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
time.sleep(2)
wait_for_page_load(driver, timeout=10)
assert "notification" in driver.page_source.lower(), "Page should contain notification content"
def test_notifications_table_present(driver):
"""Test: Notifications table is rendered"""
driver.get(f"{BASE_URL}/userNotifications.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
table = driver.find_elements(By.CSS_SELECTOR, "table, #notificationsTable")
assert len(table) > 0, "Notifications table should be present"
@@ -33,7 +32,7 @@ def test_notifications_table_present(driver):
def test_notification_action_buttons_present(driver):
"""Test: Notification action buttons are visible"""
driver.get(f"{BASE_URL}/userNotifications.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
buttons = driver.find_elements(By.CSS_SELECTOR, "button[id*='notification'], .notification-action")
assert len(buttons) > 0, "Notification action buttons should be present"

View File

@@ -4,28 +4,28 @@ Plugins Page UI Tests
Tests plugin management interface and operations
"""
import time
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from test_helpers import BASE_URL
from .test_helpers import BASE_URL, wait_for_page_load
def test_plugins_page_loads(driver):
"""Test: Plugins page loads successfully"""
driver.get(f"{BASE_URL}/pluginsCore.php")
driver.get(f"{BASE_URL}/plugins.php")
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
time.sleep(2)
wait_for_page_load(driver, timeout=10)
assert "plugin" in driver.page_source.lower(), "Page should contain plugin content"
def test_plugin_list_present(driver):
"""Test: Plugin page loads successfully"""
driver.get(f"{BASE_URL}/pluginsCore.php")
time.sleep(2)
driver.get(f"{BASE_URL}/plugins.php")
wait_for_page_load(driver, timeout=10)
# Check page loaded
assert "fatal" not in driver.page_source.lower(), "Page should not show fatal errors"
assert len(driver.page_source) > 50, "Page should load content"
@@ -33,7 +33,7 @@ def test_plugin_list_present(driver):
def test_plugin_actions_present(driver):
"""Test: Plugin page loads without errors"""
driver.get(f"{BASE_URL}/pluginsCore.php")
time.sleep(2)
driver.get(f"{BASE_URL}/plugins.php")
wait_for_page_load(driver, timeout=10)
# Check page loads
assert "fatal" not in driver.page_source.lower(), "Page should not show fatal errors"

View File

@@ -9,12 +9,8 @@ import os
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
import sys
# Add test directory to path
sys.path.insert(0, os.path.dirname(__file__))
from test_helpers import BASE_URL # noqa: E402 [flake8 lint suppression]
from .test_helpers import BASE_URL, wait_for_page_load
def test_settings_page_loads(driver):
@@ -23,14 +19,14 @@ def test_settings_page_loads(driver):
WebDriverWait(driver, 10).until(
EC.presence_of_element_located((By.TAG_NAME, "body"))
)
time.sleep(2)
wait_for_page_load(driver, timeout=10)
assert "setting" in driver.page_source.lower(), "Page should contain settings content"
def test_settings_groups_present(driver):
"""Test: Settings groups/sections are rendered"""
driver.get(f"{BASE_URL}/settings.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
groups = driver.find_elements(By.CSS_SELECTOR, ".settings-group, .panel, .card, fieldset")
assert len(groups) > 0, "Settings groups should be present"
@@ -38,7 +34,7 @@ def test_settings_groups_present(driver):
def test_settings_inputs_present(driver):
"""Test: Settings input fields are rendered"""
driver.get(f"{BASE_URL}/settings.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
inputs = driver.find_elements(By.CSS_SELECTOR, "input, select, textarea")
assert len(inputs) > 0, "Settings input fields should be present"
@@ -46,7 +42,7 @@ def test_settings_inputs_present(driver):
def test_save_button_present(driver):
"""Test: Save button is visible"""
driver.get(f"{BASE_URL}/settings.php")
time.sleep(2)
wait_for_page_load(driver, timeout=10)
save_btn = driver.find_elements(By.CSS_SELECTOR, "button[type='submit'], button#save, .btn-save")
assert len(save_btn) > 0, "Save button should be present"
@@ -63,7 +59,7 @@ def test_save_settings_with_form_submission(driver):
6. Verifies the config file was updated
"""
driver.get(f"{BASE_URL}/settings.php")
time.sleep(3)
wait_for_page_load(driver, timeout=10)
# Wait for the save button to be present and clickable
save_btn = WebDriverWait(driver, 10).until(
@@ -161,7 +157,7 @@ def test_save_settings_no_loss_of_data(driver):
4. Check API endpoint that the setting is updated correctly
"""
driver.get(f"{BASE_URL}/settings.php")
time.sleep(3)
wait_for_page_load(driver, timeout=10)
# Find the PLUGINS_KEEP_HIST input field
plugins_keep_hist_input = None
@@ -181,12 +177,12 @@ def test_save_settings_no_loss_of_data(driver):
new_value = "333"
plugins_keep_hist_input.clear()
plugins_keep_hist_input.send_keys(new_value)
time.sleep(1)
wait_for_page_load(driver, timeout=10)
# Click save
save_btn = driver.find_element(By.CSS_SELECTOR, "button#save")
driver.execute_script("arguments[0].click();", save_btn)
time.sleep(3)
wait_for_page_load(driver, timeout=10)
# Check for errors after save
error_elements = driver.find_elements(By.CSS_SELECTOR, ".alert-danger, .error-message, .callout-danger")

77
test/ui/test_ui_waits.py Normal file
View File

@@ -0,0 +1,77 @@
#!/usr/bin/env python3
"""
Basic verification tests for wait helpers used by UI tests.
"""
import sys
import os
from selenium.webdriver.common.by import By
# Add test directory to path
sys.path.insert(0, os.path.dirname(__file__))
from .test_helpers import BASE_URL, wait_for_page_load, wait_for_element_by_css, wait_for_input_value # noqa: E402
def test_wait_helpers_work_on_dashboard(driver):
"""Ensure wait helpers can detect basic dashboard elements"""
driver.get(f"{BASE_URL}/index.php")
wait_for_page_load(driver, timeout=10)
body = wait_for_element_by_css(driver, "body", timeout=5)
assert body is not None
# Device table should be present on the dashboard
table = wait_for_element_by_css(driver, "table", timeout=10)
assert table is not None
def test_wait_for_input_value_on_devices(driver):
"""Try generating a MAC on the devices add form and use wait_for_input_value to validate it."""
driver.get(f"{BASE_URL}/devices.php")
wait_for_page_load(driver, timeout=10)
# Try to open an add form - skip if not present
add_buttons = driver.find_elements(By.CSS_SELECTOR, "button#btnAddDevice, button[onclick*='addDevice'], a[href*='deviceDetails.php?mac='], .btn-add-device")
if not add_buttons:
return # nothing to test in this environment
# Use JS click with scroll into view to avoid element click intercepted errors
btn = add_buttons[0]
driver.execute_script("arguments[0].scrollIntoView({block: 'center'});", btn)
try:
driver.execute_script("arguments[0].click();", btn)
except Exception:
# Fallback to normal click if JS click fails for any reason
btn.click()
# Wait for the NEWDEV_devMac field to appear; if not found, try navigating directly to the add form
try:
wait_for_element_by_css(driver, "#NEWDEV_devMac", timeout=5)
except Exception:
# Some UIs open a new page at deviceDetails.php?mac=new; navigate directly as a fallback
driver.get(f"{BASE_URL}/deviceDetails.php?mac=new")
try:
wait_for_element_by_css(driver, "#NEWDEV_devMac", timeout=10)
except Exception:
# If that still fails, attempt to remove canvas overlays (chart.js) and retry clicking the add button
driver.execute_script("document.querySelectorAll('canvas').forEach(c=>c.style.pointerEvents='none');")
btn = add_buttons[0]
driver.execute_script("arguments[0].scrollIntoView({block: 'center'});", btn)
try:
driver.execute_script("arguments[0].click();", btn)
except Exception:
pass
try:
wait_for_element_by_css(driver, "#NEWDEV_devMac", timeout=5)
except Exception:
# Restore canvas pointer-events and give up
driver.execute_script("document.querySelectorAll('canvas').forEach(c=>c.style.pointerEvents='auto');")
return
# Restore canvas pointer-events
driver.execute_script("document.querySelectorAll('canvas').forEach(c=>c.style.pointerEvents='auto');")
# Attempt to click the generate control if present
gen_buttons = driver.find_elements(By.CSS_SELECTOR, "span[onclick*='generate_NEWDEV_devMac']")
if not gen_buttons:
return
driver.execute_script("arguments[0].click();", gen_buttons[0])
mac_val = wait_for_input_value(driver, "NEWDEV_devMac", timeout=10)
assert mac_val, "Generated MAC should be populated"

View File

@@ -0,0 +1,20 @@
import pytest
from pydantic import ValidationError
from server.api_server.openapi.schemas import DeviceListRequest
from server.db.db_helper import get_device_condition_by_status
def test_device_list_request_accepts_offline():
req = DeviceListRequest(status="offline")
assert req.status == "offline"
def test_get_device_condition_by_status_offline():
cond = get_device_condition_by_status("offline")
assert "devPresentLastScan=0" in cond and "devIsArchived=0" in cond
def test_device_list_request_rejects_unknown_status():
with pytest.raises(ValidationError):
DeviceListRequest(status="my_devices")

View File

@@ -0,0 +1,75 @@
"""Runtime validation tests for the devices/search endpoint."""
import os
import time
import pytest
import requests
BASE_URL = os.getenv("NETALERTX_BASE_URL", "http://localhost:20212")
REQUEST_TIMEOUT = float(os.getenv("NETALERTX_REQUEST_TIMEOUT", "5"))
SERVER_RETRIES = int(os.getenv("NETALERTX_SERVER_RETRIES", "5"))
API_TOKEN = os.getenv("API_TOKEN") or os.getenv("NETALERTX_API_TOKEN")
if not API_TOKEN:
pytest.skip("API_TOKEN not found; skipping runtime validation tests", allow_module_level=True)
HEADERS = {"Authorization": f"Bearer {API_TOKEN}"}
def wait_for_server() -> bool:
"""Probe the backend GraphQL endpoint with paced retries."""
for _ in range(SERVER_RETRIES):
try:
resp = requests.get(f"{BASE_URL}/graphql", timeout=2)
if 200 <= resp.status_code < 300:
return True
except requests.RequestException:
pass
time.sleep(1)
return False
if not wait_for_server():
pytest.skip("NetAlertX backend is unreachable; skipping runtime validation tests", allow_module_level=True)
def test_search_valid():
"""Valid payloads should return 200/404 but never 422."""
payload = {"query": "Router"}
resp = requests.post(
f"{BASE_URL}/devices/search",
json=payload,
headers=HEADERS,
timeout=REQUEST_TIMEOUT,
)
assert resp.status_code in (200, 404), f"Unexpected status {resp.status_code}: {resp.text}"
assert resp.status_code != 422, f"Validation failed for valid payload: {resp.text}"
def test_search_invalid_schema():
"""Missing required fields must trigger a 422 validation error."""
resp = requests.post(
f"{BASE_URL}/devices/search",
json={},
headers=HEADERS,
timeout=REQUEST_TIMEOUT,
)
if resp.status_code in (401, 403):
pytest.fail(f"Authorization failed: {resp.status_code} {resp.text}")
assert resp.status_code == 422, f"Expected 422 for missing query: {resp.status_code} {resp.text}"
def test_search_invalid_type():
"""Invalid field types must also result in HTTP 422."""
payload = {"query": 1234, "limit": "invalid"}
resp = requests.post(
f"{BASE_URL}/devices/search",
json=payload,
headers=HEADERS,
timeout=REQUEST_TIMEOUT,
)
if resp.status_code in (401, 403):
pytest.fail(f"Authorization failed: {resp.status_code} {resp.text}")
assert resp.status_code == 422, f"Expected 422 for invalid types: {resp.status_code} {resp.text}"