feat: add (experimental) fine-tuning support with TRL (#9088)

* feat: add fine-tuning endpoint

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(experimental): add fine-tuning endpoint and TRL support

This changeset defines new GRPC signatues for Fine tuning backends, and
add TRL backend as initial fine-tuning engine. This implementation also
supports exporting to GGUF and automatically importing it to LocalAI
after fine-tuning.

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* commit TRL backend, stop by killing process

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* move fine-tune to generic features

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* add evals, reorder menu

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fix tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2026-03-21 02:08:02 +01:00
committed by GitHub
parent f7e3aab4fc
commit d9c1db2b87
49 changed files with 5652 additions and 110 deletions

View File

@@ -0,0 +1,26 @@
# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export)
LLAMA_CPP_CONVERT_VERSION ?= master
.PHONY: trl
trl:
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
.PHONY: run
run: trl
@echo "Running trl..."
bash run.sh
@echo "trl run."
.PHONY: test
test: trl
@echo "Testing trl..."
bash test.sh
@echo "trl tested."
.PHONY: protogen-clean
protogen-clean:
$(RM) backend_pb2_grpc.py backend_pb2.py
.PHONY: clean
clean: protogen-clean
rm -rf venv __pycache__

View File

@@ -0,0 +1,860 @@
#!/usr/bin/env python3
"""
TRL fine-tuning backend for LocalAI.
Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO)
using standard HuggingFace transformers + PEFT. Works on both CPU and GPU.
"""
import argparse
import json
import os
import queue
import signal
import sys
import threading
import time
import uuid
from concurrent import futures
import grpc
import backend_pb2
import backend_pb2_grpc
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
class ProgressCallback:
"""HuggingFace TrainerCallback that pushes progress updates to a queue."""
def __init__(self, job_id, progress_queue, total_epochs):
self.job_id = job_id
self.progress_queue = progress_queue
self.total_epochs = total_epochs
def get_callback(self):
from transformers import TrainerCallback
parent = self
class _Callback(TrainerCallback):
def __init__(self):
self._train_start_time = None
def on_train_begin(self, args, state, control, **kwargs):
self._train_start_time = time.time()
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
total_steps = state.max_steps if state.max_steps > 0 else 0
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
eta = 0.0
if state.global_step > 0 and total_steps > 0 and self._train_start_time:
elapsed = time.time() - self._train_start_time
remaining_steps = total_steps - state.global_step
if state.global_step > 0:
eta = remaining_steps * (elapsed / state.global_step)
extra_metrics = {}
for k, v in logs.items():
if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'):
extra_metrics[k] = float(v)
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=total_steps,
current_epoch=float(logs.get('epoch', 0)),
total_epochs=float(parent.total_epochs),
loss=float(logs.get('loss', 0)),
learning_rate=float(logs.get('learning_rate', 0)),
grad_norm=float(logs.get('grad_norm', 0)),
eval_loss=float(logs.get('eval_loss', 0)),
eta_seconds=float(eta),
progress_percent=float(progress),
status="training",
extra_metrics=extra_metrics,
)
parent.progress_queue.put(update)
def on_prediction_step(self, args, state, control, **kwargs):
"""Send periodic updates during evaluation so the UI doesn't freeze."""
if not hasattr(self, '_eval_update_counter'):
self._eval_update_counter = 0
self._eval_update_counter += 1
# Throttle: send an update every 10 prediction steps
if self._eval_update_counter % 10 != 0:
return
total_steps = state.max_steps if state.max_steps > 0 else 0
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=total_steps,
current_epoch=float(state.epoch or 0),
total_epochs=float(parent.total_epochs),
progress_percent=float(progress),
status="training",
message=f"Evaluating... (batch {self._eval_update_counter})",
)
parent.progress_queue.put(update)
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
"""Report eval results once evaluation is done."""
# Reset prediction counter for next eval round
self._eval_update_counter = 0
total_steps = state.max_steps if state.max_steps > 0 else 0
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
eval_loss = 0.0
extra_metrics = {}
if metrics:
eval_loss = float(metrics.get('eval_loss', 0))
for k, v in metrics.items():
if isinstance(v, (int, float)) and k not in ('eval_loss', 'epoch'):
extra_metrics[k] = float(v)
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=total_steps,
current_epoch=float(state.epoch or 0),
total_epochs=float(parent.total_epochs),
eval_loss=eval_loss,
progress_percent=float(progress),
status="training",
message=f"Evaluation complete at step {state.global_step}",
extra_metrics=extra_metrics,
)
parent.progress_queue.put(update)
def on_save(self, args, state, control, **kwargs):
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
status="saving",
message=f"Checkpoint saved at step {state.global_step}",
checkpoint_path=checkpoint_path,
)
parent.progress_queue.put(update)
def on_train_end(self, args, state, control, **kwargs):
update = backend_pb2.FineTuneProgressUpdate(
job_id=parent.job_id,
current_step=state.global_step,
total_steps=state.max_steps,
progress_percent=100.0,
status="completed",
message="Training completed",
)
parent.progress_queue.put(update)
return _Callback()
class ActiveJob:
"""Represents an active fine-tuning job."""
def __init__(self, job_id):
self.job_id = job_id
self.progress_queue = queue.Queue()
self.trainer = None
self.thread = None
self.model = None
self.tokenizer = None
self.error = None
self.completed = False
self.stopped = False
def _is_gated_repo_error(exc):
"""Check if an exception is caused by a gated HuggingFace repo requiring authentication."""
try:
from huggingface_hub.utils import GatedRepoError
if isinstance(exc, GatedRepoError):
return True
except ImportError:
pass
msg = str(exc).lower()
if "gated repo" in msg or "access to model" in msg:
return True
if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'):
if exc.response.status_code in (401, 403):
return True
return False
class BackendServicer(backend_pb2_grpc.BackendServicer):
def __init__(self):
self.active_job = None
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
"""Accept LoadModel — actual model loading happens in StartFineTune."""
return backend_pb2.Result(success=True, message="OK")
def StartFineTune(self, request, context):
if self.active_job is not None and not self.active_job.completed:
return backend_pb2.FineTuneJobResult(
job_id="",
success=False,
message="A fine-tuning job is already running",
)
job_id = request.job_id if request.job_id else str(uuid.uuid4())
job = ActiveJob(job_id)
self.active_job = job
# Start training in background thread
thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True)
job.thread = thread
thread.start()
return backend_pb2.FineTuneJobResult(
job_id=job_id,
success=True,
message="Fine-tuning job started",
)
def _run_training(self, request, job):
try:
self._do_training(request, job)
except Exception as e:
if _is_gated_repo_error(e):
msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
else:
msg = f"Training failed: {e}"
job.error = msg
job.completed = True
update = backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id,
status="failed",
message=msg,
)
job.progress_queue.put(update)
# Send sentinel
job.progress_queue.put(None)
def _do_training(self, request, job):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset
extra = dict(request.extra_options)
training_method = request.training_method or "sft"
training_type = request.training_type or "lora"
# Send loading status
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}",
))
# Determine device and dtype
device_map = "auto" if torch.cuda.is_available() else "cpu"
dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16
# HuggingFace token for gated repos (from extra_options or HF_TOKEN env)
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
# Load model
model_kwargs = {"device_map": device_map, "torch_dtype": dtype}
if hf_token:
model_kwargs["token"] = hf_token
if extra.get("trust_remote_code", "false").lower() == "true":
model_kwargs["trust_remote_code"] = True
if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available():
from transformers import BitsAndBytesConfig
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
job.model = model
job.tokenizer = tokenizer
# Apply LoRA if requested
if training_type == "lora":
from peft import LoraConfig, get_peft_model
lora_r = request.adapter_rank if request.adapter_rank > 0 else 16
lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16
lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0
target_modules = list(request.target_modules) if request.target_modules else None
peft_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=target_modules or "all-linear",
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
# Load dataset
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id, status="loading_dataset", message="Loading dataset",
))
dataset_split = request.dataset_split or "train"
if os.path.exists(request.dataset_source):
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split)
elif request.dataset_source.endswith('.csv'):
dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split)
else:
dataset = load_dataset(request.dataset_source, split=dataset_split)
else:
dataset = load_dataset(request.dataset_source, split=dataset_split)
# Eval dataset setup
eval_dataset = None
eval_strategy = extra.get("eval_strategy", "steps")
eval_steps = int(extra.get("eval_steps", str(request.save_steps if request.save_steps > 0 else 500)))
if eval_strategy != "no":
eval_split = extra.get("eval_split")
eval_dataset_source = extra.get("eval_dataset_source")
if eval_split:
# Load a specific split as eval dataset
if os.path.exists(request.dataset_source):
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
eval_dataset = load_dataset("json", data_files=request.dataset_source, split=eval_split)
elif request.dataset_source.endswith('.csv'):
eval_dataset = load_dataset("csv", data_files=request.dataset_source, split=eval_split)
else:
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
else:
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
elif eval_dataset_source:
# Load eval dataset from a separate source
eval_dataset = load_dataset(eval_dataset_source, split="train")
else:
# Auto-split from training set
eval_split_ratio = float(extra.get("eval_split_ratio", "0.1"))
split = dataset.train_test_split(test_size=eval_split_ratio)
dataset = split["train"]
eval_dataset = split["test"]
if eval_strategy == "no":
eval_dataset = None
# Training config
output_dir = request.output_dir or f"./output-{job.job_id}"
num_epochs = request.num_epochs if request.num_epochs > 0 else 3
batch_size = request.batch_size if request.batch_size > 0 else 2
lr = request.learning_rate if request.learning_rate > 0 else 2e-4
grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4
warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5
weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01
max_steps = request.max_steps if request.max_steps > 0 else -1
save_steps = request.save_steps if request.save_steps > 0 else 500
seed = request.seed if request.seed > 0 else 3407
optimizer = request.optimizer or "adamw_torch"
# Checkpoint save controls
save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited
save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no
# CPU vs GPU training args (can be overridden via extra_options)
use_cpu = not torch.cuda.is_available()
common_train_kwargs = {}
if use_cpu:
common_train_kwargs["use_cpu"] = True
common_train_kwargs["fp16"] = False
common_train_kwargs["bf16"] = False
common_train_kwargs["gradient_checkpointing"] = False
else:
common_train_kwargs["bf16"] = True
common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing
# Allow extra_options to override training kwargs
for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"):
if flag in extra:
common_train_kwargs[flag] = extra[flag].lower() == "true"
# Create progress callback
progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs)
# Build save kwargs (shared across all methods)
_save_kwargs = {}
if save_strategy == "steps" and save_steps > 0:
_save_kwargs["save_steps"] = save_steps
_save_kwargs["save_strategy"] = "steps"
elif save_strategy == "epoch":
_save_kwargs["save_strategy"] = "epoch"
elif save_strategy == "no":
_save_kwargs["save_strategy"] = "no"
else:
_save_kwargs["save_steps"] = save_steps
_save_kwargs["save_strategy"] = "steps"
if save_total_limit:
_save_kwargs["save_total_limit"] = save_total_limit
# Eval kwargs
_eval_kwargs = {}
if eval_dataset is not None:
_eval_kwargs["eval_strategy"] = eval_strategy
_eval_kwargs["eval_steps"] = eval_steps
# Common training arguments shared by all methods
_common_args = dict(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
learning_rate=lr,
gradient_accumulation_steps=grad_accum,
warmup_steps=warmup_steps,
weight_decay=weight_decay,
max_steps=max_steps,
seed=seed,
optim=optimizer,
logging_steps=1,
report_to="none",
**_save_kwargs,
**common_train_kwargs,
**_eval_kwargs,
)
# Select trainer based on training method
if training_method == "sft":
from trl import SFTTrainer, SFTConfig
max_length = int(extra.get("max_seq_length", "512"))
packing = extra.get("packing", "false").lower() == "true"
training_args = SFTConfig(
max_length=max_length,
packing=packing,
**_common_args,
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "dpo":
from trl import DPOTrainer, DPOConfig
beta = float(extra.get("beta", "0.1"))
loss_type = extra.get("loss_type", "sigmoid")
max_length = int(extra.get("max_length", "512"))
training_args = DPOConfig(
beta=beta,
loss_type=loss_type,
max_length=max_length,
**_common_args,
)
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "grpo":
from trl import GRPOTrainer, GRPOConfig
num_generations = int(extra.get("num_generations", "4"))
max_completion_length = int(extra.get("max_completion_length", "256"))
training_args = GRPOConfig(
num_generations=num_generations,
max_completion_length=max_completion_length,
**_common_args,
)
# GRPO requires reward functions passed via extra_options as a JSON list
from reward_functions import build_reward_functions
reward_funcs = []
if extra.get("reward_funcs"):
reward_funcs = build_reward_functions(extra["reward_funcs"])
if not reward_funcs:
raise ValueError(
"GRPO requires at least one reward function. "
"Specify reward_functions in the request or "
"reward_funcs in extra_options."
)
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
reward_funcs=reward_funcs,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "orpo":
from trl import ORPOTrainer, ORPOConfig
beta = float(extra.get("beta", "0.1"))
max_length = int(extra.get("max_length", "512"))
training_args = ORPOConfig(
beta=beta,
max_length=max_length,
**_common_args,
)
trainer = ORPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "kto":
from trl import KTOTrainer, KTOConfig
beta = float(extra.get("beta", "0.1"))
max_length = int(extra.get("max_length", "512"))
training_args = KTOConfig(
beta=beta,
max_length=max_length,
**_common_args,
)
trainer = KTOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "rloo":
from trl import RLOOTrainer, RLOOConfig
num_generations = int(extra.get("num_generations", "4"))
max_completion_length = int(extra.get("max_completion_length", "256"))
training_args = RLOOConfig(
num_generations=num_generations,
max_new_tokens=max_completion_length,
**_common_args,
)
trainer = RLOOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
elif training_method == "reward":
from trl import RewardTrainer, RewardConfig
max_length = int(extra.get("max_length", "512"))
training_args = RewardConfig(
max_length=max_length,
**_common_args,
)
trainer = RewardTrainer(
model=model,
args=training_args,
train_dataset=dataset,
eval_dataset=eval_dataset,
processing_class=tokenizer,
callbacks=[progress_cb.get_callback()],
)
else:
raise ValueError(f"Unsupported training method: {training_method}. "
"Supported: sft, dpo, grpo, orpo, kto, rloo, reward")
job.trainer = trainer
# Start training
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
job_id=job.job_id, status="training", message="Training started",
))
resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None
trainer.train(resume_from_checkpoint=resume_ckpt)
# Save final model
trainer.save_model(output_dir)
if tokenizer:
tokenizer.save_pretrained(output_dir)
job.completed = True
# Sentinel to signal stream end
job.progress_queue.put(None)
def FineTuneProgress(self, request, context):
if self.active_job is None or self.active_job.job_id != request.job_id:
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Job {request.job_id} not found")
return
job = self.active_job
while True:
try:
update = job.progress_queue.get(timeout=1.0)
if update is None:
break
yield update
if update.status in ("completed", "failed", "stopped"):
break
except queue.Empty:
if job.completed or job.stopped:
break
if not context.is_active():
break
continue
def StopFineTune(self, request, context):
# Stopping is handled by killing the process from Go via ShutdownModel.
return backend_pb2.Result(success=True, message="OK")
def ListCheckpoints(self, request, context):
output_dir = request.output_dir
if not os.path.isdir(output_dir):
return backend_pb2.ListCheckpointsResponse(checkpoints=[])
checkpoints = []
for entry in sorted(os.listdir(output_dir)):
if entry.startswith("checkpoint-"):
ckpt_path = os.path.join(output_dir, entry)
if not os.path.isdir(ckpt_path):
continue
step = 0
try:
step = int(entry.split("-")[1])
except (IndexError, ValueError):
pass
# Try to read trainer_state.json for metadata
loss = 0.0
epoch = 0.0
state_file = os.path.join(ckpt_path, "trainer_state.json")
if os.path.exists(state_file):
try:
with open(state_file) as f:
state = json.load(f)
if state.get("log_history"):
last_log = state["log_history"][-1]
loss = last_log.get("loss", 0.0)
epoch = last_log.get("epoch", 0.0)
except Exception:
pass
created_at = time.strftime(
"%Y-%m-%dT%H:%M:%SZ",
time.gmtime(os.path.getmtime(ckpt_path)),
)
checkpoints.append(backend_pb2.CheckpointInfo(
path=ckpt_path,
step=step,
epoch=float(epoch),
loss=float(loss),
created_at=created_at,
))
return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints)
def ExportModel(self, request, context):
export_format = request.export_format or "lora"
output_path = request.output_path
checkpoint_path = request.checkpoint_path
# Extract HF token for gated model access
extra = dict(request.extra_options) if request.extra_options else {}
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
if not checkpoint_path or not os.path.isdir(checkpoint_path):
return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}")
os.makedirs(output_path, exist_ok=True)
try:
if export_format == "lora":
# Just copy the adapter files
import shutil
for f in os.listdir(checkpoint_path):
src = os.path.join(checkpoint_path, f)
dst = os.path.join(output_path, f)
if os.path.isfile(src):
shutil.copy2(src, dst)
elif export_format in ("merged_16bit", "merged_4bit"):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
base_model_name = request.model
if not base_model_name:
return backend_pb2.Result(success=False, message="Base model name required for merge export")
dtype = torch.float16 if export_format == "merged_16bit" else torch.float32
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token)
model = PeftModel.from_pretrained(base_model, checkpoint_path)
merged = model.merge_and_unload()
merged.save_pretrained(output_path)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
tokenizer.save_pretrained(output_path)
elif export_format == "gguf":
import torch
import subprocess
import shutil
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
base_model_name = request.model
if not base_model_name:
return backend_pb2.Result(success=False, message="Base model name required for GGUF export")
# Step 1: Merge LoRA into base model
merge_dir = os.path.join(output_path, "_hf_merged")
os.makedirs(merge_dir, exist_ok=True)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token)
model = PeftModel.from_pretrained(base_model, checkpoint_path)
merged = model.merge_and_unload()
merged.save_pretrained(merge_dir)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
tokenizer.save_pretrained(merge_dir)
# Ensure tokenizer.model (SentencePiece) is present in merge_dir.
# Gemma models need this file for GGUF conversion to use the
# SentencePiece path; without it, the script falls back to BPE
# handling which fails on unrecognized pre-tokenizer hashes.
sp_model_path = os.path.join(merge_dir, "tokenizer.model")
if not os.path.exists(sp_model_path):
sp_copied = False
# Method 1: Load the slow tokenizer which keeps the SP model file
try:
slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token)
if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file):
import shutil as _shutil
_shutil.copy2(slow_tok.vocab_file, sp_model_path)
sp_copied = True
print(f"Copied tokenizer.model from slow tokenizer cache")
except Exception as e:
print(f"Slow tokenizer method failed: {e}")
# Method 2: Download from HF hub
if not sp_copied:
try:
from huggingface_hub import hf_hub_download
cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token)
import shutil as _shutil
_shutil.copy2(cached_sp, sp_model_path)
sp_copied = True
print(f"Copied tokenizer.model from HF hub")
except Exception as e:
print(f"HF hub download method failed: {e}")
if not sp_copied:
print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. "
"GGUF conversion may fail for SentencePiece models.")
# Free GPU memory before conversion
del merged, model, base_model
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Step 2: Convert to GGUF using convert_hf_to_gguf.py
quant = request.quantization_method or "auto"
outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"}
outtype = outtype_map.get(quant, "f16")
gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf"
gguf_path = os.path.join(output_path, gguf_filename)
script_dir = os.path.dirname(os.path.abspath(__file__))
convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py")
if not os.path.exists(convert_script):
return backend_pb2.Result(success=False,
message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.")
# Log merge_dir contents for debugging conversion issues
merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else []
print(f"Merge dir contents: {merge_files}", flush=True)
env = os.environ.copy()
env["NO_LOCAL_GGUF"] = "1"
cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path]
conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env)
if conv_result.returncode != 0:
diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}"
return backend_pb2.Result(success=False,
message=f"GGUF conversion failed: {diag}")
# Clean up intermediate merged model
shutil.rmtree(merge_dir, ignore_errors=True)
else:
return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}")
except Exception as e:
if _is_gated_repo_error(e):
return backend_pb2.Result(success=False,
message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
return backend_pb2.Result(success=False, message=f"Export failed: {e}")
return backend_pb2.Result(success=True, message=f"Model exported to {output_path}")
def serve(address):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
],
)
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True)
# Handle graceful shutdown
def stop(signum, frame):
server.stop(0)
sys.exit(0)
signal.signal(signal.SIGTERM, stop)
signal.signal(signal.SIGINT, stop)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend")
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
args = parser.parse_args()
serve(args.addr)

