Files
NetAlertX/test/test_device_atomicity.py

232 lines
8.5 KiB
Python

"""
Test for atomicity of device updates with source-tracking.
Verifies that:
1. If source-tracking fails, the device row is rolled back.
2. If source-tracking succeeds, device row and sources are both committed.
3. Database remains consistent in both scenarios.
"""
import sys
import os
import sqlite3
import tempfile
import unittest
from unittest.mock import patch
# Add server and plugins to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'server'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'front', 'plugins'))
from models.device_instance import DeviceInstance # noqa: E402 [flake8 lint suppression]
from plugin_helper import normalize_mac # noqa: E402 [flake8 lint suppression]
class TestDeviceAtomicity(unittest.TestCase):
"""Test atomic transactions for device updates with source-tracking."""
def setUp(self):
"""Create an in-memory SQLite DB for testing."""
self.test_db = tempfile.NamedTemporaryFile(delete=False, suffix='.db')
self.test_db_path = self.test_db.name
self.test_db.close()
# Create minimal schema
conn = sqlite3.connect(self.test_db_path)
conn.row_factory = sqlite3.Row
cur = conn.cursor()
# Create Devices table with source-tracking columns
cur.execute("""
CREATE TABLE Devices (
devMac TEXT PRIMARY KEY,
devName TEXT,
devOwner TEXT,
devType TEXT,
devVendor TEXT,
devIcon TEXT,
devFavorite INTEGER DEFAULT 0,
devGroup TEXT,
devLocation TEXT,
devComments TEXT,
devParentMAC TEXT,
devParentPort TEXT,
devSSID TEXT,
devSite TEXT,
devStaticIP INTEGER DEFAULT 0,
devScan INTEGER DEFAULT 0,
devAlertEvents INTEGER DEFAULT 0,
devAlertDown INTEGER DEFAULT 0,
devParentRelType TEXT DEFAULT 'default',
devReqNicsOnline INTEGER DEFAULT 0,
devSkipRepeated INTEGER DEFAULT 0,
devIsNew INTEGER DEFAULT 0,
devIsArchived INTEGER DEFAULT 0,
devLastConnection TEXT,
devFirstConnection TEXT,
devLastIP TEXT,
devGUID TEXT,
devCustomProps TEXT,
devSourcePlugin TEXT,
devNameSource TEXT,
devTypeSource TEXT,
devVendorSource TEXT,
devIconSource TEXT,
devGroupSource TEXT,
devLocationSource TEXT,
devCommentsSource TEXT,
devMacSource TEXT
)
""")
conn.commit()
conn.close()
def tearDown(self):
"""Clean up test database."""
if os.path.exists(self.test_db_path):
os.unlink(self.test_db_path)
def _get_test_db_connection(self):
"""Override database connection for testing."""
conn = sqlite3.connect(self.test_db_path)
conn.row_factory = sqlite3.Row
return conn
def test_create_new_device_atomicity(self):
"""
Test that device creation and source-tracking are atomic.
If source tracking fails, the device should not be created.
"""
device_instance = DeviceInstance()
test_mac = normalize_mac("aa:bb:cc:dd:ee:ff")
# Patch at module level where it's used
with patch('models.device_instance.get_temp_db_connection', self._get_test_db_connection):
# Create a new device
data = {
"createNew": True,
"devMac": test_mac,
"devName": "Test Device",
"devOwner": "John Doe",
"devType": "Laptop",
}
result = device_instance.setDeviceData(test_mac, data)
# Verify success
self.assertTrue(result["success"], f"Device creation failed: {result}")
# Verify device exists
conn = self._get_test_db_connection()
cur = conn.cursor()
cur.execute("SELECT * FROM Devices WHERE devMac = ?", (test_mac,))
device = cur.fetchone()
conn.close()
self.assertIsNotNone(device, "Device was not created")
self.assertEqual(device["devName"], "Test Device")
# Verify source tracking was set
self.assertEqual(device["devMacSource"], "NEWDEV")
self.assertEqual(device["devNameSource"], "NEWDEV")
def test_update_device_with_source_tracking_atomicity(self):
"""
Test that device update and source-tracking are atomic.
If source tracking fails, the device update should be rolled back.
"""
device_instance = DeviceInstance()
test_mac = normalize_mac("aa:bb:cc:dd:ee:ff")
# Create initial device
conn = self._get_test_db_connection()
cur = conn.cursor()
cur.execute("""
INSERT INTO Devices (
devMac, devName, devOwner, devType,
devNameSource, devTypeSource
) VALUES (?, ?, ?, ?, ?, ?)
""", (test_mac, "Old Name", "Old Owner", "Desktop", "PLUGIN", "PLUGIN"))
conn.commit()
conn.close()
# Patch database connection
with patch('models.device_instance.get_temp_db_connection', self._get_test_db_connection):
with patch('models.device_instance.enforce_source_on_user_update') as mock_enforce:
mock_enforce.return_value = None
data = {
"createNew": False,
"devMac": test_mac,
"devName": "New Name",
"devOwner": "New Owner",
}
result = device_instance.setDeviceData(test_mac, data)
# Verify success
self.assertTrue(result["success"], f"Device update failed: {result}")
# Verify device was updated
conn = self._get_test_db_connection()
cur = conn.cursor()
cur.execute("SELECT * FROM Devices WHERE devMac = ?", (test_mac,))
device = cur.fetchone()
conn.close()
self.assertEqual(device["devName"], "New Name")
self.assertEqual(device["devOwner"], "New Owner")
def test_source_tracking_failure_rolls_back_device(self):
"""
Test that if enforce_source_on_user_update fails, the entire
transaction is rolled back (device and sources).
"""
device_instance = DeviceInstance()
test_mac = normalize_mac("aa:bb:cc:dd:ee:ff")
# Create initial device
conn = self._get_test_db_connection()
cur = conn.cursor()
cur.execute("""
INSERT INTO Devices (
devMac, devName, devOwner, devType,
devNameSource, devTypeSource
) VALUES (?, ?, ?, ?, ?, ?)
""", (test_mac, "Original Name", "Original Owner", "Desktop", "PLUGIN", "PLUGIN"))
conn.commit()
conn.close()
# Patch database connection and mock source enforcement failure
with patch('models.device_instance.get_temp_db_connection', self._get_test_db_connection):
with patch('models.device_instance.enforce_source_on_user_update') as mock_enforce:
# Simulate source tracking failure
mock_enforce.side_effect = Exception("Source tracking error")
data = {
"createNew": False,
"devMac": test_mac,
"devName": "Failed Update",
"devOwner": "Failed Owner",
}
result = device_instance.setDeviceData(test_mac, data)
# Verify error response
self.assertFalse(result["success"])
self.assertIn("Source tracking failed", result["error"])
# Verify device was NOT updated (rollback successful)
conn = self._get_test_db_connection()
cur = conn.cursor()
cur.execute("SELECT * FROM Devices WHERE devMac = ?", (test_mac,))
device = cur.fetchone()
conn.close()
self.assertEqual(device["devName"], "Original Name", "Device should not have been updated on source tracking failure")
self.assertEqual(device["devOwner"], "Original Owner", "Device should not have been updated on source tracking failure")
if __name__ == "__main__":
unittest.main()