Files
LocalAI/backend/python/trl/test.py
Ettore Di Giacinto d9c1db2b87 feat: add (experimental) fine-tuning support with TRL (#9088)
* feat: add fine-tuning endpoint

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(experimental): add fine-tuning endpoint and TRL support

This changeset defines new GRPC signatues for Fine tuning backends, and
add TRL backend as initial fine-tuning engine. This implementation also
supports exporting to GGUF and automatically importing it to LocalAI
after fine-tuning.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* commit TRL backend, stop by killing process

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* move fine-tune to generic features

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* add evals, reorder menu

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fix tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2026-03-21 02:08:02 +01:00

59 lines
1.7 KiB
Python

"""
Test script for the TRL fine-tuning gRPC backend.
"""
import unittest
import subprocess
import time
import grpc
import backend_pb2
import backend_pb2_grpc
class TestBackendServicer(unittest.TestCase):
"""Tests for the TRL fine-tuning gRPC service."""
def setUp(self):
self.service = subprocess.Popen(
["python3", "backend.py", "--addr", "localhost:50051"]
)
time.sleep(10)
def tearDown(self):
self.service.kill()
self.service.wait()
def test_server_startup(self):
"""Test that the server starts and responds to health checks."""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()
def test_list_checkpoints_empty(self):
"""Test listing checkpoints on a non-existent directory."""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.ListCheckpoints(
backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
)
self.assertEqual(len(response.checkpoints), 0)
except Exception as err:
print(err)
self.fail("ListCheckpoints service failed")
finally:
self.tearDown()
if __name__ == '__main__':
unittest.main()