View File

@@ -0,0 +1,37 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
installRequirements
# Fetch convert_hf_to_gguf.py and gguf package from the same llama.cpp version
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
if [ ! -f "${CONVERT_SCRIPT}" ]; then
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
curl -L --fail --retry 3 \
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py. GGUF export will not be available."
fi
# Install gguf package from the same llama.cpp commit to keep them in sync
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
if [ "x${USE_PIP:-}" == "xtrue" ]; then
pip install "${GGUF_PIP_SPEC}" || {
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
pip install "gguf>=0.16.0"
}
else
uv pip install "${GGUF_PIP_SPEC}" || {
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
uv pip install "gguf>=0.16.0"
}
fi

View File

@@ -0,0 +1,9 @@
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.10.0
trl
peft
datasets>=3.0.0
transformers>=4.56.2
accelerate>=1.4.0
huggingface-hub>=1.3.0
sentencepiece

View File

@@ -0,0 +1,9 @@
torch==2.10.0
trl
peft
datasets>=3.0.0
transformers>=4.56.2
accelerate>=1.4.0
huggingface-hub>=1.3.0
sentencepiece
bitsandbytes

View File

@@ -0,0 +1,9 @@
torch==2.10.0
trl
peft
datasets>=3.0.0
transformers>=4.56.2
accelerate>=1.4.0
huggingface-hub>=1.3.0
sentencepiece
bitsandbytes

