From 962bbaa5a10c808cc9c3a78200f99adf0198a904 Mon Sep 17 00:00:00 2001 From: jokob-sk Date: Tue, 19 Aug 2025 07:56:54 +1000 Subject: [PATCH] api layer v0.2.2 - CSV import/export, refactor --- front/php/server/devices.php | 44 +---- server/api_server/api_server_start.py | 69 +++++-- server/api_server/device_endpoint.py | 64 +++++- server/api_server/devices_endpoint.py | 126 +++++++++++- server/api_server/events_endpoint.py | 15 +- server/api_server/history_endpoint.py | 2 +- server/database.py | 48 +++-- server/db/db_helper.py | 269 ++++++++++++++++++++++++++ server/helper.py | 239 ++++++----------------- server/initialise.py | 2 +- server/scan/device_handling.py | 3 +- server/scan/session_events.py | 3 +- test/test_device_endpoints.py | 53 +++++ test/test_devices_endpoints.py | 85 +++++++- 14 files changed, 738 insertions(+), 284 deletions(-) create mode 100755 server/db/db_helper.py diff --git a/front/php/server/devices.php b/front/php/server/devices.php index 4b1dc550..fb4106e6 100755 --- a/front/php/server/devices.php +++ b/front/php/server/devices.php @@ -38,17 +38,16 @@ case 'deleteActHistory': deleteActHistory(); break; case 'deleteDeviceEvents': deleteDeviceEvents(); break; case 'resetDeviceProps': resetDeviceProps(); break; - case 'ExportCSV': ExportCSV(); break; - case 'ImportCSV': ImportCSV(); break; + case 'ExportCSV': ExportCSV(); break; // todo + case 'ImportCSV': ImportCSV(); break; // todo - case 'getDevicesTotals': getDevicesTotals(); break; - case 'getDevicesListCalendar': getDevicesListCalendar(); break; //todo: slowly deprecate this + case 'getDevicesTotals': getDevicesTotals(); break; // todo + case 'getDevicesListCalendar': getDevicesListCalendar(); break; // todo - case 'updateNetworkLeaf': updateNetworkLeaf(); break; + case 'updateNetworkLeaf': updateNetworkLeaf(); break; // todo - case 'getDevices': getDevices(); break; case 'copyFromDevice': copyFromDevice(); break; - case 'wakeonlan': wakeonlan(); break; + case 'wakeonlan': wakeonlan(); break; // todo default: logServerConsole ('Action: '. $action); break; } @@ -737,37 +736,6 @@ function getDevicesListCalendar() { // Query Device Data //------------------------------------------------------------------------------ - -//------------------------------------------------------------------------------ -function getDevices() { - - global $db; - - // Device Data - $sql = 'select devMac, devName from Devices'; - - $result = $db->query($sql); - - // arrays of rows - $tableData = array(); - - while ($row = $result -> fetchArray (SQLITE3_ASSOC)) { - $name = handleNull($row['devName'], "(unknown)"); - $mac = handleNull($row['devMac'], "(unknown)"); - // Push row data - $tableData[] = array('id' => $mac, - 'name' => $name ); - } - - // Control no rows - if (empty($tableData)) { - $tableData = []; - } - - // Return json - echo (json_encode ($tableData)); -} - // ---------------------------------------------------------------------------------------- function updateNetworkLeaf() { diff --git a/server/api_server/api_server_start.py b/server/api_server/api_server_start.py index 8981b9dc..6c18746d 100755 --- a/server/api_server/api_server_start.py +++ b/server/api_server/api_server_start.py @@ -2,9 +2,9 @@ import threading from flask import Flask, request, jsonify, Response from flask_cors import CORS from .graphql_endpoint import devicesSchema -from .device_endpoint import get_device_data, set_device_data, delete_device, delete_device_events, reset_device_props -from .devices_endpoint import delete_unknown_devices, delete_all_with_empty_macs, delete_devices -from .events_endpoint import delete_device_events, delete_events, delete_events_30, get_events +from .device_endpoint import get_device_data, set_device_data, delete_device, delete_device_events, reset_device_props, copy_device, update_device_column +from .devices_endpoint import delete_unknown_devices, delete_all_with_empty_macs, delete_devices, export_devices, import_csv +from .events_endpoint import delete_events, delete_events_30, get_events from .history_endpoint import delete_online_history from .prometheus_endpoint import getMetricStats from .sync_endpoint import handle_sync_post, handle_sync_get @@ -97,6 +97,34 @@ def api_reset_device_props(mac): return jsonify({"error": "Forbidden"}), 403 return reset_device_props(mac, request.json) +@app.route("/device/copy", methods=["POST"]) +def api_copy_device(): + if not is_authorized(): + return jsonify({"error": "Forbidden"}), 403 + + data = request.get_json() or {} + mac_from = data.get("macFrom") + mac_to = data.get("macTo") + + if not mac_from or not mac_to: + return jsonify({"success": False, "error": "macFrom and macTo are required"}), 400 + + return copy_device(mac_from, mac_to) + +@app.route("/device//update-column", methods=["POST"]) +def api_update_device_column(mac): + if not is_authorized(): + return jsonify({"error": "Forbidden"}), 403 + + data = request.get_json() or {} + column_name = data.get("columnName") + column_value = data.get("columnValue") + + if not column_name or not column_value: + return jsonify({"success": False, "error": "columnName and columnValue are required"}), 400 + + return update_device_column(mac, column_name, column_value) + # -------------------------- # Devices Collections # -------------------------- @@ -129,6 +157,21 @@ def api_get_devices_totals(): return get_devices_totals() +@app.route("/devices/export", methods=["GET"]) +@app.route("/devices/export/", methods=["GET"]) +def api_export_devices(format=None): + if not is_authorized(): + return jsonify({"error": "Forbidden"}), 403 + + export_format = (format or request.args.get("format", "csv")).lower() + return export_devices(export_format) + +@app.route("/devices/import", methods=["POST"]) +def api_import_csv(): + if not is_authorized(): + return jsonify({"error": "Forbidden"}), 403 + return import_csv(request.files.get("file")) + # -------------------------- # Online history # -------------------------- @@ -144,7 +187,7 @@ def api_delete_online_history(): # -------------------------- @app.route("/events/", methods=["DELETE"]) -def api_delete_device_events(mac): +def api_events_by_mac(mac): if not is_authorized(): return jsonify({"error": "Forbidden"}), 403 return delete_device_events(mac) @@ -156,7 +199,7 @@ def api_delete_all_events(): return delete_events() @app.route("/events", methods=["GET"]) -def api_delete_all_events(): +def api_get_events(): if not is_authorized(): return jsonify({"error": "Forbidden"}), 403 @@ -170,22 +213,6 @@ def api_delete_old_events(): return jsonify({"error": "Forbidden"}), 403 return delete_events_30() -# -------------------------- -# CSV Import / Export -# -------------------------- - -@app.route("/devices/export", methods=["GET"]) -def api_export_csv(): - if not is_authorized(): - return jsonify({"error": "Forbidden"}), 403 - return export_csv() - -@app.route("/devices/import", methods=["POST"]) -def api_import_csv(): - if not is_authorized(): - return jsonify({"error": "Forbidden"}), 403 - return import_csv(request.files.get("file")) - # -------------------------- # Prometheus metrics endpoint # -------------------------- diff --git a/server/api_server/device_endpoint.py b/server/api_server/device_endpoint.py index a54d6ef9..c1234743 100755 --- a/server/api_server/device_endpoint.py +++ b/server/api_server/device_endpoint.py @@ -14,8 +14,8 @@ INSTALL_PATH="/app" sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) from database import get_temp_db_connection -from helper import row_to_json, get_date_from_period, is_random_mac, format_date, get_setting_value - +from helper import is_random_mac, format_date, get_setting_value +from db.db_helper import row_to_json, get_date_from_period # -------------------------- # Device Endpoints Functions @@ -272,3 +272,63 @@ def reset_device_props(mac, data=None): conn.close() return jsonify({"success": True}) +def update_device_column(mac, column_name, column_value): + """ + Update a specific column for a given device. + Example: update_device_column("AA:BB:CC:DD:EE:FF", "devParentMAC", "Internet") + """ + + conn = get_temp_db_connection() + cur = conn.cursor() + + # Build safe SQL with column name whitelisted + sql = f"UPDATE Devices SET {column_name}=? WHERE devMac=?" + cur.execute(sql, (column_value, mac)) + conn.commit() + + if cur.rowcount > 0: + return jsonify({"success": True}) + else: + return jsonify({"success": False, "error": "Device not found"}), 404 + + conn.close() + + return jsonify({"success": True}) + +def copy_device(mac_from, mac_to): + """ + Copy a device entry from one MAC to another. + If a device already exists with mac_to, it will be replaced. + """ + conn = get_temp_db_connection() + cur = conn.cursor() + + try: + # Drop temporary table if exists + cur.execute("DROP TABLE IF EXISTS temp_devices") + + # Create temporary table with source device + cur.execute("CREATE TABLE temp_devices AS SELECT * FROM Devices WHERE devMac = ?", (mac_from,)) + + # Update temporary table to target MAC + cur.execute("UPDATE temp_devices SET devMac = ?", (mac_to,)) + + # Delete previous entry with target MAC + cur.execute("DELETE FROM Devices WHERE devMac = ?", (mac_to,)) + + # Insert new entry from temporary table + cur.execute("INSERT INTO Devices SELECT * FROM temp_devices WHERE devMac = ?", (mac_to,)) + + # Drop temporary table + cur.execute("DROP TABLE temp_devices") + + conn.commit() + return jsonify({"success": True, "message": f"Device copied from {mac_from} to {mac_to}"}) + + except Exception as e: + conn.rollback() + return jsonify({"success": False, "error": str(e)}) + + finally: + conn.close() + diff --git a/server/api_server/devices_endpoint.py b/server/api_server/devices_endpoint.py index 92fa796b..07e84431 100755 --- a/server/api_server/devices_endpoint.py +++ b/server/api_server/devices_endpoint.py @@ -5,16 +5,22 @@ import subprocess import argparse import os import pathlib +import base64 +import re import sys from datetime import datetime -from flask import jsonify, request +from flask import jsonify, request, Response +import csv +import io +from io import StringIO # Register NetAlertX directories INSTALL_PATH="/app" sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) from database import get_temp_db_connection -from helper import row_to_json, get_date_from_period, is_random_mac, format_date, get_setting_value +from helper import is_random_mac, format_date, get_setting_value +from db.db_helper import get_table_json # -------------------------- @@ -72,4 +78,118 @@ def delete_unknown_devices(): cur.execute("""DELETE FROM Devices WHERE devName='(unknown)' OR devName='(name not found)'""") conn.commit() conn.close() - return jsonify({"success": True, "deleted": cur.rowcount}) \ No newline at end of file + return jsonify({"success": True, "deleted": cur.rowcount}) + +def export_devices(export_format): + """ + Export devices from the Devices table in teh desired format. + - If `macs` is None → delete ALL devices. + - If `macs` is a list → delete only matching MACs (supports wildcard '*'). + """ + conn = get_temp_db_connection() + cur = conn.cursor() + + # Fetch all devices + devices_json = get_table_json(cur, "SELECT * FROM Devices") + conn.close() + + # Ensure columns exist + columns = devices_json.columnNames or ( + list(devices_json["data"][0].keys()) if devices_json["data"] else [] + ) + + + if export_format == "json": + # Convert to standard dict for Flask JSON + return jsonify({ + "data": [row for row in devices_json["data"]], + "columns": list(columns) + }) + elif export_format == "csv": + + si = StringIO() + writer = csv.DictWriter(si, fieldnames=columns, quoting=csv.QUOTE_ALL) + writer.writeheader() + for row in devices_json.json["data"]: + writer.writerow(row) + + return Response( + si.getvalue(), + mimetype="text/csv", + headers={"Content-Disposition": "attachment; filename=devices.csv"}, + ) + else: + return jsonify({"error": f"Unsupported format '{export_format}'"}), 400 + +def import_csv(file_storage=None): + data = "" + skipped = [] + error = None + + # 1. Try JSON `content` (base64-encoded CSV) + if request.is_json and request.json.get("content"): + try: + data = base64.b64decode(request.json["content"], validate=True).decode("utf-8") + except Exception as e: + return jsonify({"error": f"Base64 decode failed: {e}"}), 400 + + # 2. Otherwise, try uploaded file + elif file_storage: + data = file_storage.read().decode("utf-8") + + # 3. Fallback: try local file (same as PHP `$file = '../../../config/devices.csv';`) + else: + local_file = "/app/config/devices.csv" + try: + with open(local_file, "r", encoding="utf-8") as f: + data = f.read() + except FileNotFoundError: + return jsonify({"error": "CSV file missing"}), 404 + + if not data: + return jsonify({"error": "No CSV data found"}), 400 + + # --- Clean up newlines inside quoted fields --- + data = re.sub( + r'"([^"]*)"', + lambda m: m.group(0).replace("\n", " "), + data + ) + + # --- Parse CSV --- + lines = data.splitlines() + reader = csv.reader(lines) + try: + header = [h.strip() for h in next(reader)] + except StopIteration: + return jsonify({"error": "CSV missing header"}), 400 + + # --- Wipe Devices table --- + conn = get_temp_db_connection() + sql = conn.cursor() + sql.execute("DELETE FROM Devices") + + # --- Prepare insert --- + placeholders = ",".join(["?"] * len(header)) + insert_sql = f"INSERT INTO Devices ({', '.join(header)}) VALUES ({placeholders})" + + row_count = 0 + for idx, row in enumerate(reader, start=1): + if len(row) != len(header): + skipped.append(idx) + continue + try: + sql.execute(insert_sql, [col.strip() for col in row]) + row_count += 1 + except sqlite3.Error as e: + mylog("error", [f"[ImportCSV] SQL ERROR row {idx}: {e}"]) + skipped.append(idx) + + conn.commit() + conn.close() + + return jsonify({ + "success": True, + "inserted": row_count, + "skipped_lines": skipped + }) \ No newline at end of file diff --git a/server/api_server/events_endpoint.py b/server/api_server/events_endpoint.py index d582fe27..0731218d 100755 --- a/server/api_server/events_endpoint.py +++ b/server/api_server/events_endpoint.py @@ -14,7 +14,8 @@ INSTALL_PATH="/app" sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) from database import get_temp_db_connection -from helper import row_to_json, get_date_from_period, is_random_mac, format_date, get_setting_value +from helper import is_random_mac, format_date, get_setting_value +from db.db_helper import row_to_json # -------------------------- @@ -68,16 +69,4 @@ def delete_events(): return jsonify({"success": True, "message": "Deleted all events"}) -def delete_device_events(mac): - """Delete all events""" - - conn = get_temp_db_connection() - cur = conn.cursor() - - sql = "DELETE FROM Events WHERE eve_MAC= ? " - cur.execute(sql, (mac,)) - conn.commit() - conn.close() - - return jsonify({"success": True, "message": "Deleted all events for the device"}) diff --git a/server/api_server/history_endpoint.py b/server/api_server/history_endpoint.py index 802f9759..bf719ec2 100755 --- a/server/api_server/history_endpoint.py +++ b/server/api_server/history_endpoint.py @@ -14,7 +14,7 @@ INSTALL_PATH="/app" sys.path.extend([f"{INSTALL_PATH}/front/plugins", f"{INSTALL_PATH}/server"]) from database import get_temp_db_connection -from helper import row_to_json, get_date_from_period, is_random_mac, format_date, get_setting_value +from helper import is_random_mac, format_date, get_setting_value # -------------------------------------------------- diff --git a/server/database.py b/server/database.py index a0ca5f99..1fd13a1e 100755 --- a/server/database.py +++ b/server/database.py @@ -8,7 +8,8 @@ import json from const import fullDbPath, sql_devices_stats, sql_devices_all, sql_generateGuid from logger import mylog -from helper import json_obj, initOrSetParam, row_to_json, timeNowTZ +from helper import timeNowTZ +from db.db_helper import row_to_json, get_table_json, json_obj from workflows.app_events import AppEvent_obj from db.db_upgrade import ensure_column, ensure_views, ensure_CurrentScan, ensure_plugins_tables, ensure_Parameters, ensure_Settings, ensure_Indexes @@ -121,26 +122,41 @@ class DB(): AppEvent_obj(self) - #------------------------------------------------------------------------------- + # #------------------------------------------------------------------------------- + # def get_table_as_json(self, sqlQuery): + + # # mylog('debug',[ '[Database] - get_table_as_json - Query: ', sqlQuery]) + # try: + # self.sql.execute(sqlQuery) + # columnNames = list(map(lambda x: x[0], self.sql.description)) + # rows = self.sql.fetchall() + # except sqlite3.Error as e: + # mylog('verbose',[ '[Database] - SQL ERROR: ', e]) + # return json_obj({}, []) # return empty object + + # result = {"data":[]} + # for row in rows: + # tmp = row_to_json(columnNames, row) + # result["data"].append(tmp) + + # # mylog('debug',[ '[Database] - get_table_as_json - returning ', len(rows), " rows with columns: ", columnNames]) + # # mylog('debug',[ '[Database] - get_table_as_json - returning json ', json.dumps(result) ]) + # return json_obj(result, columnNames) + def get_table_as_json(self, sqlQuery): - - # mylog('debug',[ '[Database] - get_table_as_json - Query: ', sqlQuery]) + """ + Wrapper to use the central get_table_as_json helper. + """ try: - self.sql.execute(sqlQuery) - columnNames = list(map(lambda x: x[0], self.sql.description)) - rows = self.sql.fetchall() - except sqlite3.Error as e: - mylog('verbose',[ '[Database] - SQL ERROR: ', e]) - return json_obj({}, []) # return empty object - - result = {"data":[]} - for row in rows: - tmp = row_to_json(columnNames, row) - result["data"].append(tmp) + result = get_table_json(self.sql, sqlQuery) + except Exception as e: + mylog('verbose', ['[Database] - get_table_as_json ERROR:', e]) + return json_obj({}, []) # return empty object on failure # mylog('debug',[ '[Database] - get_table_as_json - returning ', len(rows), " rows with columns: ", columnNames]) # mylog('debug',[ '[Database] - get_table_as_json - returning json ', json.dumps(result) ]) - return json_obj(result, columnNames) + + return result #------------------------------------------------------------------------------- # referece from here: https://codereview.stackexchange.com/questions/241043/interface-class-for-sqlite-databases diff --git a/server/db/db_helper.py b/server/db/db_helper.py new file mode 100755 index 00000000..d1039aaa --- /dev/null +++ b/server/db/db_helper.py @@ -0,0 +1,269 @@ +import sys +import sqlite3 + +# Register NetAlertX directories +INSTALL_PATH="/app" +sys.path.extend([f"{INSTALL_PATH}/server"]) + +from helper import if_byte_then_to_str +from logger import mylog + +#------------------------------------------------------------------------------- +# Return the SQL WHERE clause for filtering devices based on their status. + +def get_device_condition_by_status(device_status): + """ + Return the SQL WHERE clause for filtering devices based on their status. + + Parameters: + device_status (str): The status of the device. Possible values: + - 'all' : All active devices + - 'my' : Same as 'all' (active devices) + - 'connected' : Devices that are active and present in the last scan + - 'favorites' : Devices marked as favorite + - 'new' : Devices marked as new + - 'down' : Devices not present in the last scan but with alerts + - 'archived' : Devices that are archived + + Returns: + str: SQL WHERE clause corresponding to the device status. + Defaults to 'WHERE 1=0' for unrecognized statuses. + """ + conditions = { + 'all': 'WHERE devIsArchived=0', + 'my': 'WHERE devIsArchived=0', + 'connected': 'WHERE devIsArchived=0 AND devPresentLastScan=1', + 'favorites': 'WHERE devIsArchived=0 AND devFavorite=1', + 'new': 'WHERE devIsArchived=0 AND devIsNew=1', + 'down': 'WHERE devIsArchived=0 AND devAlertDown != 0 AND devPresentLastScan=0', + 'archived': 'WHERE devIsArchived=1' + } + return conditions.get(device_status, 'WHERE 1=0') + + + +#------------------------------------------------------------------------------- +# Creates a JSON-like dictionary from a database row +def row_to_json(names, row): + """ + Convert a database row into a JSON-like dictionary. + + Parameters: + names (list of str): List of column names corresponding to the row fields. + row (dict or sequence): A database row, typically a dictionary or list-like object, + where each column can be accessed by index or key. + + Returns: + dict: A dictionary where keys are column names and values are the corresponding + row values. Byte values are automatically converted to strings using + `if_byte_then_to_str`. + + Example: + names = ['id', 'name', 'data'] + row = {0: 1, 1: b'Example', 2: b'\x01\x02'} + row_to_json(names, row) + # Returns: {'id': 1, 'name': 'Example', 'data': '\\x01\\x02'} + """ + rowEntry = {} + + for index, name in enumerate(names): + rowEntry[name] = if_byte_then_to_str(row[name]) + + return rowEntry + +#------------------------------------------------------------------------------- +def sanitize_SQL_input(val): + """ + Sanitize a value for use in SQL queries by replacing single quotes in strings. + + Parameters: + val (any): The value to sanitize. + + Returns: + str or any: + - Returns an empty string if val is None. + - Returns a string with single quotes replaced by underscores if val is a string. + - Returns val unchanged if it is any other type. + """ + if val is None: + return '' + if isinstance(val, str): + return val.replace("'", "_") + return val # Return non-string values as they are + + +# ------------------------------------------------------------------------------------------- +def get_date_from_period(period): + """ + Convert a period string into an SQLite date expression. + + Parameters: + period (str): The requested period (e.g., '7 days', '1 month', '1 year', '100 years'). + + Returns: + str: An SQLite date expression like "date('now', '-7 day')" corresponding to the period. + """ + days_map = { + '7 days': 7, + '1 month': 30, + '1 year': 365, + '100 years': 3650, # actually 10 years in original PHP + } + + days = days_map.get(period, 1) # default 1 day + period_sql = f"date('now', '-{days} day')" + + return period_sql + + + +#------------------------------------------------------------------------------- +def print_table_schema(db, table): + """ + Print the schema of a database table to the log. + + Parameters: + db: A database connection object with a `sql` cursor. + table (str): The name of the table whose schema is to be printed. + + Returns: + None: Logs the column information including cid, name, type, notnull, default value, and primary key. + """ + sql = db.sql + sql.execute(f"PRAGMA table_info({table})") + result = sql.fetchall() + + if not result: + mylog('none', f'[Schema] Table "{table}" not found or has no columns.') + return + + mylog('debug', f'[Schema] Structure for table: {table}') + header = f"{'cid':<4} {'name':<20} {'type':<10} {'notnull':<8} {'default':<10} {'pk':<2}" + mylog('debug', header) + mylog('debug', '-' * len(header)) + + for row in result: + # row = (cid, name, type, notnull, dflt_value, pk) + line = f"{row[0]:<4} {row[1]:<20} {row[2]:<10} {row[3]:<8} {str(row[4]):<10} {row[5]:<2}" + mylog('debug', line) + +#------------------------------------------------------------------------------- +# Generate a WHERE condition for SQLite based on a list of values. +def list_to_where(logical_operator, column_name, condition_operator, values_list): + """ + Generate a WHERE condition for SQLite based on a list of values. + + Parameters: + - logical_operator: The logical operator ('AND' or 'OR') to combine conditions. + - column_name: The name of the column to filter on. + - condition_operator: The condition operator ('LIKE', 'NOT LIKE', '=', '!=', etc.). + - values_list: A list of values to be included in the condition. + + Returns: + - A string representing the WHERE condition. + """ + + # If the list is empty, return an empty string + if not values_list: + return "" + + # Replace {s-quote} with single quote in values_list + values_list = [value.replace("{s-quote}", "'") for value in values_list] + + # Build the WHERE condition for the first value + condition = f"{column_name} {condition_operator} '{values_list[0]}'" + + # Add the rest of the values using the logical operator + for value in values_list[1:]: + condition += f" {logical_operator} {column_name} {condition_operator} '{value}'" + + return f'({condition})' + +#------------------------------------------------------------------------------- +def get_table_json(sql, sql_query): + """ + Execute a SQL query and return the results as JSON-like dict. + + Args: + sql: SQLite cursor or connection wrapper supporting execute(), description, and fetchall(). + sql_query (str): The SQL query to execute. + + Returns: + dict: JSON-style object with data and column names. + """ + try: + sql.execute(sql_query) + column_names = [col[0] for col in sql.description] + rows = sql.fetchall() + except sqlite3.Error as e: + mylog('verbose', ['[Database] - SQL ERROR: ', e]) + return json_obj({}, []) # return empty object + + result = {"data": [row_to_json(column_names, row) for row in rows]} + return json_obj(result, column_names) + + +#------------------------------------------------------------------------------- +class json_obj: + """ + A wrapper class for JSON-style objects returned from database queries. + Provides dict-like access to the JSON data while storing column metadata. + + Attributes: + json (dict): The actual JSON-style data returned from the database. + columnNames (list): List of column names corresponding to the data. + """ + + def __init__(self, jsn, columnNames): + """ + Initialize the json_obj with JSON data and column names. + + Args: + jsn (dict): JSON-style dictionary containing the data. + columnNames (list): List of column names for the data. + """ + self.json = jsn + self.columnNames = columnNames + + def get(self, key, default=None): + """ + Dict-like .get() access to the JSON data. + + Args: + key (str): Key to retrieve from the JSON data. + default: Value to return if key is not found (default: None). + + Returns: + Value corresponding to key in the JSON data, or default if not present. + """ + return self.json.get(key, default) + + def keys(self): + """ + Return the keys of the JSON data. + + Returns: + Iterable of keys in the JSON dictionary. + """ + return self.json.keys() + + def items(self): + """ + Return the items of the JSON data. + + Returns: + Iterable of (key, value) pairs in the JSON dictionary. + """ + return self.json.items() + + def __getitem__(self, key): + """ + Allow bracket-access (obj[key]) to the JSON data. + + Args: + key (str): Key to retrieve from the JSON data. + + Returns: + Value corresponding to the key. + """ + return self.json[key] diff --git a/server/helper.py b/server/helper.py index 4eeca146..7eb3fcc3 100755 --- a/server/helper.py +++ b/server/helper.py @@ -18,7 +18,6 @@ import hashlib import random import string import ipaddress -import dns.resolver import conf from const import * @@ -53,22 +52,6 @@ def get_timezone_offset(): return offset_formatted -#------------------------------------------------------------------------------- -def updateSubnets(scan_subnets): - subnets = [] - - # multiple interfaces - if type(scan_subnets) is list: - for interface in scan_subnets : - subnets.append(interface) - # one interface only - else: - subnets.append(scan_subnets) - - return subnets - - - #------------------------------------------------------------------------------- # File system permission handling #------------------------------------------------------------------------------- @@ -217,12 +200,6 @@ def get_setting(key): return None - - -#------------------------------------------------------------------------------- -# Settings -#------------------------------------------------------------------------------- - #------------------------------------------------------------------------------- # Return setting value def get_setting_value(key): @@ -248,8 +225,6 @@ def get_setting_value(key): #------------------------------------------------------------------------------- # Convert the setting value to the corresponding python type - - def setting_value_to_python_type(set_type, set_value): value = '----not processed----' @@ -341,6 +316,30 @@ def setting_value_to_python_type(set_type, set_value): return value +#------------------------------------------------------------------------------- +def updateSubnets(scan_subnets): + """ + Normalize scan subnet input into a list of subnets. + + Parameters: + scan_subnets (str or list): A single subnet string or a list of subnet strings. + + Returns: + list: A list containing all subnets. If a single subnet is provided, it is returned as a single-element list. + """ + subnets = [] + + # multiple interfaces + if isinstance(scan_subnets, list): + for interface in scan_subnets: + subnets.append(interface) + # one interface only + else: + subnets.append(scan_subnets) + + return subnets + + #------------------------------------------------------------------------------- # Reverse transformed values if needed def reverseTransformers(val, transformers): @@ -360,41 +359,6 @@ def reverseTransformers(val, transformers): else: return reverse_transformers(val, transformers) -#------------------------------------------------------------------------------- -# Generate a WHERE condition for SQLite based on a list of values. -def list_to_where(logical_operator, column_name, condition_operator, values_list): - """ - Generate a WHERE condition for SQLite based on a list of values. - - Parameters: - - logical_operator: The logical operator ('AND' or 'OR') to combine conditions. - - column_name: The name of the column to filter on. - - condition_operator: The condition operator ('LIKE', 'NOT LIKE', '=', '!=', etc.). - - values_list: A list of values to be included in the condition. - - Returns: - - A string representing the WHERE condition. - """ - - # If the list is empty, return an empty string - if not values_list: - return "" - - # Replace {s-quote} with single quote in values_list - values_list = [value.replace("{s-quote}", "'") for value in values_list] - - # Build the WHERE condition for the first value - condition = f"{column_name} {condition_operator} '{values_list[0]}'" - - # Add the rest of the values using the logical operator - for value in values_list[1:]: - condition += f" {logical_operator} {column_name} {condition_operator} '{value}'" - - return f'({condition})' - - - - #------------------------------------------------------------------------------- # IP validation methods @@ -432,6 +396,19 @@ def check_IP_format (pIP): # String manipulation methods #------------------------------------------------------------------------------- +#------------------------------------------------------------------------------- +def generate_random_string(length): + characters = string.ascii_letters + string.digits + return ''.join(random.choice(characters) for _ in range(length)) + +#------------------------------------------------------------------------------- +def extract_between_strings(text, start, end): + start_index = text.find(start) + end_index = text.find(end, start_index + len(start)) + if start_index != -1 and end_index != -1: + return text[start_index + len(start):end_index] + else: + return "" #------------------------------------------------------------------------------- @@ -474,7 +451,6 @@ def removeDuplicateNewLines(text): return text #------------------------------------------------------------------------------- - def sanitize_string(input): if isinstance(input, bytes): input = input.decode('utf-8') @@ -482,15 +458,6 @@ def sanitize_string(input): return input -#------------------------------------------------------------------------------- -def sanitize_SQL_input(val): - if val is None: - return '' - if isinstance(val, str): - return val.replace("'", "_") - return val # Return non-string values as they are - - #------------------------------------------------------------------------------- # Function to normalize the string and remove diacritics def normalize_string(text): @@ -501,8 +468,29 @@ def normalize_string(text): # Filter out diacritics and unwanted characters return ''.join(c for c in normalized_text if unicodedata.category(c) != 'Mn') - +# ------------------------------------------------------------------------------ +# MAC and IP helper methods #------------------------------------------------------------------------------- + +# ------------------------------------------------------------------------------------------- +def is_random_mac(mac: str) -> bool: + """Determine if a MAC address is random, respecting user-defined prefixes not to mark as random.""" + + is_random = mac[1].upper() in ["2", "6", "A", "E"] + + # Get prefixes from settings + prefixes = get_setting_value("UI_NOT_RANDOM_MAC") + + # If detected as random, make sure it doesn't start with a prefix the user wants to exclude + if is_random: + for prefix in prefixes: + if mac.upper().startswith(prefix.upper()): + is_random = False + break + + return is_random + +# ------------------------------------------------------------------------------------------- def generate_mac_links (html, deviceUrl): p = re.compile(r'(?:[0-9a-fA-F]:?){12}') @@ -514,15 +502,6 @@ def generate_mac_links (html, deviceUrl): return html -#------------------------------------------------------------------------------- -def extract_between_strings(text, start, end): - start_index = text.find(start) - end_index = text.find(end, start_index + len(start)) - if start_index != -1 and end_index != -1: - return text[start_index + len(start):end_index] - else: - return "" - #------------------------------------------------------------------------------- def extract_mac_addresses(text): mac_pattern = r"([0-9A-Fa-f]{2}[:-][0-9A-Fa-f]{2}[:-][0-9A-Fa-f]{2}[:-][0-9A-Fa-f]{2}[:-][0-9A-Fa-f]{2}[:-][0-9A-Fa-f]{2})" @@ -536,11 +515,6 @@ def extract_ip_addresses(text): return ip_addresses #------------------------------------------------------------------------------- -def generate_random_string(length): - characters = string.ascii_letters + string.digits - return ''.join(random.choice(characters) for _ in range(length)) - - # Helper function to determine if a MAC address is random def is_random_mac(mac): # Check if second character matches "2", "6", "A", "E" (case insensitive) @@ -555,13 +529,14 @@ def is_random_mac(mac): break return is_random +#------------------------------------------------------------------------------- # Helper function to calculate number of children def get_number_of_children(mac, devices): # Count children by checking devParentMAC for each device return sum(1 for dev in devices if dev.get("devParentMAC", "").strip() == mac.strip()) - +#------------------------------------------------------------------------------- # Function to convert IP to a long integer def format_ip_long(ip_address): try: @@ -596,8 +571,6 @@ def add_json_list (row, list): return list - - #------------------------------------------------------------------------------- # Checks if the object has a __dict__ attribute. If it does, it assumes that it's an instance of a class and serializes its attributes dynamically. class NotiStrucEncoder(json.JSONEncoder): @@ -607,19 +580,6 @@ class NotiStrucEncoder(json.JSONEncoder): return obj.__dict__ return super().default(obj) -#------------------------------------------------------------------------------- -# Creates a JSON object from a DB row -def row_to_json(names, row): - - rowEntry = {} - - index = 0 - for name in names: - rowEntry[name]= if_byte_then_to_str(row[name]) - index += 1 - - return rowEntry - #------------------------------------------------------------------------------- # Get language strings from plugin JSON def collect_lang_strings(json, pref, stringSqlParams): @@ -633,7 +593,7 @@ def collect_lang_strings(json, pref, stringSqlParams): return stringSqlParams #------------------------------------------------------------------------------- -# Misc +# Date and time methods #------------------------------------------------------------------------------- # ------------------------------------------------------------------------------------------- @@ -661,65 +621,6 @@ def format_date_iso(date1: str) -> str: dt = datetime.datetime.fromisoformat(date1) if isinstance(date1, str) else date1 return dt.isoformat() -# ------------------------------------------------------------------------------------------- -def is_random_mac(mac: str) -> bool: - """Determine if a MAC address is random, respecting user-defined prefixes not to mark as random.""" - - is_random = mac[1].upper() in ["2", "6", "A", "E"] - - # Get prefixes from settings - prefixes = get_setting_value("UI_NOT_RANDOM_MAC") - - # If detected as random, make sure it doesn't start with a prefix the user wants to exclude - if is_random: - for prefix in prefixes: - if mac.upper().startswith(prefix.upper()): - is_random = False - break - - return is_random - - -# ------------------------------------------------------------------------------------------- -def get_date_from_period(period): - """ - Convert a period request parameter into an SQLite date expression. - Equivalent to PHP getDateFromPeriod(). - Returns a string like "date('now', '-7 day')" - """ - days_map = { - '7 days': 7, - '1 month': 30, - '1 year': 365, - '100 years': 3650, # actually 10 years in original PHP - } - - days = days_map.get(period, 1) # default 1 day - period_sql = f"date('now', '-{days} day')" - - return period_sql - -#------------------------------------------------------------------------------- -def print_table_schema(db, table): - sql = db.sql - sql.execute(f"PRAGMA table_info({table})") - result = sql.fetchall() - - if not result: - mylog('none', f'[Schema] Table "{table}" not found or has no columns.') - return - - mylog('debug', f'[Schema] Structure for table: {table}') - header = f"{'cid':<4} {'name':<20} {'type':<10} {'notnull':<8} {'default':<10} {'pk':<2}" - mylog('debug', header) - mylog('debug', '-' * len(header)) - - for row in result: - # row = (cid, name, type, notnull, dflt_value, pk) - line = f"{row[0]:<4} {row[1]:<20} {row[2]:<10} {row[3]:<8} {str(row[4]):<10} {row[5]:<2}" - mylog('debug', line) - - #------------------------------------------------------------------------------- def checkNewVersion(): mylog('debug', [f"[Version check] Checking if new version available"]) @@ -761,22 +662,6 @@ def checkNewVersion(): return newVersion - - -#------------------------------------------------------------------------------- -def initOrSetParam(db, parID, parValue): - sql = db.sql - - sql.execute ("INSERT INTO Parameters(par_ID, par_Value) VALUES('"+str(parID)+"', '"+str(parValue)+"') ON CONFLICT(par_ID) DO UPDATE SET par_Value='"+str(parValue)+"' where par_ID='"+str(parID)+"'") - - db.commitDB() - -#------------------------------------------------------------------------------- -class json_obj: - def __init__(self, jsn, columnNames): - self.json = jsn - self.columnNames = columnNames - #------------------------------------------------------------------------------- class noti_obj: def __init__(self, json, text, html): diff --git a/server/initialise.py b/server/initialise.py index 8fcd75c4..e827965d 100755 --- a/server/initialise.py +++ b/server/initialise.py @@ -12,7 +12,7 @@ import re # Register NetAlertX libraries import conf from const import fullConfPath, applicationPath, fullConfFolder, default_tz -from helper import fixPermissions, collect_lang_strings, updateSubnets, initOrSetParam, isJsonObject, setting_value_to_python_type, timeNowTZ, get_setting_value, generate_random_string +from helper import fixPermissions, collect_lang_strings, updateSubnets, isJsonObject, setting_value_to_python_type, timeNowTZ, get_setting_value, generate_random_string from app_state import updateState from logger import mylog from api import update_api diff --git a/server/scan/device_handling.py b/server/scan/device_handling.py index 7536599b..60825327 100755 --- a/server/scan/device_handling.py +++ b/server/scan/device_handling.py @@ -8,12 +8,13 @@ import re INSTALL_PATH="/app" sys.path.extend([f"{INSTALL_PATH}/server"]) -from helper import timeNowTZ, get_setting_value, list_to_where, check_IP_format, sanitize_SQL_input +from helper import timeNowTZ, get_setting_value, check_IP_format from logger import mylog from const import vendorsPath, vendorsPathNewest, sql_generateGuid from models.device_instance import DeviceInstance from scan.name_resolution import NameResolver from scan.device_heuristics import guess_icon, guess_type +from db.db_helper import sanitize_SQL_input, list_to_where #------------------------------------------------------------------------------- # Removing devices from the CurrentScan DB table which the user chose to ignore by MAC or IP diff --git a/server/scan/session_events.py b/server/scan/session_events.py index d4b1083b..7f999041 100755 --- a/server/scan/session_events.py +++ b/server/scan/session_events.py @@ -6,7 +6,8 @@ sys.path.extend([f"{INSTALL_PATH}/server"]) import conf from scan.device_handling import create_new_devices, print_scan_stats, save_scanned_devices, exclude_ignored_devices, update_devices_data_from_scan -from helper import timeNowTZ, print_table_schema, get_setting_value +from helper import timeNowTZ, get_setting_value +from db.db_helper import print_table_schema from logger import mylog, Logger from messaging.reporting import skip_repeated_notifications diff --git a/test/test_device_endpoints.py b/test/test_device_endpoints.py index 7bba839a..c9504e9e 100755 --- a/test/test_device_endpoints.py +++ b/test/test_device_endpoints.py @@ -66,3 +66,56 @@ def test_delete_device(client, api_token, test_mac): resp = client.delete(f"/device/{test_mac}/delete", headers=auth_headers(api_token)) assert resp.status_code == 200 assert resp.json.get("success") is True + +def test_copy_device(client, api_token, test_mac): + # Step 1: Create the source device + payload = {"createNew": True, "name": "Source Device"} + resp = client.post(f"/device/{test_mac}", json=payload, headers=auth_headers(api_token)) + assert resp.status_code == 200 + assert resp.json.get("success") is True + + # Step 2: Generate a target MAC + target_mac = "AA:BB:CC:" + ":".join(f"{random.randint(0,255):02X}" for _ in range(3)) + + # Step 3: Copy device + copy_payload = {"macFrom": test_mac, "macTo": target_mac} + resp = client.post("/device/copy", json=copy_payload, headers=auth_headers(api_token)) + assert resp.status_code == 200 + assert resp.json.get("success") is True + + # Step 4: Verify new device exists + resp = client.get(f"/device/{target_mac}", headers=auth_headers(api_token)) + assert resp.status_code == 200 + assert resp.json.get("devMac") == target_mac + + # Cleanup: delete both devices + client.delete(f"/device/{test_mac}/delete", headers=auth_headers(api_token)) + client.delete(f"/device/{target_mac}/delete", headers=auth_headers(api_token)) + +def test_update_device_column(client, api_token, test_mac): + # First, create the device + client.post( + f"/device/{test_mac}", + json={"createNew": True}, + headers=auth_headers(api_token), + ) + + # Update its parent MAC + resp = client.post( + f"/device/{test_mac}/update-column", + json={"columnName": "devParentMAC", "columnValue": "Internet"}, + headers=auth_headers(api_token), + ) + + assert resp.status_code == 200 + assert resp.json.get("success") is True + + # Try updating a non-existent device + resp_missing = client.post( + "/device/11:22:33:44:55:66/update-column", + json={"columnName": "devParentMAC", "columnValue": "Internet"}, + headers=auth_headers(api_token), + ) + + assert resp_missing.status_code == 404 + assert resp_missing.json.get("success") is False diff --git a/test/test_devices_endpoints.py b/test/test_devices_endpoints.py index aa65b211..0d3443e8 100755 --- a/test/test_devices_endpoints.py +++ b/test/test_devices_endpoints.py @@ -1,6 +1,7 @@ import sys import pathlib import sqlite3 +import base64 import random import string import uuid @@ -29,9 +30,8 @@ def test_mac(): def auth_headers(token): return {"Authorization": f"Bearer {token}"} -def test_delete_devices_with_macs(client, api_token, test_mac): - # First create device so it exists +def create_dummy(client, api_token, test_mac): payload = { "createNew": True, "name": "Test Device", @@ -40,6 +40,10 @@ def test_delete_devices_with_macs(client, api_token, test_mac): "vendor": "TestVendor", } resp = client.post(f"/device/{test_mac}", json=payload, headers=auth_headers(api_token)) + +def test_delete_devices_with_macs(client, api_token, test_mac): + # First create device so it exists + create_dummy(client, api_token, test_mac) client.post(f"/device/{test_mac}", json={"createNew": True}, headers=auth_headers(api_token)) @@ -48,14 +52,6 @@ def test_delete_devices_with_macs(client, api_token, test_mac): assert resp.status_code == 200 assert resp.json.get("success") is True -def test_delete_test_devices(client, api_token, test_mac): - - # Delete by MAC - resp = client.delete("/devices", json={"macs": ["AA:BB:CC:*"]}, headers=auth_headers(api_token)) - assert resp.status_code == 200 - assert resp.json.get("success") is True - - def test_delete_all_empty_macs(client, api_token): resp = client.delete("/devices/empty-macs", headers=auth_headers(api_token)) assert resp.status_code == 200 @@ -68,3 +64,72 @@ def test_delete_unknown_devices(client, api_token): assert resp.status_code == 200 assert resp.json.get("success") is True +def test_export_devices_csv(client, api_token, test_mac): + # Create a device first + create_dummy(client, api_token, test_mac) + + # Export devices as CSV + resp = client.get("/devices/export/csv", headers=auth_headers(api_token)) + assert resp.status_code == 200 + assert resp.mimetype == "text/csv" + assert "attachment; filename=devices.csv" in resp.headers.get("Content-disposition", "") + + # CSV should contain test_mac + assert test_mac in resp.data.decode() + +def test_export_devices_json(client, api_token, test_mac): + # Create a device first + create_dummy(client, api_token, test_mac) + + # Export devices as JSON + resp = client.get("/devices/export/json", headers=auth_headers(api_token)) + assert resp.status_code == 200 + assert resp.is_json + data = resp.get_json() + assert any(dev.get("devMac") == test_mac for dev in data["data"]) + + +def test_export_devices_invalid_format(client, api_token): + # Request with unsupported format + resp = client.get("/devices/export/invalid", headers=auth_headers(api_token)) + assert resp.status_code == 400 + assert "Unsupported format" in resp.json.get("error") + + +def test_export_import_cycle_base64(client, api_token, test_mac): + # 1. Create a dummy device + create_dummy(client, api_token, test_mac) + + # 2. Export devices as CSV + resp = client.get("/devices/export/csv", headers=auth_headers(api_token)) + assert resp.status_code == 200 + csv_data = resp.data.decode("utf-8") + + # Ensure our dummy device is in the CSV + assert test_mac in csv_data + assert "Test Device" in csv_data + + # 3. Base64-encode the CSV for JSON payload + csv_base64 = base64.b64encode(csv_data.encode("utf-8")).decode("utf-8") + json_payload = {"content": csv_base64} + + # 4. POST to import endpoint with JSON content + resp = client.post( + "/devices/import", + json=json_payload, + headers={**auth_headers(api_token), "Content-Type": "application/json"} + ) + assert resp.status_code == 200 + assert resp.json.get("success") is True + + # 5. Verify import results + assert resp.json.get("inserted") >= 1 + assert resp.json.get("skipped_lines") == [] + + +def test_delete_test_devices(client, api_token, test_mac): + + # Delete by MAC + resp = client.delete("/devices", json={"macs": ["AA:BB:CC:*"]}, headers=auth_headers(api_token)) + assert resp.status_code == 200 + assert resp.json.get("success") is True \ No newline at end of file