refactor: use DeviceInstance model instead of direct SQLite query

Replaces the raw sqlite3 query in get_netalertx_devices() with
DeviceInstance().getAll() as suggested in code review, applying the
archived/offline/new filters in Python. Removes the sqlite3 and
fullDbPath imports. Updates tests to mock DeviceInstance.getAll().

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Nathan Jacobson
2026-05-26 22:03:14 -04:00
parent ca7a699ce3
commit 93e534cef5
2 changed files with 64 additions and 110 deletions

View File

@@ -16,17 +16,17 @@ import sys
import json
import requests
from pytz import timezone
import sqlite3
from typing import Dict, List, Optional, Set, Tuple
# Define the installation path and extend the system path for plugin imports
INSTALL_PATH = os.getenv('NETALERTX_APP', '/app')
sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"])
from const import dataPath, logPath, fullDbPath # noqa: E402, E261
from const import dataPath, logPath # noqa: E402, E261
from plugin_helper import Plugin_Objects # noqa: E402, E261
from logger import mylog, Logger # noqa: E402, E261
from helper import get_setting_value # noqa: E402, E261
from models.device_instance import DeviceInstance # noqa: E402, E261
import conf # noqa: E402, E261
# ----------------------------
@@ -164,57 +164,35 @@ class AdGuardClient:
# ---------------------------------------------------------------------------
# Database helpers
# ---------------------------------------------------------------------------
def get_netalertx_devices(db_path: str, include_offline: bool, include_new: bool) -> List[dict]:
def get_netalertx_devices(include_offline: bool, include_new: bool) -> List[dict]:
"""
Query NetAlertX's Devices table and return a list of dicts with the
fields we care about: mac, name, last_ip, dev_type
Return filtered devices from NetAlertX using the DeviceInstance model.
Fields returned per device: mac, name, last_ip, dev_type
"""
devices = []
conn = None
try:
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
cur = conn.cursor()
clauses = ["devIsArchived = 0"]
if not include_offline:
clauses.append("devPresentLastScan = 1")
if not include_new:
clauses.append("devIsNew = 0")
where = "WHERE " + " AND ".join(clauses)
cur.execute(
f"""
SELECT devMac AS mac,
devName AS name,
devLastIP AS last_ip,
devType AS dev_type
FROM Devices
{where}
ORDER BY devMac
"""
)
for row in cur.fetchall():
mac = (row["mac"] or "").strip()
name = (row["name"] or "").strip()
last_ip = (row["last_ip"] or "").strip()
dev_type = (row["dev_type"] or "").strip()
# Skip completely empty rows
if not mac and not last_ip:
for d in DeviceInstance().getAll():
if d.get("devIsArchived", 0):
continue
if not include_offline and not d.get("devPresentLastScan", 1):
continue
if not include_new and d.get("devIsNew", 0):
continue
# Fall back to MAC as name when no friendly name is set
mac = (d.get("devMac", "") or "").strip()
last_ip = (d.get("devLastIP", "") or "").strip()
name = (d.get("devName", "") or "").strip()
dev_type = (d.get("devType", "") or "").strip()
if not mac and not last_ip:
continue
if not name:
name = mac or last_ip
devices.append({"mac": mac, "name": name, "last_ip": last_ip, "dev_type": dev_type})
except sqlite3.Error as exc:
mylog("verbose", [f"[{pluginName}] ERROR reading NetAlertX database: {exc}"])
finally:
if conn:
conn.close()
except Exception as exc:
mylog("verbose", [f"[{pluginName}] ERROR reading devices: {exc}"])
return devices
@@ -408,7 +386,7 @@ def main():
# ------------------------------------------------------------------
# Load devices from NetAlertX
# ------------------------------------------------------------------
devices = get_netalertx_devices(fullDbPath, include_offline, include_new)
devices = get_netalertx_devices(include_offline, include_new)
mylog("verbose", [f"[{pluginName}] Loaded {len(devices)} device(s) from NetAlertX database."])
if not devices:

View File