View File

@@ -0,0 +1,3 @@
grpcio==1.78.1
protobuf
certifi

View File

@@ -0,0 +1,236 @@
"""
Built-in reward functions and inline function compiler for GRPO training.
All reward functions follow TRL's signature: (completions, **kwargs) -> list[float]
"""
import json
import re
import math
import string
import functools
# ---------------------------------------------------------------------------
# Built-in reward functions
# ---------------------------------------------------------------------------
def format_reward(completions, **kwargs):
"""Checks for <think>...</think> followed by an answer. Returns 1.0 or 0.0."""
pattern = re.compile(r"<think>.*?</think>\s*\S", re.DOTALL)
return [1.0 if pattern.search(c) else 0.0 for c in completions]
def reasoning_accuracy_reward(completions, **kwargs):
"""Extracts <answer>...</answer> content and compares to the expected answer."""
answers = kwargs.get("answer", [])
if not answers:
return [0.0] * len(completions)
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
scores = []
for i, c in enumerate(completions):
expected = answers[i] if i < len(answers) else ""
match = pattern.search(c)
if match:
extracted = match.group(1).strip()
scores.append(1.0 if extracted.lower() == str(expected).strip().lower() else 0.0)
else:
scores.append(0.0)
return scores
def length_reward(completions, target_length=200, **kwargs):
"""Score based on proximity to target_length. Returns [0, 1]."""
scores = []
for c in completions:
length = len(c)
if target_length <= 0:
scores.append(0.0)
else:
diff = abs(length - target_length) / target_length
scores.append(max(0.0, 1.0 - diff))
return scores
def xml_tag_reward(completions, **kwargs):
"""Scores properly opened/closed XML tags (<think>, <answer>)."""
tags = ["think", "answer"]
scores = []
for c in completions:
tag_score = 0.0
for tag in tags:
if f"<{tag}>" in c and f"</{tag}>" in c:
tag_score += 0.5
scores.append(min(tag_score, 1.0))
return scores
def no_repetition_reward(completions, n=4, **kwargs):
"""Penalizes n-gram repetition. Returns [0, 1]."""
scores = []
for c in completions:
words = c.split()
if len(words) < n:
scores.append(1.0)
continue
ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
unique = len(set(ngrams))
total = len(ngrams)
scores.append(unique / total if total > 0 else 1.0)
return scores
def code_execution_reward(completions, **kwargs):
"""Checks Python code block syntax validity via compile(). Returns 1.0 or 0.0."""
pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL)
scores = []
for c in completions:
match = pattern.search(c)
if not match:
scores.append(0.0)
continue
code = match.group(1)
try:
compile(code, "<inline>", "exec")
scores.append(1.0)
except SyntaxError:
scores.append(0.0)
return scores
# ---------------------------------------------------------------------------
# Registry
# ---------------------------------------------------------------------------
BUILTIN_REGISTRY = {
"format_reward": format_reward,
"reasoning_accuracy_reward": reasoning_accuracy_reward,
"length_reward": length_reward,
"xml_tag_reward": xml_tag_reward,
"no_repetition_reward": no_repetition_reward,
"code_execution_reward": code_execution_reward,
}
# ---------------------------------------------------------------------------
# Inline function compiler
# ---------------------------------------------------------------------------
_SAFE_BUILTINS = {
"len": len, "int": int, "float": float, "str": str, "bool": bool,
"list": list, "dict": dict, "tuple": tuple, "set": set,
"range": range, "enumerate": enumerate, "zip": zip,
"map": map, "filter": filter, "sorted": sorted,
"min": min, "max": max, "sum": sum, "abs": abs, "round": round,
"any": any, "all": all, "isinstance": isinstance, "type": type,
"print": print, "True": True, "False": False, "None": None,
"ValueError": ValueError, "TypeError": TypeError,
"KeyError": KeyError, "IndexError": IndexError,
}
def compile_inline_reward(name, code):
"""Compile user-provided code into a reward function.
The code should be the body of a function that receives
`completions` (list[str]) and `**kwargs`, and returns list[float].
Available modules: re, math, json, string.
"""
func_source = (
f"def _user_reward_{name}(completions, **kwargs):\n"
+ "\n".join(f" {line}" for line in code.splitlines())
)
restricted_globals = {
"__builtins__": _SAFE_BUILTINS,
"re": re,
"math": math,
"json": json,
"string": string,
}
try:
compiled = compile(func_source, f"<inline-reward-{name}>", "exec")
except SyntaxError as e:
raise ValueError(f"Syntax error in inline reward function '{name}': {e}")
exec(compiled, restricted_globals)
func = restricted_globals[f"_user_reward_{name}"]
# Validate with a quick smoke test
try:
result = func(["test"], answer=["test"])
if not isinstance(result, list):
raise ValueError(
f"Inline reward function '{name}' must return a list, got {type(result).__name__}"
)
except Exception as e:
if "must return a list" in str(e):
raise
# Other errors during smoke test are acceptable (e.g. missing kwargs)
pass
return func
# ---------------------------------------------------------------------------
# Dispatcher
# ---------------------------------------------------------------------------
def build_reward_functions(specs_json):
"""Parse a JSON list of reward function specs and return a list of callables.
Each spec is a dict with:
- type: "builtin" or "inline"
- name: function name
- code: (inline only) Python function body
- params: (optional) dict of string params applied via functools.partial
"""
if isinstance(specs_json, str):
specs = json.loads(specs_json)
else:
specs = specs_json
if not isinstance(specs, list):
raise ValueError("reward_funcs must be a JSON array of reward function specs")
reward_funcs = []
for spec in specs:
spec_type = spec.get("type", "builtin")
name = spec.get("name", "")
params = spec.get("params", {})
if spec_type == "builtin":
if name not in BUILTIN_REGISTRY:
available = ", ".join(sorted(BUILTIN_REGISTRY.keys()))
raise ValueError(
f"Unknown builtin reward function '{name}'. Available: {available}"
)
func = BUILTIN_REGISTRY[name]
if params:
# Convert string params to appropriate types
typed_params = {}
for k, v in params.items():
try:
typed_params[k] = int(v)
except (ValueError, TypeError):
try:
typed_params[k] = float(v)
except (ValueError, TypeError):
typed_params[k] = v
func = functools.partial(func, **typed_params)
reward_funcs.append(func)
elif spec_type == "inline":
code = spec.get("code", "")
if not code.strip():
raise ValueError(f"Inline reward function '{name}' has no code")
func = compile_inline_reward(name, code)
reward_funcs.append(func)
else:
raise ValueError(f"Unknown reward function type '{spec_type}'. Use 'builtin' or 'inline'")
return reward_funcs

