mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 13:15:51 -04:00
* feat: add fine-tuning endpoint Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(experimental): add fine-tuning endpoint and TRL support This changeset defines new GRPC signatues for Fine tuning backends, and add TRL backend as initial fine-tuning engine. This implementation also supports exporting to GGUF and automatically importing it to LocalAI after fine-tuning. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * commit TRL backend, stop by killing process Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * move fine-tune to generic features Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * add evals, reorder menu Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Fix tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
"""
|
|
Test script for the TRL fine-tuning gRPC backend.
|
|
"""
|
|
import unittest
|
|
import subprocess
|
|
import time
|
|
|
|
import grpc
|
|
import backend_pb2
|
|
import backend_pb2_grpc
|
|
|
|
|
|
class TestBackendServicer(unittest.TestCase):
|
|
"""Tests for the TRL fine-tuning gRPC service."""
|
|
|
|
def setUp(self):
|
|
self.service = subprocess.Popen(
|
|
["python3", "backend.py", "--addr", "localhost:50051"]
|
|
)
|
|
time.sleep(10)
|
|
|
|
def tearDown(self):
|
|
self.service.kill()
|
|
self.service.wait()
|
|
|
|
def test_server_startup(self):
|
|
"""Test that the server starts and responds to health checks."""
|
|
try:
|
|
self.setUp()
|
|
with grpc.insecure_channel("localhost:50051") as channel:
|
|
stub = backend_pb2_grpc.BackendStub(channel)
|
|
response = stub.Health(backend_pb2.HealthMessage())
|
|
self.assertEqual(response.message, b'OK')
|
|
except Exception as err:
|
|
print(err)
|
|
self.fail("Server failed to start")
|
|
finally:
|
|
self.tearDown()
|
|
|
|
def test_list_checkpoints_empty(self):
|
|
"""Test listing checkpoints on a non-existent directory."""
|
|
try:
|
|
self.setUp()
|
|
with grpc.insecure_channel("localhost:50051") as channel:
|
|
stub = backend_pb2_grpc.BackendStub(channel)
|
|
response = stub.ListCheckpoints(
|
|
backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
|
|
)
|
|
self.assertEqual(len(response.checkpoints), 0)
|
|
except Exception as err:
|
|
print(err)
|
|
self.fail("ListCheckpoints service failed")
|
|
finally:
|
|
self.tearDown()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|