mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-31 21:25:59 -04:00
feat: add (experimental) fine-tuning support with TRL (#9088)
* feat: add fine-tuning endpoint Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(experimental): add fine-tuning endpoint and TRL support This changeset defines new GRPC signatues for Fine tuning backends, and add TRL backend as initial fine-tuning engine. This implementation also supports exporting to GGUF and automatically importing it to LocalAI after fine-tuning. Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * commit TRL backend, stop by killing process Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * move fine-tune to generic features Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * add evals, reorder menu Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Fix tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
f7e3aab4fc
commit
d9c1db2b87
26
backend/python/trl/Makefile
Normal file
26
backend/python/trl/Makefile
Normal file
@@ -0,0 +1,26 @@
|
||||
# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export)
|
||||
LLAMA_CPP_CONVERT_VERSION ?= master
|
||||
|
||||
.PHONY: trl
|
||||
trl:
|
||||
LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: trl
|
||||
@echo "Running trl..."
|
||||
bash run.sh
|
||||
@echo "trl run."
|
||||
|
||||
.PHONY: test
|
||||
test: trl
|
||||
@echo "Testing trl..."
|
||||
bash test.sh
|
||||
@echo "trl tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
860
backend/python/trl/backend.py
Normal file
860
backend/python/trl/backend.py
Normal file
@@ -0,0 +1,860 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TRL fine-tuning backend for LocalAI.
|
||||
|
||||
Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO)
|
||||
using standard HuggingFace transformers + PEFT. Works on both CPU and GPU.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent import futures
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '4'))
|
||||
|
||||
|
||||
class ProgressCallback:
|
||||
"""HuggingFace TrainerCallback that pushes progress updates to a queue."""
|
||||
|
||||
def __init__(self, job_id, progress_queue, total_epochs):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = progress_queue
|
||||
self.total_epochs = total_epochs
|
||||
|
||||
def get_callback(self):
|
||||
from transformers import TrainerCallback
|
||||
|
||||
parent = self
|
||||
|
||||
class _Callback(TrainerCallback):
|
||||
def __init__(self):
|
||||
self._train_start_time = None
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
self._train_start_time = time.time()
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if logs is None:
|
||||
return
|
||||
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||
eta = 0.0
|
||||
if state.global_step > 0 and total_steps > 0 and self._train_start_time:
|
||||
elapsed = time.time() - self._train_start_time
|
||||
remaining_steps = total_steps - state.global_step
|
||||
if state.global_step > 0:
|
||||
eta = remaining_steps * (elapsed / state.global_step)
|
||||
|
||||
extra_metrics = {}
|
||||
for k, v in logs.items():
|
||||
if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'):
|
||||
extra_metrics[k] = float(v)
|
||||
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=total_steps,
|
||||
current_epoch=float(logs.get('epoch', 0)),
|
||||
total_epochs=float(parent.total_epochs),
|
||||
loss=float(logs.get('loss', 0)),
|
||||
learning_rate=float(logs.get('learning_rate', 0)),
|
||||
grad_norm=float(logs.get('grad_norm', 0)),
|
||||
eval_loss=float(logs.get('eval_loss', 0)),
|
||||
eta_seconds=float(eta),
|
||||
progress_percent=float(progress),
|
||||
status="training",
|
||||
extra_metrics=extra_metrics,
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_prediction_step(self, args, state, control, **kwargs):
|
||||
"""Send periodic updates during evaluation so the UI doesn't freeze."""
|
||||
if not hasattr(self, '_eval_update_counter'):
|
||||
self._eval_update_counter = 0
|
||||
self._eval_update_counter += 1
|
||||
# Throttle: send an update every 10 prediction steps
|
||||
if self._eval_update_counter % 10 != 0:
|
||||
return
|
||||
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=total_steps,
|
||||
current_epoch=float(state.epoch or 0),
|
||||
total_epochs=float(parent.total_epochs),
|
||||
progress_percent=float(progress),
|
||||
status="training",
|
||||
message=f"Evaluating... (batch {self._eval_update_counter})",
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
|
||||
"""Report eval results once evaluation is done."""
|
||||
# Reset prediction counter for next eval round
|
||||
self._eval_update_counter = 0
|
||||
|
||||
total_steps = state.max_steps if state.max_steps > 0 else 0
|
||||
progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
|
||||
|
||||
eval_loss = 0.0
|
||||
extra_metrics = {}
|
||||
if metrics:
|
||||
eval_loss = float(metrics.get('eval_loss', 0))
|
||||
for k, v in metrics.items():
|
||||
if isinstance(v, (int, float)) and k not in ('eval_loss', 'epoch'):
|
||||
extra_metrics[k] = float(v)
|
||||
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=total_steps,
|
||||
current_epoch=float(state.epoch or 0),
|
||||
total_epochs=float(parent.total_epochs),
|
||||
eval_loss=eval_loss,
|
||||
progress_percent=float(progress),
|
||||
status="training",
|
||||
message=f"Evaluation complete at step {state.global_step}",
|
||||
extra_metrics=extra_metrics,
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
status="saving",
|
||||
message=f"Checkpoint saved at step {state.global_step}",
|
||||
checkpoint_path=checkpoint_path,
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=parent.job_id,
|
||||
current_step=state.global_step,
|
||||
total_steps=state.max_steps,
|
||||
progress_percent=100.0,
|
||||
status="completed",
|
||||
message="Training completed",
|
||||
)
|
||||
parent.progress_queue.put(update)
|
||||
|
||||
return _Callback()
|
||||
|
||||
|
||||
class ActiveJob:
|
||||
"""Represents an active fine-tuning job."""
|
||||
|
||||
def __init__(self, job_id):
|
||||
self.job_id = job_id
|
||||
self.progress_queue = queue.Queue()
|
||||
self.trainer = None
|
||||
self.thread = None
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.error = None
|
||||
self.completed = False
|
||||
self.stopped = False
|
||||
|
||||
|
||||
def _is_gated_repo_error(exc):
|
||||
"""Check if an exception is caused by a gated HuggingFace repo requiring authentication."""
|
||||
try:
|
||||
from huggingface_hub.utils import GatedRepoError
|
||||
if isinstance(exc, GatedRepoError):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
msg = str(exc).lower()
|
||||
if "gated repo" in msg or "access to model" in msg:
|
||||
return True
|
||||
if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'):
|
||||
if exc.response.status_code in (401, 403):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def __init__(self):
|
||||
self.active_job = None
|
||||
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
"""Accept LoadModel — actual model loading happens in StartFineTune."""
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def StartFineTune(self, request, context):
|
||||
if self.active_job is not None and not self.active_job.completed:
|
||||
return backend_pb2.FineTuneJobResult(
|
||||
job_id="",
|
||||
success=False,
|
||||
message="A fine-tuning job is already running",
|
||||
)
|
||||
|
||||
job_id = request.job_id if request.job_id else str(uuid.uuid4())
|
||||
job = ActiveJob(job_id)
|
||||
self.active_job = job
|
||||
|
||||
# Start training in background thread
|
||||
thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True)
|
||||
job.thread = thread
|
||||
thread.start()
|
||||
|
||||
return backend_pb2.FineTuneJobResult(
|
||||
job_id=job_id,
|
||||
success=True,
|
||||
message="Fine-tuning job started",
|
||||
)
|
||||
|
||||
def _run_training(self, request, job):
|
||||
try:
|
||||
self._do_training(request, job)
|
||||
except Exception as e:
|
||||
if _is_gated_repo_error(e):
|
||||
msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
|
||||
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
|
||||
else:
|
||||
msg = f"Training failed: {e}"
|
||||
job.error = msg
|
||||
job.completed = True
|
||||
update = backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id,
|
||||
status="failed",
|
||||
message=msg,
|
||||
)
|
||||
job.progress_queue.put(update)
|
||||
# Send sentinel
|
||||
job.progress_queue.put(None)
|
||||
|
||||
def _do_training(self, request, job):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
extra = dict(request.extra_options)
|
||||
training_method = request.training_method or "sft"
|
||||
training_type = request.training_type or "lora"
|
||||
|
||||
# Send loading status
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}",
|
||||
))
|
||||
|
||||
# Determine device and dtype
|
||||
device_map = "auto" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16
|
||||
|
||||
# HuggingFace token for gated repos (from extra_options or HF_TOKEN env)
|
||||
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||
|
||||
# Load model
|
||||
model_kwargs = {"device_map": device_map, "torch_dtype": dtype}
|
||||
if hf_token:
|
||||
model_kwargs["token"] = hf_token
|
||||
if extra.get("trust_remote_code", "false").lower() == "true":
|
||||
model_kwargs["trust_remote_code"] = True
|
||||
if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available():
|
||||
from transformers import BitsAndBytesConfig
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
job.model = model
|
||||
job.tokenizer = tokenizer
|
||||
|
||||
# Apply LoRA if requested
|
||||
if training_type == "lora":
|
||||
from peft import LoraConfig, get_peft_model
|
||||
lora_r = request.adapter_rank if request.adapter_rank > 0 else 16
|
||||
lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16
|
||||
lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0
|
||||
|
||||
target_modules = list(request.target_modules) if request.target_modules else None
|
||||
peft_config = LoraConfig(
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
target_modules=target_modules or "all-linear",
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
# Load dataset
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="loading_dataset", message="Loading dataset",
|
||||
))
|
||||
|
||||
dataset_split = request.dataset_split or "train"
|
||||
if os.path.exists(request.dataset_source):
|
||||
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
|
||||
dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split)
|
||||
elif request.dataset_source.endswith('.csv'):
|
||||
dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split)
|
||||
else:
|
||||
dataset = load_dataset(request.dataset_source, split=dataset_split)
|
||||
else:
|
||||
dataset = load_dataset(request.dataset_source, split=dataset_split)
|
||||
|
||||
# Eval dataset setup
|
||||
eval_dataset = None
|
||||
eval_strategy = extra.get("eval_strategy", "steps")
|
||||
eval_steps = int(extra.get("eval_steps", str(request.save_steps if request.save_steps > 0 else 500)))
|
||||
|
||||
if eval_strategy != "no":
|
||||
eval_split = extra.get("eval_split")
|
||||
eval_dataset_source = extra.get("eval_dataset_source")
|
||||
if eval_split:
|
||||
# Load a specific split as eval dataset
|
||||
if os.path.exists(request.dataset_source):
|
||||
if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
|
||||
eval_dataset = load_dataset("json", data_files=request.dataset_source, split=eval_split)
|
||||
elif request.dataset_source.endswith('.csv'):
|
||||
eval_dataset = load_dataset("csv", data_files=request.dataset_source, split=eval_split)
|
||||
else:
|
||||
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
|
||||
else:
|
||||
eval_dataset = load_dataset(request.dataset_source, split=eval_split)
|
||||
elif eval_dataset_source:
|
||||
# Load eval dataset from a separate source
|
||||
eval_dataset = load_dataset(eval_dataset_source, split="train")
|
||||
else:
|
||||
# Auto-split from training set
|
||||
eval_split_ratio = float(extra.get("eval_split_ratio", "0.1"))
|
||||
split = dataset.train_test_split(test_size=eval_split_ratio)
|
||||
dataset = split["train"]
|
||||
eval_dataset = split["test"]
|
||||
|
||||
if eval_strategy == "no":
|
||||
eval_dataset = None
|
||||
|
||||
# Training config
|
||||
output_dir = request.output_dir or f"./output-{job.job_id}"
|
||||
num_epochs = request.num_epochs if request.num_epochs > 0 else 3
|
||||
batch_size = request.batch_size if request.batch_size > 0 else 2
|
||||
lr = request.learning_rate if request.learning_rate > 0 else 2e-4
|
||||
grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4
|
||||
warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5
|
||||
weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01
|
||||
max_steps = request.max_steps if request.max_steps > 0 else -1
|
||||
save_steps = request.save_steps if request.save_steps > 0 else 500
|
||||
seed = request.seed if request.seed > 0 else 3407
|
||||
optimizer = request.optimizer or "adamw_torch"
|
||||
|
||||
# Checkpoint save controls
|
||||
save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited
|
||||
save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no
|
||||
|
||||
# CPU vs GPU training args (can be overridden via extra_options)
|
||||
use_cpu = not torch.cuda.is_available()
|
||||
common_train_kwargs = {}
|
||||
if use_cpu:
|
||||
common_train_kwargs["use_cpu"] = True
|
||||
common_train_kwargs["fp16"] = False
|
||||
common_train_kwargs["bf16"] = False
|
||||
common_train_kwargs["gradient_checkpointing"] = False
|
||||
else:
|
||||
common_train_kwargs["bf16"] = True
|
||||
common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing
|
||||
|
||||
# Allow extra_options to override training kwargs
|
||||
for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"):
|
||||
if flag in extra:
|
||||
common_train_kwargs[flag] = extra[flag].lower() == "true"
|
||||
|
||||
# Create progress callback
|
||||
progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs)
|
||||
|
||||
# Build save kwargs (shared across all methods)
|
||||
_save_kwargs = {}
|
||||
if save_strategy == "steps" and save_steps > 0:
|
||||
_save_kwargs["save_steps"] = save_steps
|
||||
_save_kwargs["save_strategy"] = "steps"
|
||||
elif save_strategy == "epoch":
|
||||
_save_kwargs["save_strategy"] = "epoch"
|
||||
elif save_strategy == "no":
|
||||
_save_kwargs["save_strategy"] = "no"
|
||||
else:
|
||||
_save_kwargs["save_steps"] = save_steps
|
||||
_save_kwargs["save_strategy"] = "steps"
|
||||
if save_total_limit:
|
||||
_save_kwargs["save_total_limit"] = save_total_limit
|
||||
|
||||
# Eval kwargs
|
||||
_eval_kwargs = {}
|
||||
if eval_dataset is not None:
|
||||
_eval_kwargs["eval_strategy"] = eval_strategy
|
||||
_eval_kwargs["eval_steps"] = eval_steps
|
||||
|
||||
# Common training arguments shared by all methods
|
||||
_common_args = dict(
|
||||
output_dir=output_dir,
|
||||
num_train_epochs=num_epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
learning_rate=lr,
|
||||
gradient_accumulation_steps=grad_accum,
|
||||
warmup_steps=warmup_steps,
|
||||
weight_decay=weight_decay,
|
||||
max_steps=max_steps,
|
||||
seed=seed,
|
||||
optim=optimizer,
|
||||
logging_steps=1,
|
||||
report_to="none",
|
||||
**_save_kwargs,
|
||||
**common_train_kwargs,
|
||||
**_eval_kwargs,
|
||||
)
|
||||
|
||||
# Select trainer based on training method
|
||||
if training_method == "sft":
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
max_length = int(extra.get("max_seq_length", "512"))
|
||||
packing = extra.get("packing", "false").lower() == "true"
|
||||
|
||||
training_args = SFTConfig(
|
||||
max_length=max_length,
|
||||
packing=packing,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "dpo":
|
||||
from trl import DPOTrainer, DPOConfig
|
||||
|
||||
beta = float(extra.get("beta", "0.1"))
|
||||
loss_type = extra.get("loss_type", "sigmoid")
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = DPOConfig(
|
||||
beta=beta,
|
||||
loss_type=loss_type,
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "grpo":
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
num_generations = int(extra.get("num_generations", "4"))
|
||||
max_completion_length = int(extra.get("max_completion_length", "256"))
|
||||
|
||||
training_args = GRPOConfig(
|
||||
num_generations=num_generations,
|
||||
max_completion_length=max_completion_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
# GRPO requires reward functions passed via extra_options as a JSON list
|
||||
from reward_functions import build_reward_functions
|
||||
|
||||
reward_funcs = []
|
||||
if extra.get("reward_funcs"):
|
||||
reward_funcs = build_reward_functions(extra["reward_funcs"])
|
||||
|
||||
if not reward_funcs:
|
||||
raise ValueError(
|
||||
"GRPO requires at least one reward function. "
|
||||
"Specify reward_functions in the request or "
|
||||
"reward_funcs in extra_options."
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=reward_funcs,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "orpo":
|
||||
from trl import ORPOTrainer, ORPOConfig
|
||||
|
||||
beta = float(extra.get("beta", "0.1"))
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = ORPOConfig(
|
||||
beta=beta,
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = ORPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "kto":
|
||||
from trl import KTOTrainer, KTOConfig
|
||||
|
||||
beta = float(extra.get("beta", "0.1"))
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = KTOConfig(
|
||||
beta=beta,
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = KTOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "rloo":
|
||||
from trl import RLOOTrainer, RLOOConfig
|
||||
|
||||
num_generations = int(extra.get("num_generations", "4"))
|
||||
max_completion_length = int(extra.get("max_completion_length", "256"))
|
||||
|
||||
training_args = RLOOConfig(
|
||||
num_generations=num_generations,
|
||||
max_new_tokens=max_completion_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = RLOOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
elif training_method == "reward":
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
|
||||
max_length = int(extra.get("max_length", "512"))
|
||||
|
||||
training_args = RewardConfig(
|
||||
max_length=max_length,
|
||||
**_common_args,
|
||||
)
|
||||
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[progress_cb.get_callback()],
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported training method: {training_method}. "
|
||||
"Supported: sft, dpo, grpo, orpo, kto, rloo, reward")
|
||||
|
||||
job.trainer = trainer
|
||||
|
||||
# Start training
|
||||
job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
|
||||
job_id=job.job_id, status="training", message="Training started",
|
||||
))
|
||||
|
||||
resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None
|
||||
trainer.train(resume_from_checkpoint=resume_ckpt)
|
||||
|
||||
# Save final model
|
||||
trainer.save_model(output_dir)
|
||||
if tokenizer:
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
job.completed = True
|
||||
# Sentinel to signal stream end
|
||||
job.progress_queue.put(None)
|
||||
|
||||
def FineTuneProgress(self, request, context):
|
||||
if self.active_job is None or self.active_job.job_id != request.job_id:
|
||||
context.set_code(grpc.StatusCode.NOT_FOUND)
|
||||
context.set_details(f"Job {request.job_id} not found")
|
||||
return
|
||||
|
||||
job = self.active_job
|
||||
while True:
|
||||
try:
|
||||
update = job.progress_queue.get(timeout=1.0)
|
||||
if update is None:
|
||||
break
|
||||
yield update
|
||||
if update.status in ("completed", "failed", "stopped"):
|
||||
break
|
||||
except queue.Empty:
|
||||
if job.completed or job.stopped:
|
||||
break
|
||||
if not context.is_active():
|
||||
break
|
||||
continue
|
||||
|
||||
def StopFineTune(self, request, context):
|
||||
# Stopping is handled by killing the process from Go via ShutdownModel.
|
||||
return backend_pb2.Result(success=True, message="OK")
|
||||
|
||||
def ListCheckpoints(self, request, context):
|
||||
output_dir = request.output_dir
|
||||
if not os.path.isdir(output_dir):
|
||||
return backend_pb2.ListCheckpointsResponse(checkpoints=[])
|
||||
|
||||
checkpoints = []
|
||||
for entry in sorted(os.listdir(output_dir)):
|
||||
if entry.startswith("checkpoint-"):
|
||||
ckpt_path = os.path.join(output_dir, entry)
|
||||
if not os.path.isdir(ckpt_path):
|
||||
continue
|
||||
step = 0
|
||||
try:
|
||||
step = int(entry.split("-")[1])
|
||||
except (IndexError, ValueError):
|
||||
pass
|
||||
|
||||
# Try to read trainer_state.json for metadata
|
||||
loss = 0.0
|
||||
epoch = 0.0
|
||||
state_file = os.path.join(ckpt_path, "trainer_state.json")
|
||||
if os.path.exists(state_file):
|
||||
try:
|
||||
with open(state_file) as f:
|
||||
state = json.load(f)
|
||||
if state.get("log_history"):
|
||||
last_log = state["log_history"][-1]
|
||||
loss = last_log.get("loss", 0.0)
|
||||
epoch = last_log.get("epoch", 0.0)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
created_at = time.strftime(
|
||||
"%Y-%m-%dT%H:%M:%SZ",
|
||||
time.gmtime(os.path.getmtime(ckpt_path)),
|
||||
)
|
||||
|
||||
checkpoints.append(backend_pb2.CheckpointInfo(
|
||||
path=ckpt_path,
|
||||
step=step,
|
||||
epoch=float(epoch),
|
||||
loss=float(loss),
|
||||
created_at=created_at,
|
||||
))
|
||||
|
||||
return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints)
|
||||
|
||||
def ExportModel(self, request, context):
|
||||
export_format = request.export_format or "lora"
|
||||
output_path = request.output_path
|
||||
checkpoint_path = request.checkpoint_path
|
||||
|
||||
# Extract HF token for gated model access
|
||||
extra = dict(request.extra_options) if request.extra_options else {}
|
||||
hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
|
||||
|
||||
if not checkpoint_path or not os.path.isdir(checkpoint_path):
|
||||
return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
try:
|
||||
if export_format == "lora":
|
||||
# Just copy the adapter files
|
||||
import shutil
|
||||
for f in os.listdir(checkpoint_path):
|
||||
src = os.path.join(checkpoint_path, f)
|
||||
dst = os.path.join(output_path, f)
|
||||
if os.path.isfile(src):
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
elif export_format in ("merged_16bit", "merged_4bit"):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import PeftModel
|
||||
|
||||
base_model_name = request.model
|
||||
if not base_model_name:
|
||||
return backend_pb2.Result(success=False, message="Base model name required for merge export")
|
||||
|
||||
dtype = torch.float16 if export_format == "merged_16bit" else torch.float32
|
||||
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token)
|
||||
model = PeftModel.from_pretrained(base_model, checkpoint_path)
|
||||
merged = model.merge_and_unload()
|
||||
merged.save_pretrained(output_path)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
|
||||
tokenizer.save_pretrained(output_path)
|
||||
|
||||
elif export_format == "gguf":
|
||||
import torch
|
||||
import subprocess
|
||||
import shutil
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import PeftModel
|
||||
|
||||
base_model_name = request.model
|
||||
if not base_model_name:
|
||||
return backend_pb2.Result(success=False, message="Base model name required for GGUF export")
|
||||
|
||||
# Step 1: Merge LoRA into base model
|
||||
merge_dir = os.path.join(output_path, "_hf_merged")
|
||||
os.makedirs(merge_dir, exist_ok=True)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token)
|
||||
model = PeftModel.from_pretrained(base_model, checkpoint_path)
|
||||
merged = model.merge_and_unload()
|
||||
merged.save_pretrained(merge_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
|
||||
tokenizer.save_pretrained(merge_dir)
|
||||
|
||||
# Ensure tokenizer.model (SentencePiece) is present in merge_dir.
|
||||
# Gemma models need this file for GGUF conversion to use the
|
||||
# SentencePiece path; without it, the script falls back to BPE
|
||||
# handling which fails on unrecognized pre-tokenizer hashes.
|
||||
sp_model_path = os.path.join(merge_dir, "tokenizer.model")
|
||||
if not os.path.exists(sp_model_path):
|
||||
sp_copied = False
|
||||
# Method 1: Load the slow tokenizer which keeps the SP model file
|
||||
try:
|
||||
slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token)
|
||||
if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file):
|
||||
import shutil as _shutil
|
||||
_shutil.copy2(slow_tok.vocab_file, sp_model_path)
|
||||
sp_copied = True
|
||||
print(f"Copied tokenizer.model from slow tokenizer cache")
|
||||
except Exception as e:
|
||||
print(f"Slow tokenizer method failed: {e}")
|
||||
# Method 2: Download from HF hub
|
||||
if not sp_copied:
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token)
|
||||
import shutil as _shutil
|
||||
_shutil.copy2(cached_sp, sp_model_path)
|
||||
sp_copied = True
|
||||
print(f"Copied tokenizer.model from HF hub")
|
||||
except Exception as e:
|
||||
print(f"HF hub download method failed: {e}")
|
||||
if not sp_copied:
|
||||
print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. "
|
||||
"GGUF conversion may fail for SentencePiece models.")
|
||||
|
||||
# Free GPU memory before conversion
|
||||
del merged, model, base_model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Step 2: Convert to GGUF using convert_hf_to_gguf.py
|
||||
quant = request.quantization_method or "auto"
|
||||
outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"}
|
||||
outtype = outtype_map.get(quant, "f16")
|
||||
|
||||
gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf"
|
||||
gguf_path = os.path.join(output_path, gguf_filename)
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py")
|
||||
if not os.path.exists(convert_script):
|
||||
return backend_pb2.Result(success=False,
|
||||
message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.")
|
||||
|
||||
# Log merge_dir contents for debugging conversion issues
|
||||
merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else []
|
||||
print(f"Merge dir contents: {merge_files}", flush=True)
|
||||
|
||||
env = os.environ.copy()
|
||||
env["NO_LOCAL_GGUF"] = "1"
|
||||
cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path]
|
||||
conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env)
|
||||
if conv_result.returncode != 0:
|
||||
diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}"
|
||||
return backend_pb2.Result(success=False,
|
||||
message=f"GGUF conversion failed: {diag}")
|
||||
|
||||
# Clean up intermediate merged model
|
||||
shutil.rmtree(merge_dir, ignore_errors=True)
|
||||
else:
|
||||
return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}")
|
||||
|
||||
except Exception as e:
|
||||
if _is_gated_repo_error(e):
|
||||
return backend_pb2.Result(success=False,
|
||||
message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
|
||||
"Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
|
||||
return backend_pb2.Result(success=False, message=f"Export failed: {e}")
|
||||
|
||||
return backend_pb2.Result(success=True, message=f"Model exported to {output_path}")
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024),
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024),
|
||||
],
|
||||
)
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True)
|
||||
|
||||
# Handle graceful shutdown
|
||||
def stop(signum, frame):
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, stop)
|
||||
signal.signal(signal.SIGINT, stop)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend")
|
||||
parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
|
||||
args = parser.parse_args()
|
||||
serve(args.addr)
|
||||
37
backend/python/trl/install.sh
Normal file
37
backend/python/trl/install.sh
Normal file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
installRequirements
|
||||
|
||||
# Fetch convert_hf_to_gguf.py and gguf package from the same llama.cpp version
|
||||
LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
|
||||
CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
|
||||
if [ ! -f "${CONVERT_SCRIPT}" ]; then
|
||||
echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
curl -L --fail --retry 3 \
|
||||
"https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
|
||||
-o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py. GGUF export will not be available."
|
||||
fi
|
||||
|
||||
# Install gguf package from the same llama.cpp commit to keep them in sync
|
||||
GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
|
||||
echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
|
||||
if [ "x${USE_PIP:-}" == "xtrue" ]; then
|
||||
pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
pip install "gguf>=0.16.0"
|
||||
}
|
||||
else
|
||||
uv pip install "${GGUF_PIP_SPEC}" || {
|
||||
echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
|
||||
uv pip install "gguf>=0.16.0"
|
||||
}
|
||||
fi
|
||||
9
backend/python/trl/requirements-cpu.txt
Normal file
9
backend/python/trl/requirements-cpu.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
torch==2.10.0
|
||||
trl
|
||||
peft
|
||||
datasets>=3.0.0
|
||||
transformers>=4.56.2
|
||||
accelerate>=1.4.0
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
9
backend/python/trl/requirements-cublas12.txt
Normal file
9
backend/python/trl/requirements-cublas12.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
torch==2.10.0
|
||||
trl
|
||||
peft
|
||||
datasets>=3.0.0
|
||||
transformers>=4.56.2
|
||||
accelerate>=1.4.0
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
bitsandbytes
|
||||
9
backend/python/trl/requirements-cublas13.txt
Normal file
9
backend/python/trl/requirements-cublas13.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
torch==2.10.0
|
||||
trl
|
||||
peft
|
||||
datasets>=3.0.0
|
||||
transformers>=4.56.2
|
||||
accelerate>=1.4.0
|
||||
huggingface-hub>=1.3.0
|
||||
sentencepiece
|
||||
bitsandbytes
|
||||
3
backend/python/trl/requirements.txt
Normal file
3
backend/python/trl/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
grpcio==1.78.1
|
||||
protobuf
|
||||
certifi
|
||||
236
backend/python/trl/reward_functions.py
Normal file
236
backend/python/trl/reward_functions.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""
|
||||
Built-in reward functions and inline function compiler for GRPO training.
|
||||
|
||||
All reward functions follow TRL's signature: (completions, **kwargs) -> list[float]
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import math
|
||||
import string
|
||||
import functools
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Built-in reward functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Checks for <think>...</think> followed by an answer. Returns 1.0 or 0.0."""
|
||||
pattern = re.compile(r"<think>.*?</think>\s*\S", re.DOTALL)
|
||||
return [1.0 if pattern.search(c) else 0.0 for c in completions]
|
||||
|
||||
|
||||
def reasoning_accuracy_reward(completions, **kwargs):
|
||||
"""Extracts <answer>...</answer> content and compares to the expected answer."""
|
||||
answers = kwargs.get("answer", [])
|
||||
if not answers:
|
||||
return [0.0] * len(completions)
|
||||
|
||||
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
|
||||
scores = []
|
||||
for i, c in enumerate(completions):
|
||||
expected = answers[i] if i < len(answers) else ""
|
||||
match = pattern.search(c)
|
||||
if match:
|
||||
extracted = match.group(1).strip()
|
||||
scores.append(1.0 if extracted.lower() == str(expected).strip().lower() else 0.0)
|
||||
else:
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
|
||||
def length_reward(completions, target_length=200, **kwargs):
|
||||
"""Score based on proximity to target_length. Returns [0, 1]."""
|
||||
scores = []
|
||||
for c in completions:
|
||||
length = len(c)
|
||||
if target_length <= 0:
|
||||
scores.append(0.0)
|
||||
else:
|
||||
diff = abs(length - target_length) / target_length
|
||||
scores.append(max(0.0, 1.0 - diff))
|
||||
return scores
|
||||
|
||||
|
||||
def xml_tag_reward(completions, **kwargs):
|
||||
"""Scores properly opened/closed XML tags (<think>, <answer>)."""
|
||||
tags = ["think", "answer"]
|
||||
scores = []
|
||||
for c in completions:
|
||||
tag_score = 0.0
|
||||
for tag in tags:
|
||||
if f"<{tag}>" in c and f"</{tag}>" in c:
|
||||
tag_score += 0.5
|
||||
scores.append(min(tag_score, 1.0))
|
||||
return scores
|
||||
|
||||
|
||||
def no_repetition_reward(completions, n=4, **kwargs):
|
||||
"""Penalizes n-gram repetition. Returns [0, 1]."""
|
||||
scores = []
|
||||
for c in completions:
|
||||
words = c.split()
|
||||
if len(words) < n:
|
||||
scores.append(1.0)
|
||||
continue
|
||||
ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
|
||||
unique = len(set(ngrams))
|
||||
total = len(ngrams)
|
||||
scores.append(unique / total if total > 0 else 1.0)
|
||||
return scores
|
||||
|
||||
|
||||
def code_execution_reward(completions, **kwargs):
|
||||
"""Checks Python code block syntax validity via compile(). Returns 1.0 or 0.0."""
|
||||
pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL)
|
||||
scores = []
|
||||
for c in completions:
|
||||
match = pattern.search(c)
|
||||
if not match:
|
||||
scores.append(0.0)
|
||||
continue
|
||||
code = match.group(1)
|
||||
try:
|
||||
compile(code, "<inline>", "exec")
|
||||
scores.append(1.0)
|
||||
except SyntaxError:
|
||||
scores.append(0.0)
|
||||
return scores
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
BUILTIN_REGISTRY = {
|
||||
"format_reward": format_reward,
|
||||
"reasoning_accuracy_reward": reasoning_accuracy_reward,
|
||||
"length_reward": length_reward,
|
||||
"xml_tag_reward": xml_tag_reward,
|
||||
"no_repetition_reward": no_repetition_reward,
|
||||
"code_execution_reward": code_execution_reward,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inline function compiler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SAFE_BUILTINS = {
|
||||
"len": len, "int": int, "float": float, "str": str, "bool": bool,
|
||||
"list": list, "dict": dict, "tuple": tuple, "set": set,
|
||||
"range": range, "enumerate": enumerate, "zip": zip,
|
||||
"map": map, "filter": filter, "sorted": sorted,
|
||||
"min": min, "max": max, "sum": sum, "abs": abs, "round": round,
|
||||
"any": any, "all": all, "isinstance": isinstance, "type": type,
|
||||
"print": print, "True": True, "False": False, "None": None,
|
||||
"ValueError": ValueError, "TypeError": TypeError,
|
||||
"KeyError": KeyError, "IndexError": IndexError,
|
||||
}
|
||||
|
||||
|
||||
def compile_inline_reward(name, code):
|
||||
"""Compile user-provided code into a reward function.
|
||||
|
||||
The code should be the body of a function that receives
|
||||
`completions` (list[str]) and `**kwargs`, and returns list[float].
|
||||
|
||||
Available modules: re, math, json, string.
|
||||
"""
|
||||
func_source = (
|
||||
f"def _user_reward_{name}(completions, **kwargs):\n"
|
||||
+ "\n".join(f" {line}" for line in code.splitlines())
|
||||
)
|
||||
|
||||
restricted_globals = {
|
||||
"__builtins__": _SAFE_BUILTINS,
|
||||
"re": re,
|
||||
"math": math,
|
||||
"json": json,
|
||||
"string": string,
|
||||
}
|
||||
|
||||
try:
|
||||
compiled = compile(func_source, f"<inline-reward-{name}>", "exec")
|
||||
except SyntaxError as e:
|
||||
raise ValueError(f"Syntax error in inline reward function '{name}': {e}")
|
||||
|
||||
exec(compiled, restricted_globals)
|
||||
func = restricted_globals[f"_user_reward_{name}"]
|
||||
|
||||
# Validate with a quick smoke test
|
||||
try:
|
||||
result = func(["test"], answer=["test"])
|
||||
if not isinstance(result, list):
|
||||
raise ValueError(
|
||||
f"Inline reward function '{name}' must return a list, got {type(result).__name__}"
|
||||
)
|
||||
except Exception as e:
|
||||
if "must return a list" in str(e):
|
||||
raise
|
||||
# Other errors during smoke test are acceptable (e.g. missing kwargs)
|
||||
pass
|
||||
|
||||
return func
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_reward_functions(specs_json):
|
||||
"""Parse a JSON list of reward function specs and return a list of callables.
|
||||
|
||||
Each spec is a dict with:
|
||||
- type: "builtin" or "inline"
|
||||
- name: function name
|
||||
- code: (inline only) Python function body
|
||||
- params: (optional) dict of string params applied via functools.partial
|
||||
"""
|
||||
if isinstance(specs_json, str):
|
||||
specs = json.loads(specs_json)
|
||||
else:
|
||||
specs = specs_json
|
||||
|
||||
if not isinstance(specs, list):
|
||||
raise ValueError("reward_funcs must be a JSON array of reward function specs")
|
||||
|
||||
reward_funcs = []
|
||||
for spec in specs:
|
||||
spec_type = spec.get("type", "builtin")
|
||||
name = spec.get("name", "")
|
||||
params = spec.get("params", {})
|
||||
|
||||
if spec_type == "builtin":
|
||||
if name not in BUILTIN_REGISTRY:
|
||||
available = ", ".join(sorted(BUILTIN_REGISTRY.keys()))
|
||||
raise ValueError(
|
||||
f"Unknown builtin reward function '{name}'. Available: {available}"
|
||||
)
|
||||
func = BUILTIN_REGISTRY[name]
|
||||
if params:
|
||||
# Convert string params to appropriate types
|
||||
typed_params = {}
|
||||
for k, v in params.items():
|
||||
try:
|
||||
typed_params[k] = int(v)
|
||||
except (ValueError, TypeError):
|
||||
try:
|
||||
typed_params[k] = float(v)
|
||||
except (ValueError, TypeError):
|
||||
typed_params[k] = v
|
||||
func = functools.partial(func, **typed_params)
|
||||
reward_funcs.append(func)
|
||||
|
||||
elif spec_type == "inline":
|
||||
code = spec.get("code", "")
|
||||
if not code.strip():
|
||||
raise ValueError(f"Inline reward function '{name}' has no code")
|
||||
func = compile_inline_reward(name, code)
|
||||
reward_funcs.append(func)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown reward function type '{spec_type}'. Use 'builtin' or 'inline'")
|
||||
|
||||
return reward_funcs
|
||||
10
backend/python/trl/run.sh
Normal file
10
backend/python/trl/run.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
58
backend/python/trl/test.py
Normal file
58
backend/python/trl/test.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Test script for the TRL fine-tuning gRPC backend.
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""Tests for the TRL fine-tuning gRPC service."""
|
||||
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(
|
||||
["python3", "backend.py", "--addr", "localhost:50051"]
|
||||
)
|
||||
time.sleep(10)
|
||||
|
||||
def tearDown(self):
|
||||
self.service.kill()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
"""Test that the server starts and responds to health checks."""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_list_checkpoints_empty(self):
|
||||
"""Test listing checkpoints on a non-existent directory."""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.ListCheckpoints(
|
||||
backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
|
||||
)
|
||||
self.assertEqual(len(response.checkpoints), 0)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("ListCheckpoints service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
11
backend/python/trl/test.sh
Normal file
11
backend/python/trl/test.sh
Normal file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
Reference in New Issue
Block a user