10
backend/python/trl/run.sh Normal file
View File

@@ -0,0 +1,10 @@
#!/bin/bash
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
startBackend $@

View File

@@ -0,0 +1,58 @@
"""
Test script for the TRL fine-tuning gRPC backend.
"""
import unittest
import subprocess
import time
import grpc
import backend_pb2
import backend_pb2_grpc
class TestBackendServicer(unittest.TestCase):
"""Tests for the TRL fine-tuning gRPC service."""
def setUp(self):
self.service = subprocess.Popen(
["python3", "backend.py", "--addr", "localhost:50051"]
)
time.sleep(10)
def tearDown(self):
self.service.kill()
self.service.wait()
def test_server_startup(self):
"""Test that the server starts and responds to health checks."""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.Health(backend_pb2.HealthMessage())
self.assertEqual(response.message, b'OK')
except Exception as err:
print(err)
self.fail("Server failed to start")
finally:
self.tearDown()
def test_list_checkpoints_empty(self):
"""Test listing checkpoints on a non-existent directory."""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
response = stub.ListCheckpoints(
backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
)
self.assertEqual(len(response.checkpoints), 0)
except Exception as err:
print(err)
self.fail("ListCheckpoints service failed")
finally:
self.tearDown()
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,11 @@
#!/bin/bash
set -e
backend_dir=$(dirname $0)
if [ -d $backend_dir/common ]; then
source $backend_dir/common/libbackend.sh
else
source $backend_dir/../common/libbackend.sh
fi
runUnittests