diff --git a/server/database.py b/server/database.py
index 8948ee1c..3bc5452a 100755
--- a/server/database.py
+++ b/server/database.py
@@ -198,12 +198,16 @@ class DB():
# # mylog('debug',[ '[Database] - get_table_as_json - returning json ', json.dumps(result) ])
# return json_obj(result, columnNames)
- def get_table_as_json(self, sqlQuery):
+ def get_table_as_json(self, sqlQuery, parameters=None):
"""
Wrapper to use the central get_table_as_json helper.
+
+ Args:
+ sqlQuery (str): The SQL query to execute.
+ parameters (dict, optional): Named parameters for the SQL query.
"""
try:
- result = get_table_json(self.sql, sqlQuery)
+ result = get_table_json(self.sql, sqlQuery, parameters)
except Exception as e:
mylog('minimal', ['[Database] - get_table_as_json ERROR:', e])
return json_obj({}, []) # return empty object on failure
diff --git a/server/db/db_helper.py b/server/db/db_helper.py
index 55f39472..6654be67 100755
--- a/server/db/db_helper.py
+++ b/server/db/db_helper.py
@@ -180,19 +180,23 @@ def list_to_where(logical_operator, column_name, condition_operator, values_list
return f'({condition})'
#-------------------------------------------------------------------------------
-def get_table_json(sql, sql_query):
+def get_table_json(sql, sql_query, parameters=None):
"""
Execute a SQL query and return the results as JSON-like dict.
Args:
sql: SQLite cursor or connection wrapper supporting execute(), description, and fetchall().
sql_query (str): The SQL query to execute.
+ parameters (dict, optional): Named parameters for the SQL query.
Returns:
dict: JSON-style object with data and column names.
"""
try:
- sql.execute(sql_query)
+ if parameters:
+ sql.execute(sql_query, parameters)
+ else:
+ sql.execute(sql_query)
rows = sql.fetchall()
if (rows):
# We only return data if we actually got some out of SQLite
diff --git a/server/db/sql_safe_builder.py b/server/db/sql_safe_builder.py
new file mode 100644
index 00000000..d3e285af
--- /dev/null
+++ b/server/db/sql_safe_builder.py
@@ -0,0 +1,365 @@
+"""
+NetAlertX SQL Safe Builder Module
+
+This module provides safe SQL condition building functionality to prevent
+SQL injection vulnerabilities. It validates inputs against whitelists,
+sanitizes data, and returns parameterized queries.
+
+Author: Security Enhancement for NetAlertX
+License: GNU GPLv3
+"""
+
+import re
+import sys
+from typing import Dict, List, Tuple, Any, Optional
+
+# Register NetAlertX directories
+INSTALL_PATH = "/app"
+sys.path.extend([f"{INSTALL_PATH}/server"])
+
+from logger import mylog
+
+
+class SafeConditionBuilder:
+ """
+ A secure SQL condition builder that validates inputs against whitelists
+ and generates parameterized SQL snippets to prevent SQL injection.
+ """
+
+ # Whitelist of allowed column names for filtering
+ ALLOWED_COLUMNS = {
+ 'eve_MAC', 'eve_DateTime', 'eve_IP', 'eve_EventType', 'devName',
+ 'devComments', 'devLastIP', 'devVendor', 'devAlertEvents',
+ 'devAlertDown', 'devIsArchived', 'devPresentLastScan', 'devFavorite',
+ 'devIsNew', 'Plugin', 'Object_PrimaryId', 'Object_SecondaryId',
+ 'DateTimeChanged', 'Watched_Value1', 'Watched_Value2', 'Watched_Value3',
+ 'Watched_Value4', 'Status'
+ }
+
+ # Whitelist of allowed comparison operators
+ ALLOWED_OPERATORS = {
+ '=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'NOT LIKE',
+ 'IN', 'NOT IN', 'IS NULL', 'IS NOT NULL'
+ }
+
+ # Whitelist of allowed logical operators
+ ALLOWED_LOGICAL_OPERATORS = {'AND', 'OR'}
+
+ # Whitelist of allowed event types
+ ALLOWED_EVENT_TYPES = {
+ 'New Device', 'Connected', 'Disconnected', 'Device Down',
+ 'Down Reconnected', 'IP Changed'
+ }
+
+ def __init__(self):
+ """Initialize the SafeConditionBuilder."""
+ self.parameters = {}
+ self.param_counter = 0
+
+ def _generate_param_name(self, prefix: str = 'param') -> str:
+ """Generate a unique parameter name for SQL binding."""
+ self.param_counter += 1
+ return f"{prefix}_{self.param_counter}"
+
+ def _sanitize_string(self, value: str) -> str:
+ """
+ Sanitize string input by removing potentially dangerous characters.
+
+ Args:
+ value: String to sanitize
+
+ Returns:
+ Sanitized string
+ """
+ if not isinstance(value, str):
+ return str(value)
+
+ # Replace {s-quote} placeholder with single quote (maintaining compatibility)
+ value = value.replace('{s-quote}', "'")
+
+ # Remove any null bytes, control characters, and excessive whitespace
+ value = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x84\x86-\x9f]', '', value)
+ value = re.sub(r'\s+', ' ', value.strip())
+
+ return value
+
+ def _validate_column_name(self, column: str) -> bool:
+ """
+ Validate that a column name is in the whitelist.
+
+ Args:
+ column: Column name to validate
+
+ Returns:
+ True if valid, False otherwise
+ """
+ return column in self.ALLOWED_COLUMNS
+
+ def _validate_operator(self, operator: str) -> bool:
+ """
+ Validate that an operator is in the whitelist.
+
+ Args:
+ operator: Operator to validate
+
+ Returns:
+ True if valid, False otherwise
+ """
+ return operator.upper() in self.ALLOWED_OPERATORS
+
+ def _validate_logical_operator(self, logical_op: str) -> bool:
+ """
+ Validate that a logical operator is in the whitelist.
+
+ Args:
+ logical_op: Logical operator to validate
+
+ Returns:
+ True if valid, False otherwise
+ """
+ return logical_op.upper() in self.ALLOWED_LOGICAL_OPERATORS
+
+ def build_safe_condition(self, condition_string: str) -> Tuple[str, Dict[str, Any]]:
+ """
+ Parse and build a safe SQL condition from a user-provided string.
+ This method attempts to parse common condition patterns and convert
+ them to parameterized queries.
+
+ Args:
+ condition_string: User-provided condition string
+
+ Returns:
+ Tuple of (safe_sql_snippet, parameters_dict)
+
+ Raises:
+ ValueError: If the condition contains invalid or unsafe elements
+ """
+ if not condition_string or not condition_string.strip():
+ return "", {}
+
+ # Sanitize the input
+ condition_string = self._sanitize_string(condition_string)
+
+ # Reset parameters for this condition
+ self.parameters = {}
+ self.param_counter = 0
+
+ try:
+ return self._parse_condition(condition_string)
+ except Exception as e:
+ mylog('verbose', f'[SafeConditionBuilder] Error parsing condition: {e}')
+ raise ValueError(f"Invalid condition format: {condition_string}")
+
+ def _parse_condition(self, condition: str) -> Tuple[str, Dict[str, Any]]:
+ """
+ Parse a condition string into safe SQL with parameters.
+
+ This method handles basic patterns like:
+ - AND devName = 'value'
+ - AND devComments LIKE '%value%'
+ - AND eve_EventType IN ('type1', 'type2')
+
+ Args:
+ condition: Condition string to parse
+
+ Returns:
+ Tuple of (safe_sql_snippet, parameters_dict)
+ """
+ condition = condition.strip()
+
+ # Handle empty conditions
+ if not condition:
+ return "", {}
+
+ # Simple pattern matching for common conditions
+ # Pattern 1: AND/OR column operator value (supporting Unicode in quoted strings)
+ pattern1 = r'^\s*(AND|OR)?\s+(\w+)\s+(=|!=|<>|<|>|<=|>=|LIKE|NOT\s+LIKE)\s+\'([^\']*)\'\s*$'
+ match1 = re.match(pattern1, condition, re.IGNORECASE | re.UNICODE)
+
+ if match1:
+ logical_op, column, operator, value = match1.groups()
+ return self._build_simple_condition(logical_op, column, operator, value)
+
+ # Pattern 2: AND/OR column IN ('val1', 'val2', ...)
+ pattern2 = r'^\s*(AND|OR)?\s+(\w+)\s+(IN|NOT\s+IN)\s+\(([^)]+)\)\s*$'
+ match2 = re.match(pattern2, condition, re.IGNORECASE)
+
+ if match2:
+ logical_op, column, operator, values_str = match2.groups()
+ return self._build_in_condition(logical_op, column, operator, values_str)
+
+ # Pattern 3: AND/OR column IS NULL/IS NOT NULL
+ pattern3 = r'^\s*(AND|OR)?\s+(\w+)\s+(IS\s+NULL|IS\s+NOT\s+NULL)\s*$'
+ match3 = re.match(pattern3, condition, re.IGNORECASE)
+
+ if match3:
+ logical_op, column, operator = match3.groups()
+ return self._build_null_condition(logical_op, column, operator)
+
+ # If no patterns match, reject the condition for security
+ raise ValueError(f"Unsupported condition pattern: {condition}")
+
+ def _build_simple_condition(self, logical_op: Optional[str], column: str,
+ operator: str, value: str) -> Tuple[str, Dict[str, Any]]:
+ """Build a simple condition with parameter binding."""
+ # Validate components
+ if not self._validate_column_name(column):
+ raise ValueError(f"Invalid column name: {column}")
+
+ if not self._validate_operator(operator):
+ raise ValueError(f"Invalid operator: {operator}")
+
+ if logical_op and not self._validate_logical_operator(logical_op):
+ raise ValueError(f"Invalid logical operator: {logical_op}")
+
+ # Generate parameter name and store value
+ param_name = self._generate_param_name()
+ self.parameters[param_name] = value
+
+ # Build the SQL snippet
+ sql_parts = []
+ if logical_op:
+ sql_parts.append(logical_op.upper())
+
+ sql_parts.extend([column, operator.upper(), f":{param_name}"])
+
+ return " ".join(sql_parts), self.parameters
+
+ def _build_in_condition(self, logical_op: Optional[str], column: str,
+ operator: str, values_str: str) -> Tuple[str, Dict[str, Any]]:
+ """Build an IN condition with parameter binding."""
+ # Validate components
+ if not self._validate_column_name(column):
+ raise ValueError(f"Invalid column name: {column}")
+
+ if logical_op and not self._validate_logical_operator(logical_op):
+ raise ValueError(f"Invalid logical operator: {logical_op}")
+
+ # Parse values from the IN clause
+ values = []
+ # Simple regex to extract quoted values
+ value_pattern = r"'([^']*)'"
+ matches = re.findall(value_pattern, values_str)
+
+ if not matches:
+ raise ValueError("No valid values found in IN clause")
+
+ # Generate parameters for each value
+ param_names = []
+ for value in matches:
+ param_name = self._generate_param_name()
+ self.parameters[param_name] = value
+ param_names.append(f":{param_name}")
+
+ # Build the SQL snippet
+ sql_parts = []
+ if logical_op:
+ sql_parts.append(logical_op.upper())
+
+ sql_parts.extend([column, operator.upper(), f"({', '.join(param_names)})"])
+
+ return " ".join(sql_parts), self.parameters
+
+ def _build_null_condition(self, logical_op: Optional[str], column: str,
+ operator: str) -> Tuple[str, Dict[str, Any]]:
+ """Build a NULL check condition."""
+ # Validate components
+ if not self._validate_column_name(column):
+ raise ValueError(f"Invalid column name: {column}")
+
+ if logical_op and not self._validate_logical_operator(logical_op):
+ raise ValueError(f"Invalid logical operator: {logical_op}")
+
+ # Build the SQL snippet (no parameters needed for NULL checks)
+ sql_parts = []
+ if logical_op:
+ sql_parts.append(logical_op.upper())
+
+ sql_parts.extend([column, operator.upper()])
+
+ return " ".join(sql_parts), {}
+
+ def build_device_name_filter(self, device_name: str) -> Tuple[str, Dict[str, Any]]:
+ """
+ Build a safe device name filter condition.
+
+ Args:
+ device_name: Device name to filter for
+
+ Returns:
+ Tuple of (safe_sql_snippet, parameters_dict)
+ """
+ if not device_name:
+ return "", {}
+
+ device_name = self._sanitize_string(device_name)
+ param_name = self._generate_param_name('device_name')
+ self.parameters[param_name] = device_name
+
+ return f"AND devName = :{param_name}", self.parameters
+
+ def build_event_type_filter(self, event_types: List[str]) -> Tuple[str, Dict[str, Any]]:
+ """
+ Build a safe event type filter condition.
+
+ Args:
+ event_types: List of event types to filter for
+
+ Returns:
+ Tuple of (safe_sql_snippet, parameters_dict)
+ """
+ if not event_types:
+ return "", {}
+
+ # Validate event types against whitelist
+ valid_types = []
+ for event_type in event_types:
+ event_type = self._sanitize_string(event_type)
+ if event_type in self.ALLOWED_EVENT_TYPES:
+ valid_types.append(event_type)
+ else:
+ mylog('verbose', f'[SafeConditionBuilder] Invalid event type filtered out: {event_type}')
+
+ if not valid_types:
+ return "", {}
+
+ # Generate parameters for each valid event type
+ param_names = []
+ for event_type in valid_types:
+ param_name = self._generate_param_name('event_type')
+ self.parameters[param_name] = event_type
+ param_names.append(f":{param_name}")
+
+ sql_snippet = f"AND eve_EventType IN ({', '.join(param_names)})"
+ return sql_snippet, self.parameters
+
+ def get_safe_condition_legacy(self, condition_setting: str) -> Tuple[str, Dict[str, Any]]:
+ """
+ Convert legacy condition settings to safe parameterized queries.
+ This method provides backward compatibility for existing condition formats.
+
+ Args:
+ condition_setting: The condition string from settings
+
+ Returns:
+ Tuple of (safe_sql_snippet, parameters_dict)
+ """
+ if not condition_setting or not condition_setting.strip():
+ return "", {}
+
+ try:
+ return self.build_safe_condition(condition_setting)
+ except ValueError as e:
+ # Log the error and return empty condition for safety
+ mylog('verbose', f'[SafeConditionBuilder] Unsafe condition rejected: {condition_setting}, Error: {e}')
+ return "", {}
+
+
+def create_safe_condition_builder() -> SafeConditionBuilder:
+ """
+ Factory function to create a new SafeConditionBuilder instance.
+
+ Returns:
+ New SafeConditionBuilder instance
+ """
+ return SafeConditionBuilder()
\ No newline at end of file
diff --git a/server/messaging/reporting.py b/server/messaging/reporting.py
index 81694b29..d22bf6d0 100755
--- a/server/messaging/reporting.py
+++ b/server/messaging/reporting.py
@@ -22,6 +22,7 @@ import conf
from const import applicationPath, logPath, apiPath, confFileName
from helper import timeNowTZ, get_file_content, write_file, get_timezone_offset, get_setting_value
from logger import logResult, mylog
+from db.sql_safe_builder import create_safe_condition_builder
#===============================================================================
# REPORTING
@@ -70,18 +71,30 @@ def get_notifications (db):
if 'new_devices' in sections:
# Compose New Devices Section (no empty lines in SQL queries!)
- # Note: NTFPRCS_new_dev_condition should be validated/sanitized at the settings level
- # to prevent SQL injection. For now, we preserve existing functionality but flag the risk.
- new_dev_condition = get_setting_value('NTFPRCS_new_dev_condition').replace('{s-quote}',"'")
- sqlQuery = f"""SELECT eve_MAC as MAC, eve_DateTime as Datetime, devLastIP as IP, eve_EventType as "Event Type", devName as "Device name", devComments as Comments FROM Events_Devices
- WHERE eve_PendingAlertEmail = 1
- AND eve_EventType = 'New Device' {new_dev_condition}
- ORDER BY eve_DateTime"""
+ # Use SafeConditionBuilder to prevent SQL injection vulnerabilities
+ condition_builder = create_safe_condition_builder()
+ new_dev_condition_setting = get_setting_value('NTFPRCS_new_dev_condition')
+
+ try:
+ safe_condition, parameters = condition_builder.get_safe_condition_legacy(new_dev_condition_setting)
+ sqlQuery = """SELECT eve_MAC as MAC, eve_DateTime as Datetime, devLastIP as IP, eve_EventType as "Event Type", devName as "Device name", devComments as Comments FROM Events_Devices
+ WHERE eve_PendingAlertEmail = 1
+ AND eve_EventType = 'New Device' {}
+ ORDER BY eve_DateTime""".format(safe_condition)
+ except Exception as e:
+ mylog('verbose', ['[Notification] Error building safe condition for new devices: ', e])
+ # Fall back to safe default (no additional conditions)
+ sqlQuery = """SELECT eve_MAC as MAC, eve_DateTime as Datetime, devLastIP as IP, eve_EventType as "Event Type", devName as "Device name", devComments as Comments FROM Events_Devices
+ WHERE eve_PendingAlertEmail = 1
+ AND eve_EventType = 'New Device'
+ ORDER BY eve_DateTime"""
+ parameters = {}
mylog('debug', ['[Notification] new_devices SQL query: ', sqlQuery ])
+ mylog('debug', ['[Notification] new_devices parameters: ', parameters ])
- # Get the events as JSON
- json_obj = db.get_table_as_json(sqlQuery)
+ # Get the events as JSON using parameterized query
+ json_obj = db.get_table_as_json(sqlQuery, parameters)
json_new_devices_meta = {
"title": "🆕 New devices",
@@ -146,18 +159,30 @@ def get_notifications (db):
if 'events' in sections:
# Compose Events Section (no empty lines in SQL queries!)
- # Note: NTFPRCS_event_condition should be validated/sanitized at the settings level
- # to prevent SQL injection. For now, we preserve existing functionality but flag the risk.
- event_condition = get_setting_value('NTFPRCS_event_condition').replace('{s-quote}',"'")
- sqlQuery = f"""SELECT eve_MAC as MAC, eve_DateTime as Datetime, devLastIP as IP, eve_EventType as "Event Type", devName as "Device name", devComments as Comments FROM Events_Devices
- WHERE eve_PendingAlertEmail = 1
- AND eve_EventType IN ('Connected', 'Down Reconnected', 'Disconnected','IP Changed') {event_condition}
- ORDER BY eve_DateTime"""
+ # Use SafeConditionBuilder to prevent SQL injection vulnerabilities
+ condition_builder = create_safe_condition_builder()
+ event_condition_setting = get_setting_value('NTFPRCS_event_condition')
+
+ try:
+ safe_condition, parameters = condition_builder.get_safe_condition_legacy(event_condition_setting)
+ sqlQuery = """SELECT eve_MAC as MAC, eve_DateTime as Datetime, devLastIP as IP, eve_EventType as "Event Type", devName as "Device name", devComments as Comments FROM Events_Devices
+ WHERE eve_PendingAlertEmail = 1
+ AND eve_EventType IN ('Connected', 'Down Reconnected', 'Disconnected','IP Changed') {}
+ ORDER BY eve_DateTime""".format(safe_condition)
+ except Exception as e:
+ mylog('verbose', ['[Notification] Error building safe condition for events: ', e])
+ # Fall back to safe default (no additional conditions)
+ sqlQuery = """SELECT eve_MAC as MAC, eve_DateTime as Datetime, devLastIP as IP, eve_EventType as "Event Type", devName as "Device name", devComments as Comments FROM Events_Devices
+ WHERE eve_PendingAlertEmail = 1
+ AND eve_EventType IN ('Connected', 'Down Reconnected', 'Disconnected','IP Changed')
+ ORDER BY eve_DateTime"""
+ parameters = {}
mylog('debug', ['[Notification] events SQL query: ', sqlQuery ])
+ mylog('debug', ['[Notification] events parameters: ', parameters ])
- # Get the events as JSON
- json_obj = db.get_table_as_json(sqlQuery)
+ # Get the events as JSON using parameterized query
+ json_obj = db.get_table_as_json(sqlQuery, parameters)
json_events_meta = {
"title": "⚡ Events",
diff --git a/test/test_safe_builder_unit.py b/test/test_safe_builder_unit.py
new file mode 100644
index 00000000..356fdee1
--- /dev/null
+++ b/test/test_safe_builder_unit.py
@@ -0,0 +1,331 @@
+"""
+Unit tests for SafeConditionBuilder focusing on core security functionality.
+This test file has minimal dependencies to ensure it can run in any environment.
+"""
+
+import sys
+import unittest
+import re
+from unittest.mock import Mock, patch
+
+# Mock the logger module to avoid dependency issues
+sys.modules['logger'] = Mock()
+
+# Standalone version of SafeConditionBuilder for testing
+class TestSafeConditionBuilder:
+ """
+ Test version of SafeConditionBuilder with mock logger.
+ """
+
+ # Whitelist of allowed column names for filtering
+ ALLOWED_COLUMNS = {
+ 'eve_MAC', 'eve_DateTime', 'eve_IP', 'eve_EventType', 'devName',
+ 'devComments', 'devLastIP', 'devVendor', 'devAlertEvents',
+ 'devAlertDown', 'devIsArchived', 'devPresentLastScan', 'devFavorite',
+ 'devIsNew', 'Plugin', 'Object_PrimaryId', 'Object_SecondaryId',
+ 'DateTimeChanged', 'Watched_Value1', 'Watched_Value2', 'Watched_Value3',
+ 'Watched_Value4', 'Status'
+ }
+
+ # Whitelist of allowed comparison operators
+ ALLOWED_OPERATORS = {
+ '=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'NOT LIKE',
+ 'IN', 'NOT IN', 'IS NULL', 'IS NOT NULL'
+ }
+
+ # Whitelist of allowed logical operators
+ ALLOWED_LOGICAL_OPERATORS = {'AND', 'OR'}
+
+ # Whitelist of allowed event types
+ ALLOWED_EVENT_TYPES = {
+ 'New Device', 'Connected', 'Disconnected', 'Device Down',
+ 'Down Reconnected', 'IP Changed'
+ }
+
+ def __init__(self):
+ """Initialize the SafeConditionBuilder."""
+ self.parameters = {}
+ self.param_counter = 0
+
+ def _generate_param_name(self, prefix='param'):
+ """Generate a unique parameter name for SQL binding."""
+ self.param_counter += 1
+ return f"{prefix}_{self.param_counter}"
+
+ def _sanitize_string(self, value):
+ """Sanitize string input by removing potentially dangerous characters."""
+ if not isinstance(value, str):
+ return str(value)
+
+ # Replace {s-quote} placeholder with single quote (maintaining compatibility)
+ value = value.replace('{s-quote}', "'")
+
+ # Remove any null bytes, control characters, and excessive whitespace
+ value = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x84\x86-\x9f]', '', value)
+ value = re.sub(r'\s+', ' ', value.strip())
+
+ return value
+
+ def _validate_column_name(self, column):
+ """Validate that a column name is in the whitelist."""
+ return column in self.ALLOWED_COLUMNS
+
+ def _validate_operator(self, operator):
+ """Validate that an operator is in the whitelist."""
+ return operator.upper() in self.ALLOWED_OPERATORS
+
+ def _validate_logical_operator(self, logical_op):
+ """Validate that a logical operator is in the whitelist."""
+ return logical_op.upper() in self.ALLOWED_LOGICAL_OPERATORS
+
+ def build_safe_condition(self, condition_string):
+ """Parse and build a safe SQL condition from a user-provided string."""
+ if not condition_string or not condition_string.strip():
+ return "", {}
+
+ # Sanitize the input
+ condition_string = self._sanitize_string(condition_string)
+
+ # Reset parameters for this condition
+ self.parameters = {}
+ self.param_counter = 0
+
+ try:
+ return self._parse_condition(condition_string)
+ except Exception as e:
+ raise ValueError(f"Invalid condition format: {condition_string}")
+
+ def _parse_condition(self, condition):
+ """Parse a condition string into safe SQL with parameters."""
+ condition = condition.strip()
+
+ # Handle empty conditions
+ if not condition:
+ return "", {}
+
+ # Simple pattern matching for common conditions
+ # Pattern 1: AND/OR column operator value
+ pattern1 = r'^\s*(AND|OR)?\s+(\w+)\s+(=|!=|<>|<|>|<=|>=|LIKE|NOT\s+LIKE)\s+\'([^\']*)\'\s*$'
+ match1 = re.match(pattern1, condition, re.IGNORECASE)
+
+ if match1:
+ logical_op, column, operator, value = match1.groups()
+ return self._build_simple_condition(logical_op, column, operator, value)
+
+ # If no patterns match, reject the condition for security
+ raise ValueError(f"Unsupported condition pattern: {condition}")
+
+ def _build_simple_condition(self, logical_op, column, operator, value):
+ """Build a simple condition with parameter binding."""
+ # Validate components
+ if not self._validate_column_name(column):
+ raise ValueError(f"Invalid column name: {column}")
+
+ if not self._validate_operator(operator):
+ raise ValueError(f"Invalid operator: {operator}")
+
+ if logical_op and not self._validate_logical_operator(logical_op):
+ raise ValueError(f"Invalid logical operator: {logical_op}")
+
+ # Generate parameter name and store value
+ param_name = self._generate_param_name()
+ self.parameters[param_name] = value
+
+ # Build the SQL snippet
+ sql_parts = []
+ if logical_op:
+ sql_parts.append(logical_op.upper())
+
+ sql_parts.extend([column, operator.upper(), f":{param_name}"])
+
+ return " ".join(sql_parts), self.parameters
+
+ def get_safe_condition_legacy(self, condition_setting):
+ """Convert legacy condition settings to safe parameterized queries."""
+ if not condition_setting or not condition_setting.strip():
+ return "", {}
+
+ try:
+ return self.build_safe_condition(condition_setting)
+ except ValueError:
+ # Log the error and return empty condition for safety
+ return "", {}
+
+
+class TestSafeConditionBuilderSecurity(unittest.TestCase):
+ """Test cases for the SafeConditionBuilder security functionality."""
+
+ def setUp(self):
+ """Set up test fixtures before each test method."""
+ self.builder = TestSafeConditionBuilder()
+
+ def test_initialization(self):
+ """Test that SafeConditionBuilder initializes correctly."""
+ self.assertIsInstance(self.builder, TestSafeConditionBuilder)
+ self.assertEqual(self.builder.param_counter, 0)
+ self.assertEqual(self.builder.parameters, {})
+
+ def test_sanitize_string(self):
+ """Test string sanitization functionality."""
+ # Test normal string
+ result = self.builder._sanitize_string("normal string")
+ self.assertEqual(result, "normal string")
+
+ # Test s-quote replacement
+ result = self.builder._sanitize_string("test{s-quote}value")
+ self.assertEqual(result, "test'value")
+
+ # Test control character removal
+ result = self.builder._sanitize_string("test\x00\x01string")
+ self.assertEqual(result, "teststring")
+
+ # Test excessive whitespace
+ result = self.builder._sanitize_string(" test string ")
+ self.assertEqual(result, "test string")
+
+ def test_validate_column_name(self):
+ """Test column name validation against whitelist."""
+ # Valid columns
+ self.assertTrue(self.builder._validate_column_name('eve_MAC'))
+ self.assertTrue(self.builder._validate_column_name('devName'))
+ self.assertTrue(self.builder._validate_column_name('eve_EventType'))
+
+ # Invalid columns
+ self.assertFalse(self.builder._validate_column_name('malicious_column'))
+ self.assertFalse(self.builder._validate_column_name('drop_table'))
+ self.assertFalse(self.builder._validate_column_name('user_input'))
+
+ def test_validate_operator(self):
+ """Test operator validation against whitelist."""
+ # Valid operators
+ self.assertTrue(self.builder._validate_operator('='))
+ self.assertTrue(self.builder._validate_operator('LIKE'))
+ self.assertTrue(self.builder._validate_operator('IN'))
+
+ # Invalid operators
+ self.assertFalse(self.builder._validate_operator('UNION'))
+ self.assertFalse(self.builder._validate_operator('DROP'))
+ self.assertFalse(self.builder._validate_operator('EXEC'))
+
+ def test_build_simple_condition_valid(self):
+ """Test building valid simple conditions."""
+ sql, params = self.builder._build_simple_condition('AND', 'devName', '=', 'TestDevice')
+
+ self.assertIn('AND devName = :param_', sql)
+ self.assertEqual(len(params), 1)
+ self.assertIn('TestDevice', params.values())
+
+ def test_build_simple_condition_invalid_column(self):
+ """Test that invalid column names are rejected."""
+ with self.assertRaises(ValueError) as context:
+ self.builder._build_simple_condition('AND', 'invalid_column', '=', 'value')
+
+ self.assertIn('Invalid column name', str(context.exception))
+
+ def test_build_simple_condition_invalid_operator(self):
+ """Test that invalid operators are rejected."""
+ with self.assertRaises(ValueError) as context:
+ self.builder._build_simple_condition('AND', 'devName', 'UNION', 'value')
+
+ self.assertIn('Invalid operator', str(context.exception))
+
+ def test_sql_injection_attempts(self):
+ """Test that various SQL injection attempts are blocked."""
+ injection_attempts = [
+ "'; DROP TABLE Devices; --",
+ "' UNION SELECT * FROM Settings --",
+ "' OR 1=1 --",
+ "'; INSERT INTO Events VALUES(1,2,3); --",
+ "' AND (SELECT COUNT(*) FROM sqlite_master) > 0 --",
+ ]
+
+ for injection in injection_attempts:
+ with self.subTest(injection=injection):
+ with self.assertRaises(ValueError):
+ self.builder.build_safe_condition(f"AND devName = '{injection}'")
+
+ def test_legacy_condition_compatibility(self):
+ """Test backward compatibility with legacy condition formats."""
+ # Test simple condition
+ sql, params = self.builder.get_safe_condition_legacy("AND devName = 'TestDevice'")
+ self.assertIn('devName', sql)
+ self.assertIn('TestDevice', params.values())
+
+ # Test empty condition
+ sql, params = self.builder.get_safe_condition_legacy("")
+ self.assertEqual(sql, "")
+ self.assertEqual(params, {})
+
+ # Test invalid condition returns empty
+ sql, params = self.builder.get_safe_condition_legacy("INVALID SQL INJECTION")
+ self.assertEqual(sql, "")
+ self.assertEqual(params, {})
+
+ def test_parameter_generation(self):
+ """Test that parameters are generated correctly."""
+ # Test multiple parameters
+ sql1, params1 = self.builder.build_safe_condition("AND devName = 'Device1'")
+ sql2, params2 = self.builder.build_safe_condition("AND devName = 'Device2'")
+
+ # Each should have unique parameter names
+ self.assertNotEqual(list(params1.keys())[0], list(params2.keys())[0])
+
+ def test_xss_prevention(self):
+ """Test that XSS-like payloads in device names are handled safely."""
+ xss_payloads = [
+ "",
+ "javascript:alert(1)",
+ "
",
+ "'; DROP TABLE users; SELECT '' --"
+ ]
+
+ for payload in xss_payloads:
+ with self.subTest(payload=payload):
+ # Should either process safely or reject
+ try:
+ sql, params = self.builder.build_safe_condition(f"AND devName = '{payload}'")
+ # If processed, should be parameterized
+ self.assertIn(':', sql)
+ self.assertIn(payload, params.values())
+ except ValueError:
+ # Rejection is also acceptable for safety
+ pass
+
+ def test_unicode_handling(self):
+ """Test that Unicode characters are handled properly."""
+ unicode_strings = [
+ "Ülrich's Device",
+ "Café Network",
+ "测试设备",
+ "Устройство"
+ ]
+
+ for unicode_str in unicode_strings:
+ with self.subTest(unicode_str=unicode_str):
+ sql, params = self.builder.build_safe_condition(f"AND devName = '{unicode_str}'")
+ self.assertIn(unicode_str, params.values())
+
+ def test_edge_cases(self):
+ """Test edge cases and boundary conditions."""
+ edge_cases = [
+ "", # Empty string
+ " ", # Whitespace only
+ "AND devName = ''", # Empty value
+ "AND devName = 'a'", # Single character
+ "AND devName = '" + "x" * 1000 + "'", # Very long string
+ ]
+
+ for case in edge_cases:
+ with self.subTest(case=case):
+ try:
+ sql, params = self.builder.get_safe_condition_legacy(case)
+ # Should either return valid result or empty safe result
+ self.assertIsInstance(sql, str)
+ self.assertIsInstance(params, dict)
+ except Exception:
+ self.fail(f"Unexpected exception for edge case: {case}")
+
+
+if __name__ == '__main__':
+ # Run the test suite
+ unittest.main(verbosity=2)
\ No newline at end of file
diff --git a/test/test_sql_security.py b/test/test_sql_security.py
new file mode 100644
index 00000000..da505319
--- /dev/null
+++ b/test/test_sql_security.py
@@ -0,0 +1,381 @@
+"""
+NetAlertX SQL Security Test Suite
+
+This test suite validates the SQL injection prevention mechanisms
+implemented in the SafeConditionBuilder and reporting modules.
+
+Author: Security Enhancement for NetAlertX
+License: GNU GPLv3
+"""
+
+import sys
+import unittest
+import sqlite3
+import tempfile
+import os
+from unittest.mock import Mock, patch, MagicMock
+
+# Add the server directory to the path for imports
+INSTALL_PATH = "/app"
+sys.path.extend([f"{INSTALL_PATH}/server"])
+sys.path.append('/home/dell/coding/bash/10x-agentic-setup/netalertx-sql-fix/server')
+
+from db.sql_safe_builder import SafeConditionBuilder, create_safe_condition_builder
+from database import DB
+from messaging.reporting import get_notifications
+
+
+class TestSafeConditionBuilder(unittest.TestCase):
+ """Test cases for the SafeConditionBuilder class."""
+
+ def setUp(self):
+ """Set up test fixtures before each test method."""
+ self.builder = SafeConditionBuilder()
+
+ def test_initialization(self):
+ """Test that SafeConditionBuilder initializes correctly."""
+ self.assertIsInstance(self.builder, SafeConditionBuilder)
+ self.assertEqual(self.builder.param_counter, 0)
+ self.assertEqual(self.builder.parameters, {})
+
+ def test_sanitize_string(self):
+ """Test string sanitization functionality."""
+ # Test normal string
+ result = self.builder._sanitize_string("normal string")
+ self.assertEqual(result, "normal string")
+
+ # Test s-quote replacement
+ result = self.builder._sanitize_string("test{s-quote}value")
+ self.assertEqual(result, "test'value")
+
+ # Test control character removal
+ result = self.builder._sanitize_string("test\x00\x01string")
+ self.assertEqual(result, "teststring")
+
+ # Test excessive whitespace
+ result = self.builder._sanitize_string(" test string ")
+ self.assertEqual(result, "test string")
+
+ def test_validate_column_name(self):
+ """Test column name validation against whitelist."""
+ # Valid columns
+ self.assertTrue(self.builder._validate_column_name('eve_MAC'))
+ self.assertTrue(self.builder._validate_column_name('devName'))
+ self.assertTrue(self.builder._validate_column_name('eve_EventType'))
+
+ # Invalid columns
+ self.assertFalse(self.builder._validate_column_name('malicious_column'))
+ self.assertFalse(self.builder._validate_column_name('drop_table'))
+ self.assertFalse(self.builder._validate_column_name('\'; DROP TABLE users; --'))
+
+ def test_validate_operator(self):
+ """Test operator validation against whitelist."""
+ # Valid operators
+ self.assertTrue(self.builder._validate_operator('='))
+ self.assertTrue(self.builder._validate_operator('LIKE'))
+ self.assertTrue(self.builder._validate_operator('IN'))
+
+ # Invalid operators
+ self.assertFalse(self.builder._validate_operator('UNION'))
+ self.assertFalse(self.builder._validate_operator('; DROP'))
+ self.assertFalse(self.builder._validate_operator('EXEC'))
+
+ def test_build_simple_condition_valid(self):
+ """Test building valid simple conditions."""
+ sql, params = self.builder._build_simple_condition('AND', 'devName', '=', 'TestDevice')
+
+ self.assertIn('AND devName = :param_', sql)
+ self.assertEqual(len(params), 1)
+ self.assertIn('TestDevice', params.values())
+
+ def test_build_simple_condition_invalid_column(self):
+ """Test that invalid column names are rejected."""
+ with self.assertRaises(ValueError) as context:
+ self.builder._build_simple_condition('AND', 'invalid_column', '=', 'value')
+
+ self.assertIn('Invalid column name', str(context.exception))
+
+ def test_build_simple_condition_invalid_operator(self):
+ """Test that invalid operators are rejected."""
+ with self.assertRaises(ValueError) as context:
+ self.builder._build_simple_condition('AND', 'devName', 'UNION', 'value')
+
+ self.assertIn('Invalid operator', str(context.exception))
+
+ def test_build_in_condition_valid(self):
+ """Test building valid IN conditions."""
+ sql, params = self.builder._build_in_condition('AND', 'eve_EventType', 'IN', "'Connected', 'Disconnected'")
+
+ self.assertIn('AND eve_EventType IN', sql)
+ self.assertEqual(len(params), 2)
+ self.assertIn('Connected', params.values())
+ self.assertIn('Disconnected', params.values())
+
+ def test_build_null_condition(self):
+ """Test building NULL check conditions."""
+ sql, params = self.builder._build_null_condition('AND', 'devComments', 'IS NULL')
+
+ self.assertEqual(sql, 'AND devComments IS NULL')
+ self.assertEqual(len(params), 0)
+
+ def test_sql_injection_attempts(self):
+ """Test that various SQL injection attempts are blocked."""
+ injection_attempts = [
+ "'; DROP TABLE Devices; --",
+ "' UNION SELECT * FROM Settings --",
+ "' OR 1=1 --",
+ "'; INSERT INTO Events VALUES(1,2,3); --",
+ "' AND (SELECT COUNT(*) FROM sqlite_master) > 0 --",
+ "'; ATTACH DATABASE '/etc/passwd' AS pwn; --"
+ ]
+
+ for injection in injection_attempts:
+ with self.subTest(injection=injection):
+ with self.assertRaises(ValueError):
+ self.builder.build_safe_condition(f"AND devName = '{injection}'")
+
+ def test_legacy_condition_compatibility(self):
+ """Test backward compatibility with legacy condition formats."""
+ # Test simple condition
+ sql, params = self.builder.get_safe_condition_legacy("AND devName = 'TestDevice'")
+ self.assertIn('devName', sql)
+ self.assertIn('TestDevice', params.values())
+
+ # Test empty condition
+ sql, params = self.builder.get_safe_condition_legacy("")
+ self.assertEqual(sql, "")
+ self.assertEqual(params, {})
+
+ # Test invalid condition returns empty
+ sql, params = self.builder.get_safe_condition_legacy("INVALID SQL INJECTION")
+ self.assertEqual(sql, "")
+ self.assertEqual(params, {})
+
+ def test_device_name_filter(self):
+ """Test the device name filter helper method."""
+ sql, params = self.builder.build_device_name_filter("TestDevice")
+
+ self.assertIn('AND devName = :device_name_', sql)
+ self.assertIn('TestDevice', params.values())
+
+ def test_event_type_filter(self):
+ """Test the event type filter helper method."""
+ event_types = ['Connected', 'Disconnected']
+ sql, params = self.builder.build_event_type_filter(event_types)
+
+ self.assertIn('AND eve_EventType IN', sql)
+ self.assertEqual(len(params), 2)
+ self.assertIn('Connected', params.values())
+ self.assertIn('Disconnected', params.values())
+
+ def test_event_type_filter_whitelist(self):
+ """Test that event type filter enforces whitelist."""
+ # Valid event types
+ valid_types = ['Connected', 'New Device']
+ sql, params = self.builder.build_event_type_filter(valid_types)
+ self.assertEqual(len(params), 2)
+
+ # Mix of valid and invalid event types
+ mixed_types = ['Connected', 'InvalidEventType', 'Device Down']
+ sql, params = self.builder.build_event_type_filter(mixed_types)
+ self.assertEqual(len(params), 2) # Only valid types should be included
+
+ # All invalid event types
+ invalid_types = ['InvalidType1', 'InvalidType2']
+ sql, params = self.builder.build_event_type_filter(invalid_types)
+ self.assertEqual(sql, "")
+ self.assertEqual(params, {})
+
+
+class TestDatabaseParameterSupport(unittest.TestCase):
+ """Test that database layer supports parameterized queries."""
+
+ def setUp(self):
+ """Set up test database."""
+ self.temp_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db')
+ self.temp_db.close()
+
+ # Create test database
+ self.conn = sqlite3.connect(self.temp_db.name)
+ self.conn.execute('''CREATE TABLE test_table (
+ id INTEGER PRIMARY KEY,
+ name TEXT,
+ value TEXT
+ )''')
+ self.conn.execute("INSERT INTO test_table (name, value) VALUES ('test1', 'value1')")
+ self.conn.execute("INSERT INTO test_table (name, value) VALUES ('test2', 'value2')")
+ self.conn.commit()
+
+ def tearDown(self):
+ """Clean up test database."""
+ self.conn.close()
+ os.unlink(self.temp_db.name)
+
+ def test_parameterized_query_execution(self):
+ """Test that parameterized queries work correctly."""
+ cursor = self.conn.cursor()
+
+ # Test named parameters
+ cursor.execute("SELECT * FROM test_table WHERE name = :name", {'name': 'test1'})
+ results = cursor.fetchall()
+
+ self.assertEqual(len(results), 1)
+ self.assertEqual(results[0][1], 'test1')
+
+ def test_parameterized_query_prevents_injection(self):
+ """Test that parameterized queries prevent SQL injection."""
+ cursor = self.conn.cursor()
+
+ # This should not cause SQL injection
+ malicious_input = "'; DROP TABLE test_table; --"
+ cursor.execute("SELECT * FROM test_table WHERE name = :name", {'name': malicious_input})
+ results = cursor.fetchall()
+
+ # The table should still exist and be queryable
+ cursor.execute("SELECT COUNT(*) FROM test_table")
+ count = cursor.fetchone()[0]
+ self.assertEqual(count, 2) # Original data should still be there
+
+
+class TestReportingSecurityIntegration(unittest.TestCase):
+ """Integration tests for the secure reporting functionality."""
+
+ def setUp(self):
+ """Set up test environment for reporting tests."""
+ self.mock_db = Mock()
+ self.mock_db.sql = Mock()
+ self.mock_db.get_table_as_json = Mock()
+
+ # Mock successful JSON response
+ mock_json_obj = Mock()
+ mock_json_obj.columnNames = ['MAC', 'Datetime', 'IP', 'Event Type', 'Device name', 'Comments']
+ mock_json_obj.json = {'data': []}
+ self.mock_db.get_table_as_json.return_value = mock_json_obj
+
+ @patch('messaging.reporting.get_setting_value')
+ def test_new_devices_section_security(self, mock_get_setting):
+ """Test that new devices section uses safe SQL building."""
+ # Mock settings
+ mock_get_setting.side_effect = lambda key: {
+ 'NTFPRCS_INCLUDED_SECTIONS': ['new_devices'],
+ 'NTFPRCS_new_dev_condition': "AND devName = 'TestDevice'"
+ }.get(key, '')
+
+ # Call the function
+ result = get_notifications(self.mock_db)
+
+ # Verify that get_table_as_json was called with parameters
+ self.mock_db.get_table_as_json.assert_called()
+ call_args = self.mock_db.get_table_as_json.call_args
+
+ # Should have been called with both query and parameters
+ self.assertEqual(len(call_args[0]), 1) # Query argument
+ self.assertEqual(len(call_args[1]), 1) # Parameters keyword argument
+
+ @patch('messaging.reporting.get_setting_value')
+ def test_events_section_security(self, mock_get_setting):
+ """Test that events section uses safe SQL building."""
+ # Mock settings
+ mock_get_setting.side_effect = lambda key: {
+ 'NTFPRCS_INCLUDED_SECTIONS': ['events'],
+ 'NTFPRCS_event_condition': "AND devName = 'TestDevice'"
+ }.get(key, '')
+
+ # Call the function
+ result = get_notifications(self.mock_db)
+
+ # Verify that get_table_as_json was called with parameters
+ self.mock_db.get_table_as_json.assert_called()
+
+ @patch('messaging.reporting.get_setting_value')
+ def test_malicious_condition_handling(self, mock_get_setting):
+ """Test that malicious conditions are safely handled."""
+ # Mock settings with malicious input
+ mock_get_setting.side_effect = lambda key: {
+ 'NTFPRCS_INCLUDED_SECTIONS': ['new_devices'],
+ 'NTFPRCS_new_dev_condition': "'; DROP TABLE Devices; --"
+ }.get(key, '')
+
+ # Call the function - should not raise an exception
+ result = get_notifications(self.mock_db)
+
+ # Should still call get_table_as_json (with safe fallback query)
+ self.mock_db.get_table_as_json.assert_called()
+
+ @patch('messaging.reporting.get_setting_value')
+ def test_empty_condition_handling(self, mock_get_setting):
+ """Test that empty conditions are handled gracefully."""
+ # Mock settings with empty condition
+ mock_get_setting.side_effect = lambda key: {
+ 'NTFPRCS_INCLUDED_SECTIONS': ['new_devices'],
+ 'NTFPRCS_new_dev_condition': ""
+ }.get(key, '')
+
+ # Call the function
+ result = get_notifications(self.mock_db)
+
+ # Should call get_table_as_json
+ self.mock_db.get_table_as_json.assert_called()
+
+
+class TestSecurityBenchmarks(unittest.TestCase):
+ """Performance and security benchmark tests."""
+
+ def setUp(self):
+ """Set up benchmark environment."""
+ self.builder = SafeConditionBuilder()
+
+ def test_performance_simple_condition(self):
+ """Test performance of simple condition building."""
+ import time
+
+ start_time = time.time()
+ for _ in range(1000):
+ sql, params = self.builder.build_safe_condition("AND devName = 'TestDevice'")
+ end_time = time.time()
+
+ execution_time = end_time - start_time
+ self.assertLess(execution_time, 1.0, "Simple condition building should be fast")
+
+ def test_memory_usage_parameter_generation(self):
+ """Test memory usage of parameter generation."""
+ import psutil
+ import os
+
+ process = psutil.Process(os.getpid())
+ initial_memory = process.memory_info().rss
+
+ # Generate many conditions
+ for i in range(100):
+ builder = SafeConditionBuilder()
+ sql, params = builder.build_safe_condition(f"AND devName = 'Device{i}'")
+
+ final_memory = process.memory_info().rss
+ memory_increase = final_memory - initial_memory
+
+ # Memory increase should be reasonable (less than 10MB)
+ self.assertLess(memory_increase, 10 * 1024 * 1024, "Memory usage should be reasonable")
+
+ def test_pattern_coverage(self):
+ """Test coverage of condition patterns."""
+ patterns_tested = [
+ "AND devName = 'value'",
+ "OR eve_EventType LIKE '%test%'",
+ "AND devComments IS NULL",
+ "AND eve_EventType IN ('Connected', 'Disconnected')",
+ ]
+
+ for pattern in patterns_tested:
+ with self.subTest(pattern=pattern):
+ try:
+ sql, params = self.builder.build_safe_condition(pattern)
+ self.assertIsInstance(sql, str)
+ self.assertIsInstance(params, dict)
+ except ValueError:
+ # Some patterns might be rejected, which is acceptable
+ pass
+
+
+if __name__ == '__main__':
+ # Run the test suite
+ unittest.main(verbosity=2)
\ No newline at end of file