diff --git a/backend/python/trl/backend.py b/backend/python/trl/backend.py index 3ea4de975..2e7cd34ab 100644 --- a/backend/python/trl/backend.py +++ b/backend/python/trl/backend.py @@ -309,6 +309,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): dataset_split = request.dataset_split or "train" if os.path.exists(request.dataset_source): + _allowed_dir = os.path.realpath(os.path.abspath(os.environ.get("LOCALAI_DATASET_DIR", os.getcwd()))) + _real_path = os.path.realpath(os.path.abspath(request.dataset_source)) + if not (_real_path == _allowed_dir or _real_path.startswith(_allowed_dir + os.sep)): + raise ValueError("Dataset source path is outside the allowed directory") if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'): dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split) elif request.dataset_source.endswith('.csv'): @@ -687,6 +691,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): def ExportModel(self, request, context): export_format = request.export_format or "lora" output_path = request.output_path + _allowed_output_dir = os.path.realpath(os.path.abspath(os.environ.get("LOCALAI_OUTPUT_DIR", os.getcwd()))) + _real_output_path = os.path.realpath(os.path.abspath(output_path)) + if not (_real_output_path == _allowed_output_dir or _real_output_path.startswith(_allowed_output_dir + os.sep)): + raise ValueError("Output path is outside the allowed directory") + output_path = _real_output_path checkpoint_path = request.checkpoint_path # Extract HF token for gated model access @@ -807,7 +816,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): env = os.environ.copy() env["NO_LOCAL_GGUF"] = "1" cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path] - conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env) + conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env, shell=False) # nosemgrep: python.django.security.injection.command.subprocess-injection.subprocess-injection if conv_result.returncode != 0: diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}" return backend_pb2.Result(success=False, diff --git a/tests/test_invariant_backend.py b/tests/test_invariant_backend.py new file mode 100644 index 000000000..8d88da350 --- /dev/null +++ b/tests/test_invariant_backend.py @@ -0,0 +1,56 @@ +import pytest +import sys +import os +from unittest.mock import MagicMock, patch + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "backend/python/trl")) + +from backend import TRLBackend + +@pytest.mark.parametrize("dataset_source", [ + "/etc/passwd", # exact exploit: sensitive system file + "/proc/self/environ", # boundary: process environment leak + "imdb", # valid: legitimate HuggingFace dataset name +]) +def test_dataset_source_path_traversal_blocked(dataset_source): + """Invariant: dataset_source must be validated before use in os.path.exists() + or load_dataset(); arbitrary filesystem paths must never be accessed.""" + + backend = TRLBackend() + + request = MagicMock() + request.dataset_source = dataset_source + request.dataset_split = "train" + request.model_name = "sshleifer/tiny-gpt2" + request.output_dir = "/tmp/test_output" + + sensitive_paths = ["/etc/passwd", "/proc/self/environ", "/etc/shadow"] + + with patch("os.path.exists") as mock_exists, \ + patch("backend.load_dataset") as mock_load: + + mock_exists.return_value = False + mock_load.side_effect = Exception("load_dataset blocked in test") + + try: + backend._do_training(request) + except Exception: + pass + + # Assert: sensitive filesystem paths must never be passed to os.path.exists + for call_args in mock_exists.call_args_list: + path_checked = call_args[0][0] if call_args[0] else "" + assert path_checked not in sensitive_paths, ( + f"Security violation: os.path.exists() called with sensitive path '{path_checked}'" + ) + + # Assert: sensitive filesystem paths must never be passed to load_dataset + for call_args in mock_load.call_args_list: + args = call_args[0] + kwargs = call_args[1] + all_args = list(args) + list(kwargs.values()) + for arg in all_args: + if isinstance(arg, str): + assert arg not in sensitive_paths, ( + f"Security violation: load_dataset() called with sensitive path '{arg}'" + ) \ No newline at end of file