mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 13:15:51 -04:00
feat: add (experimental) fine-tuning support with TRL (#9088)
* feat: add fine-tuning endpoint Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(experimental): add fine-tuning endpoint and TRL support This changeset defines new GRPC signatues for Fine tuning backends, and add TRL backend as initial fine-tuning engine. This implementation also supports exporting to GGUF and automatically importing it to LocalAI after fine-tuning. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * commit TRL backend, stop by killing process Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * move fine-tune to generic features Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * add evals, reorder menu Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Fix tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
f7e3aab4fc
commit
d9c1db2b87
58
backend/python/trl/test.py
Normal file
58
backend/python/trl/test.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Test script for the TRL fine-tuning gRPC backend.
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""Tests for the TRL fine-tuning gRPC service."""
|
||||
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", "localhost:50051"]
|
||||
)
|
||||
time.sleep(10)
|
||||
|
||||
def tearDown(self):
|
||||
self.service.kill()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
"""Test that the server starts and responds to health checks."""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_list_checkpoints_empty(self):
|
||||
"""Test listing checkpoints on a non-existent directory."""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.ListCheckpoints(
|
||||
backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
|
||||
)
|
||||
self.assertEqual(len(response.checkpoints), 0)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("ListCheckpoints service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user