@@ -10,7 +10,6 @@ automatically before the script is imported.
import json
import os
import sqlite3
import sys
import tempfile
import types
@@ -40,6 +39,8 @@ _stub("const", dataPath=_tmp_log, logPath=_tmp_log, fullDbPath=os.path.join(_tmp
_stub("plugin_helper", Plugin_Objects=MagicMock)
_stub("logger", mylog=lambda *a: None, Logger=MagicMock)
_stub("helper", get_setting_value=lambda k: "")
_stub("models", )
_stub("models.device_instance", DeviceInstance=MagicMock)
# Stub requests only when it isn't installed (e.g. bare system Python locally).
# In the container and CI, the real package is present and will be used.
@@ -74,37 +75,18 @@ from script import ( # noqa: E402
# Helpers
# ---------------------------------------------------------------------------
def _make_db(path: str, rows: list[dict]) -> None:
"""Create a minimal Devices table and populate it with *rows*."""
conn = sqlite3.connect(path)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS Devices (
devMac TEXT,
devName TEXT,
devLastIP TEXT,
devType TEXT,
devIsArchived INTEGER DEFAULT 0,
devPresentLastScan INTEGER DEFAULT 1,
devIsNew INTEGER DEFAULT 0
)
"""
)
for row in rows:
conn.execute(
"INSERT INTO Devices VALUES (?,?,?,?,?,?,?)",
(
row.get("devMac", ""),
row.get("devName", ""),
row.get("devLastIP", ""),
row.get("devType", ""),
row.get("devIsArchived", 0),
row.get("devPresentLastScan", 1),
row.get("devIsNew", 0),
),
)
conn.commit()
conn.close()
def _raw_device(**overrides) -> dict:
"""Build a raw DeviceInstance.getAll() style dict."""
base = {
"devMac": "AA:BB:CC:00:00:01",
"devName": "PC",
"devLastIP": "10.0.0.1",
"devType": "desktop",
"devIsArchived": 0,
"devPresentLastScan": 1,
"devIsNew": 0,
}
return {**base, **overrides}
def _mock_agrd(existing=None) -> MagicMock:
@@ -229,59 +211,53 @@ class TestManagedNames:
class TestGetNetalertxDevices:
def test_basic_query(self, tmp_path):
db = str(tmp_path / "na.db")
_make_db(db, [{"devMac": "AA:BB:CC:00:00:01", "devName": "PC", "devLastIP": "10.0.0.1", "devType": "desktop"}])
result = get_netalertx_devices(db, include_offline=True, include_new=True)
def _call(self, rows, include_offline=True, include_new=True):
with patch("script.DeviceInstance") as mock_di:
mock_di.return_value.getAll.return_value = rows
return get_netalertx_devices(include_offline=include_offline, include_new=include_new)
def test_basic_query(self):
result = self._call([_raw_device()])
assert len(result) == 1
assert result[0]["name"] == "PC"
assert result[0]["mac"] == "AA:BB:CC:00:00:01"
def test_archived_devices_excluded(self, tmp_path):
db = str(tmp_path / "na.db")
_make_db(db, [
{"devMac": "AA:00:00:00:00:01", "devName": "Active", "devLastIP": "10.0.0.1", "devIsArchived": 0},
{"devMac": "AA:00:00:00:00:02", "devName": "Archived", "devLastIP": "10.0.0.2", "devIsArchived": 1},
def test_archived_devices_excluded(self):
result = self._call([
_raw_device(devMac="AA:00:00:00:00:01", devName="Active", devIsArchived=0),
_raw_device(devMac="AA:00:00:00:00:02", devName="Archived", devIsArchived=1),
])
result = get_netalertx_devices(db, include_offline=True, include_new=True)
assert len(result) == 1
assert result[0]["name"] == "Active"
def test_offline_excluded_when_flag_false(self, tmp_path):
db = str(tmp_path / "na.db")
_make_db(db, [
{"devMac": "AA:00:00:00:00:01", "devName": "Online", "devLastIP": "10.0.0.1", "devPresentLastScan": 1},
{"devMac": "AA:00:00:00:00:02", "devName": "Offline", "devLastIP": "10.0.0.2", "devPresentLastScan": 0},
])
result = get_netalertx_devices(db, include_offline=False, include_new=True)
def test_offline_excluded_when_flag_false(self):
result = self._call([
_raw_device(devMac="AA:00:00:00:00:01", devName="Online", devPresentLastScan=1),
_raw_device(devMac="AA:00:00:00:00:02", devName="Offline", devPresentLastScan=0),
], include_offline=False)
assert len(result) == 1
assert result[0]["name"] == "Online"
def test_new_devices_excluded_when_flag_false(self, tmp_path):
db = str(tmp_path / "na.db")
_make_db(db, [
{"devMac": "AA:00:00:00:00:01", "devName": "Known", "devLastIP": "10.0.0.1", "devIsNew": 0},
{"devMac": "AA:00:00:00:00:02", "devName": "Unknown", "devLastIP": "10.0.0.2", "devIsNew": 1},
])
result = get_netalertx_devices(db, include_offline=True, include_new=False)
def test_new_devices_excluded_when_flag_false(self):
result = self._call([
_raw_device(devMac="AA:00:00:00:00:01", devName="Known", devIsNew=0),
_raw_device(devMac="AA:00:00:00:00:02", devName="Unknown", devIsNew=1),
], include_new=False)
assert len(result) == 1
assert result[0]["name"] == "Known"
def test_nameless_device_falls_back_to_mac(self, tmp_path):
db = str(tmp_path / "na.db")
_make_db(db, [{"devMac": "BB:CC:DD:EE:FF:00", "devName": "", "devLastIP": "10.0.0.5"}])
result = get_netalertx_devices(db, include_offline=True, include_new=True)
def test_nameless_device_falls_back_to_mac(self):
result = self._call([_raw_device(devMac="BB:CC:DD:EE:FF:00", devName="", devLastIP="10.0.0.5")])
assert result[0]["name"] == "BB:CC:DD:EE:FF:00"
def test_row_with_no_mac_and_no_ip_skipped(self, tmp_path):
db = str(tmp_path / "na.db")
_make_db(db, [{"devMac": "", "devName": "Ghost", "devLastIP": ""}])
result = get_netalertx_devices(db, include_offline=True, include_new=True)
def test_row_with_no_mac_and_no_ip_skipped(self):
result = self._call([_raw_device(devMac="", devName="Ghost", devLastIP="")])
assert result == []
def test_missing_db_returns_empty_list(self, tmp_path):
result = get_netalertx_devices(str(tmp_path / "missing.db"), True, True)
assert result == []
def test_exception_returns_empty_list(self):
with patch("script.DeviceInstance") as mock_di:
mock_di.return_value.getAll.side_effect = Exception("db error")
assert get_netalertx_devices(True, True) == []
# ---------------------------------------------------------------------------