diff --git a/backend/python/trl/Makefile b/backend/python/trl/Makefile new file mode 100644 index 000000000..ababb961c --- /dev/null +++ b/backend/python/trl/Makefile @@ -0,0 +1,26 @@ +# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export) +LLAMA_CPP_CONVERT_VERSION ?= master + +.PHONY: trl +trl: + LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh + +.PHONY: run +run: trl + @echo "Running trl..." + bash run.sh + @echo "trl run." + +.PHONY: test +test: trl + @echo "Testing trl..." + bash test.sh + @echo "trl tested." + +.PHONY: protogen-clean +protogen-clean: + $(RM) backend_pb2_grpc.py backend_pb2.py + +.PHONY: clean +clean: protogen-clean + rm -rf venv __pycache__ diff --git a/backend/python/trl/backend.py b/backend/python/trl/backend.py new file mode 100644 index 000000000..6594c91a5 --- /dev/null +++ b/backend/python/trl/backend.py @@ -0,0 +1,759 @@ +#!/usr/bin/env python3 +""" +TRL fine-tuning backend for LocalAI. + +Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO) +using standard HuggingFace transformers + PEFT. Works on both CPU and GPU. +""" +import argparse +import json +import os +import queue +import signal +import sys +import threading +import time +import uuid +from concurrent import futures + +import grpc +import backend_pb2 +import backend_pb2_grpc + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 +MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) + + +class ProgressCallback: + """HuggingFace TrainerCallback that pushes progress updates to a queue.""" + + def __init__(self, job_id, progress_queue, total_epochs): + self.job_id = job_id + self.progress_queue = progress_queue + self.total_epochs = total_epochs + + def get_callback(self): + from transformers import TrainerCallback + + parent = self + + class _Callback(TrainerCallback): + def on_log(self, args, state, control, logs=None, **kwargs): + if logs is None: + return + total_steps = state.max_steps if state.max_steps > 0 else 0 + progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0 + eta = 0.0 + if state.global_step > 0 and total_steps > 0: + elapsed = time.time() - state.logging_steps # approximate + remaining_steps = total_steps - state.global_step + if state.global_step > 1: + eta = remaining_steps * (elapsed / state.global_step) + + extra_metrics = {} + for k, v in logs.items(): + if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'): + extra_metrics[k] = float(v) + + update = backend_pb2.FineTuneProgressUpdate( + job_id=parent.job_id, + current_step=state.global_step, + total_steps=total_steps, + current_epoch=float(logs.get('epoch', 0)), + total_epochs=float(parent.total_epochs), + loss=float(logs.get('loss', 0)), + learning_rate=float(logs.get('learning_rate', 0)), + grad_norm=float(logs.get('grad_norm', 0)), + eval_loss=float(logs.get('eval_loss', 0)), + eta_seconds=float(eta), + progress_percent=float(progress), + status="training", + extra_metrics=extra_metrics, + ) + parent.progress_queue.put(update) + + def on_save(self, args, state, control, **kwargs): + checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") + update = backend_pb2.FineTuneProgressUpdate( + job_id=parent.job_id, + current_step=state.global_step, + status="saving", + message=f"Checkpoint saved at step {state.global_step}", + checkpoint_path=checkpoint_path, + ) + parent.progress_queue.put(update) + + def on_train_end(self, args, state, control, **kwargs): + update = backend_pb2.FineTuneProgressUpdate( + job_id=parent.job_id, + current_step=state.global_step, + total_steps=state.max_steps, + progress_percent=100.0, + status="completed", + message="Training completed", + ) + parent.progress_queue.put(update) + + return _Callback() + + +class ActiveJob: + """Represents an active fine-tuning job.""" + + def __init__(self, job_id): + self.job_id = job_id + self.progress_queue = queue.Queue() + self.trainer = None + self.thread = None + self.model = None + self.tokenizer = None + self.error = None + self.completed = False + self.stopped = False + + +def _is_gated_repo_error(exc): + """Check if an exception is caused by a gated HuggingFace repo requiring authentication.""" + try: + from huggingface_hub.utils import GatedRepoError + if isinstance(exc, GatedRepoError): + return True + except ImportError: + pass + msg = str(exc).lower() + if "gated repo" in msg or "access to model" in msg: + return True + if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'): + if exc.response.status_code in (401, 403): + return True + return False + + +class BackendServicer(backend_pb2_grpc.BackendServicer): + def __init__(self): + self.active_job = None + + def Health(self, request, context): + return backend_pb2.Reply(message=bytes("OK", 'utf-8')) + + def LoadModel(self, request, context): + """Accept LoadModel — actual model loading happens in StartFineTune.""" + return backend_pb2.Result(success=True, message="OK") + + def StartFineTune(self, request, context): + if self.active_job is not None and not self.active_job.completed: + return backend_pb2.FineTuneJobResult( + job_id="", + success=False, + message="A fine-tuning job is already running", + ) + + job_id = request.job_id if request.job_id else str(uuid.uuid4()) + job = ActiveJob(job_id) + self.active_job = job + + # Start training in background thread + thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True) + job.thread = thread + thread.start() + + return backend_pb2.FineTuneJobResult( + job_id=job_id, + success=True, + message="Fine-tuning job started", + ) + + def _run_training(self, request, job): + try: + self._do_training(request, job) + except Exception as e: + if _is_gated_repo_error(e): + msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. " + "Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.") + else: + msg = f"Training failed: {e}" + job.error = msg + job.completed = True + update = backend_pb2.FineTuneProgressUpdate( + job_id=job.job_id, + status="failed", + message=msg, + ) + job.progress_queue.put(update) + # Send sentinel + job.progress_queue.put(None) + + def _do_training(self, request, job): + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + from datasets import load_dataset, Dataset + + extra = dict(request.extra_options) + training_method = request.training_method or "sft" + training_type = request.training_type or "lora" + + # Send loading status + job.progress_queue.put(backend_pb2.FineTuneProgressUpdate( + job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}", + )) + + # Determine device and dtype + device_map = "auto" if torch.cuda.is_available() else "cpu" + dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16 + + # HuggingFace token for gated repos (from extra_options or HF_TOKEN env) + hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN") + + # Load model + model_kwargs = {"device_map": device_map, "torch_dtype": dtype} + if hf_token: + model_kwargs["token"] = hf_token + if extra.get("trust_remote_code", "false").lower() == "true": + model_kwargs["trust_remote_code"] = True + if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available(): + from transformers import BitsAndBytesConfig + model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True) + + model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + job.model = model + job.tokenizer = tokenizer + + # Apply LoRA if requested + if training_type == "lora": + from peft import LoraConfig, get_peft_model + lora_r = request.adapter_rank if request.adapter_rank > 0 else 16 + lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16 + lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0 + + target_modules = list(request.target_modules) if request.target_modules else None + peft_config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + target_modules=target_modules or "all-linear", + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, peft_config) + + # Load dataset + job.progress_queue.put(backend_pb2.FineTuneProgressUpdate( + job_id=job.job_id, status="loading_dataset", message="Loading dataset", + )) + + dataset_split = request.dataset_split or "train" + if os.path.exists(request.dataset_source): + if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'): + dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split) + elif request.dataset_source.endswith('.csv'): + dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split) + else: + dataset = load_dataset(request.dataset_source, split=dataset_split) + else: + dataset = load_dataset(request.dataset_source, split=dataset_split) + + # Training config + output_dir = request.output_dir or f"./output-{job.job_id}" + num_epochs = request.num_epochs if request.num_epochs > 0 else 3 + batch_size = request.batch_size if request.batch_size > 0 else 2 + lr = request.learning_rate if request.learning_rate > 0 else 2e-4 + grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4 + warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5 + weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01 + max_steps = request.max_steps if request.max_steps > 0 else -1 + save_steps = request.save_steps if request.save_steps > 0 else 500 + seed = request.seed if request.seed > 0 else 3407 + optimizer = request.optimizer or "adamw_torch" + + # Checkpoint save controls + save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited + save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no + + # CPU vs GPU training args (can be overridden via extra_options) + use_cpu = not torch.cuda.is_available() + common_train_kwargs = {} + if use_cpu: + common_train_kwargs["use_cpu"] = True + common_train_kwargs["fp16"] = False + common_train_kwargs["bf16"] = False + common_train_kwargs["gradient_checkpointing"] = False + else: + common_train_kwargs["bf16"] = True + common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing + + # Allow extra_options to override training kwargs + for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"): + if flag in extra: + common_train_kwargs[flag] = extra[flag].lower() == "true" + + # Create progress callback + progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs) + + # Build save kwargs (shared across all methods) + _save_kwargs = {} + if save_strategy == "steps" and save_steps > 0: + _save_kwargs["save_steps"] = save_steps + _save_kwargs["save_strategy"] = "steps" + elif save_strategy == "epoch": + _save_kwargs["save_strategy"] = "epoch" + elif save_strategy == "no": + _save_kwargs["save_strategy"] = "no" + else: + _save_kwargs["save_steps"] = save_steps + _save_kwargs["save_strategy"] = "steps" + if save_total_limit: + _save_kwargs["save_total_limit"] = save_total_limit + + # Common training arguments shared by all methods + _common_args = dict( + output_dir=output_dir, + num_train_epochs=num_epochs, + per_device_train_batch_size=batch_size, + learning_rate=lr, + gradient_accumulation_steps=grad_accum, + warmup_steps=warmup_steps, + weight_decay=weight_decay, + max_steps=max_steps, + seed=seed, + optim=optimizer, + logging_steps=1, + report_to="none", + **_save_kwargs, + **common_train_kwargs, + ) + + # Select trainer based on training method + if training_method == "sft": + from trl import SFTTrainer, SFTConfig + + max_length = int(extra.get("max_seq_length", "512")) + packing = extra.get("packing", "false").lower() == "true" + + training_args = SFTConfig( + max_length=max_length, + packing=packing, + **_common_args, + ) + + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + callbacks=[progress_cb.get_callback()], + ) + + elif training_method == "dpo": + from trl import DPOTrainer, DPOConfig + + beta = float(extra.get("beta", "0.1")) + loss_type = extra.get("loss_type", "sigmoid") + max_length = int(extra.get("max_length", "512")) + + training_args = DPOConfig( + beta=beta, + loss_type=loss_type, + max_length=max_length, + **_common_args, + ) + + trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + callbacks=[progress_cb.get_callback()], + ) + + elif training_method == "grpo": + from trl import GRPOTrainer, GRPOConfig + + num_generations = int(extra.get("num_generations", "4")) + max_completion_length = int(extra.get("max_completion_length", "256")) + + training_args = GRPOConfig( + num_generations=num_generations, + max_completion_length=max_completion_length, + **_common_args, + ) + + # GRPO requires reward functions passed via extra_options as a JSON list + from reward_functions import build_reward_functions + + reward_funcs = [] + if extra.get("reward_funcs"): + reward_funcs = build_reward_functions(extra["reward_funcs"]) + + if not reward_funcs: + raise ValueError( + "GRPO requires at least one reward function. " + "Specify reward_functions in the request or " + "reward_funcs in extra_options." + ) + + trainer = GRPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + reward_funcs=reward_funcs, + callbacks=[progress_cb.get_callback()], + ) + + elif training_method == "orpo": + from trl import ORPOTrainer, ORPOConfig + + beta = float(extra.get("beta", "0.1")) + max_length = int(extra.get("max_length", "512")) + + training_args = ORPOConfig( + beta=beta, + max_length=max_length, + **_common_args, + ) + + trainer = ORPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + callbacks=[progress_cb.get_callback()], + ) + + elif training_method == "kto": + from trl import KTOTrainer, KTOConfig + + beta = float(extra.get("beta", "0.1")) + max_length = int(extra.get("max_length", "512")) + + training_args = KTOConfig( + beta=beta, + max_length=max_length, + **_common_args, + ) + + trainer = KTOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + callbacks=[progress_cb.get_callback()], + ) + + elif training_method == "rloo": + from trl import RLOOTrainer, RLOOConfig + + num_generations = int(extra.get("num_generations", "4")) + max_completion_length = int(extra.get("max_completion_length", "256")) + + training_args = RLOOConfig( + num_generations=num_generations, + max_new_tokens=max_completion_length, + **_common_args, + ) + + trainer = RLOOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + callbacks=[progress_cb.get_callback()], + ) + + elif training_method == "reward": + from trl import RewardTrainer, RewardConfig + + max_length = int(extra.get("max_length", "512")) + + training_args = RewardConfig( + max_length=max_length, + **_common_args, + ) + + trainer = RewardTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + callbacks=[progress_cb.get_callback()], + ) + + else: + raise ValueError(f"Unsupported training method: {training_method}. " + "Supported: sft, dpo, grpo, orpo, kto, rloo, reward") + + job.trainer = trainer + + # Start training + job.progress_queue.put(backend_pb2.FineTuneProgressUpdate( + job_id=job.job_id, status="training", message="Training started", + )) + + resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None + trainer.train(resume_from_checkpoint=resume_ckpt) + + # Save final model + trainer.save_model(output_dir) + if tokenizer: + tokenizer.save_pretrained(output_dir) + + job.completed = True + # Sentinel to signal stream end + job.progress_queue.put(None) + + def FineTuneProgress(self, request, context): + if self.active_job is None or self.active_job.job_id != request.job_id: + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(f"Job {request.job_id} not found") + return + + job = self.active_job + while True: + try: + update = job.progress_queue.get(timeout=1.0) + if update is None: + break + yield update + if update.status in ("completed", "failed", "stopped"): + break + except queue.Empty: + if job.completed or job.stopped: + break + if not context.is_active(): + break + continue + + def StopFineTune(self, request, context): + # No-op: stopping is handled by killing the backend process from Go. + # This stub remains to satisfy the proto-generated gRPC interface. + return backend_pb2.Result(success=True, message="No-op (process kill used instead)") + + def ListCheckpoints(self, request, context): + output_dir = request.output_dir + if not os.path.isdir(output_dir): + return backend_pb2.ListCheckpointsResponse(checkpoints=[]) + + checkpoints = [] + for entry in sorted(os.listdir(output_dir)): + if entry.startswith("checkpoint-"): + ckpt_path = os.path.join(output_dir, entry) + if not os.path.isdir(ckpt_path): + continue + step = 0 + try: + step = int(entry.split("-")[1]) + except (IndexError, ValueError): + pass + + # Try to read trainer_state.json for metadata + loss = 0.0 + epoch = 0.0 + state_file = os.path.join(ckpt_path, "trainer_state.json") + if os.path.exists(state_file): + try: + with open(state_file) as f: + state = json.load(f) + if state.get("log_history"): + last_log = state["log_history"][-1] + loss = last_log.get("loss", 0.0) + epoch = last_log.get("epoch", 0.0) + except Exception: + pass + + created_at = time.strftime( + "%Y-%m-%dT%H:%M:%SZ", + time.gmtime(os.path.getmtime(ckpt_path)), + ) + + checkpoints.append(backend_pb2.CheckpointInfo( + path=ckpt_path, + step=step, + epoch=float(epoch), + loss=float(loss), + created_at=created_at, + )) + + return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints) + + def ExportModel(self, request, context): + export_format = request.export_format or "lora" + output_path = request.output_path + checkpoint_path = request.checkpoint_path + + # Extract HF token for gated model access + extra = dict(request.extra_options) if request.extra_options else {} + hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN") + + if not checkpoint_path or not os.path.isdir(checkpoint_path): + return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}") + + os.makedirs(output_path, exist_ok=True) + + try: + if export_format == "lora": + # Just copy the adapter files + import shutil + for f in os.listdir(checkpoint_path): + src = os.path.join(checkpoint_path, f) + dst = os.path.join(output_path, f) + if os.path.isfile(src): + shutil.copy2(src, dst) + + elif export_format in ("merged_16bit", "merged_4bit"): + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer + from peft import PeftModel + + base_model_name = request.model + if not base_model_name: + return backend_pb2.Result(success=False, message="Base model name required for merge export") + + dtype = torch.float16 if export_format == "merged_16bit" else torch.float32 + base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token) + model = PeftModel.from_pretrained(base_model, checkpoint_path) + merged = model.merge_and_unload() + merged.save_pretrained(output_path) + + tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token) + tokenizer.save_pretrained(output_path) + + elif export_format == "gguf": + import torch + import subprocess + import shutil + from transformers import AutoModelForCausalLM, AutoTokenizer + from peft import PeftModel + + base_model_name = request.model + if not base_model_name: + return backend_pb2.Result(success=False, message="Base model name required for GGUF export") + + # Step 1: Merge LoRA into base model + merge_dir = os.path.join(output_path, "_hf_merged") + os.makedirs(merge_dir, exist_ok=True) + + base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token) + model = PeftModel.from_pretrained(base_model, checkpoint_path) + merged = model.merge_and_unload() + merged.save_pretrained(merge_dir) + + tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token) + tokenizer.save_pretrained(merge_dir) + + # Ensure tokenizer.model (SentencePiece) is present in merge_dir. + # Gemma models need this file for GGUF conversion to use the + # SentencePiece path; without it, the script falls back to BPE + # handling which fails on unrecognized pre-tokenizer hashes. + sp_model_path = os.path.join(merge_dir, "tokenizer.model") + if not os.path.exists(sp_model_path): + sp_copied = False + # Method 1: Load the slow tokenizer which keeps the SP model file + try: + slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token) + if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file): + import shutil as _shutil + _shutil.copy2(slow_tok.vocab_file, sp_model_path) + sp_copied = True + print(f"Copied tokenizer.model from slow tokenizer cache") + except Exception as e: + print(f"Slow tokenizer method failed: {e}") + # Method 2: Download from HF hub + if not sp_copied: + try: + from huggingface_hub import hf_hub_download + cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token) + import shutil as _shutil + _shutil.copy2(cached_sp, sp_model_path) + sp_copied = True + print(f"Copied tokenizer.model from HF hub") + except Exception as e: + print(f"HF hub download method failed: {e}") + if not sp_copied: + print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. " + "GGUF conversion may fail for SentencePiece models.") + + # Free GPU memory before conversion + del merged, model, base_model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Step 2: Convert to GGUF using convert_hf_to_gguf.py + quant = request.quantization_method or "auto" + outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"} + outtype = outtype_map.get(quant, "f16") + + gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf" + gguf_path = os.path.join(output_path, gguf_filename) + + script_dir = os.path.dirname(os.path.abspath(__file__)) + convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py") + if not os.path.exists(convert_script): + return backend_pb2.Result(success=False, + message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.") + + # Log merge_dir contents for debugging conversion issues + merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else [] + print(f"Merge dir contents: {merge_files}", flush=True) + + env = os.environ.copy() + env["NO_LOCAL_GGUF"] = "1" + cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path] + conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env) + if conv_result.returncode != 0: + diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}" + return backend_pb2.Result(success=False, + message=f"GGUF conversion failed: {diag}") + + # Clean up intermediate merged model + shutil.rmtree(merge_dir, ignore_errors=True) + else: + return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}") + + except Exception as e: + if _is_gated_repo_error(e): + return backend_pb2.Result(success=False, + message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. " + "Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.") + return backend_pb2.Result(success=False, message=f"Export failed: {e}") + + return backend_pb2.Result(success=True, message=f"Model exported to {output_path}") + + +def serve(address): + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), + options=[ + ('grpc.max_message_length', 50 * 1024 * 1024), + ('grpc.max_send_message_length', 50 * 1024 * 1024), + ('grpc.max_receive_message_length', 50 * 1024 * 1024), + ], + ) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True) + + # Handle graceful shutdown + def stop(signum, frame): + server.stop(0) + sys.exit(0) + + signal.signal(signal.SIGTERM, stop) + signal.signal(signal.SIGINT, stop) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend") + parser.add_argument("--addr", default="localhost:50051", help="gRPC server address") + args = parser.parse_args() + serve(args.addr) diff --git a/backend/python/trl/install.sh b/backend/python/trl/install.sh new file mode 100644 index 000000000..6963e60ed --- /dev/null +++ b/backend/python/trl/install.sh @@ -0,0 +1,37 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" +installRequirements + +# Fetch convert_hf_to_gguf.py and gguf package from the same llama.cpp version +LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}" +CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py" +if [ ! -f "${CONVERT_SCRIPT}" ]; then + echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..." + curl -L --fail --retry 3 \ + "https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \ + -o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py. GGUF export will not be available." +fi + +# Install gguf package from the same llama.cpp commit to keep them in sync +GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py" +echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..." +if [ "x${USE_PIP:-}" == "xtrue" ]; then + pip install "${GGUF_PIP_SPEC}" || { + echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..." + pip install "gguf>=0.16.0" + } +else + uv pip install "${GGUF_PIP_SPEC}" || { + echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..." + uv pip install "gguf>=0.16.0" + } +fi diff --git a/backend/python/trl/requirements-cpu.txt b/backend/python/trl/requirements-cpu.txt new file mode 100644 index 000000000..c67858542 --- /dev/null +++ b/backend/python/trl/requirements-cpu.txt @@ -0,0 +1,9 @@ +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.10.0 +trl +peft +datasets>=3.0.0 +transformers>=4.56.2 +accelerate>=1.4.0 +huggingface-hub>=1.3.0 +sentencepiece diff --git a/backend/python/trl/requirements-cublas12.txt b/backend/python/trl/requirements-cublas12.txt new file mode 100644 index 000000000..05f29591c --- /dev/null +++ b/backend/python/trl/requirements-cublas12.txt @@ -0,0 +1,9 @@ +torch==2.10.0 +trl +peft +datasets>=3.0.0 +transformers>=4.56.2 +accelerate>=1.4.0 +huggingface-hub>=1.3.0 +sentencepiece +bitsandbytes diff --git a/backend/python/trl/requirements-cublas13.txt b/backend/python/trl/requirements-cublas13.txt new file mode 100644 index 000000000..05f29591c --- /dev/null +++ b/backend/python/trl/requirements-cublas13.txt @@ -0,0 +1,9 @@ +torch==2.10.0 +trl +peft +datasets>=3.0.0 +transformers>=4.56.2 +accelerate>=1.4.0 +huggingface-hub>=1.3.0 +sentencepiece +bitsandbytes diff --git a/backend/python/trl/requirements.txt b/backend/python/trl/requirements.txt new file mode 100644 index 000000000..0834a8fcd --- /dev/null +++ b/backend/python/trl/requirements.txt @@ -0,0 +1,3 @@ +grpcio==1.78.1 +protobuf +certifi diff --git a/backend/python/trl/reward_functions.py b/backend/python/trl/reward_functions.py new file mode 100644 index 000000000..12074f80c --- /dev/null +++ b/backend/python/trl/reward_functions.py @@ -0,0 +1,236 @@ +""" +Built-in reward functions and inline function compiler for GRPO training. + +All reward functions follow TRL's signature: (completions, **kwargs) -> list[float] +""" + +import json +import re +import math +import string +import functools + + +# --------------------------------------------------------------------------- +# Built-in reward functions +# --------------------------------------------------------------------------- + +def format_reward(completions, **kwargs): + """Checks for ... followed by an answer. Returns 1.0 or 0.0.""" + pattern = re.compile(r".*?\s*\S", re.DOTALL) + return [1.0 if pattern.search(c) else 0.0 for c in completions] + + +def reasoning_accuracy_reward(completions, **kwargs): + """Extracts ... content and compares to the expected answer.""" + answers = kwargs.get("answer", []) + if not answers: + return [0.0] * len(completions) + + pattern = re.compile(r"(.*?)", re.DOTALL) + scores = [] + for i, c in enumerate(completions): + expected = answers[i] if i < len(answers) else "" + match = pattern.search(c) + if match: + extracted = match.group(1).strip() + scores.append(1.0 if extracted.lower() == str(expected).strip().lower() else 0.0) + else: + scores.append(0.0) + return scores + + +def length_reward(completions, target_length=200, **kwargs): + """Score based on proximity to target_length. Returns [0, 1].""" + scores = [] + for c in completions: + length = len(c) + if target_length <= 0: + scores.append(0.0) + else: + diff = abs(length - target_length) / target_length + scores.append(max(0.0, 1.0 - diff)) + return scores + + +def xml_tag_reward(completions, **kwargs): + """Scores properly opened/closed XML tags (, ).""" + tags = ["think", "answer"] + scores = [] + for c in completions: + tag_score = 0.0 + for tag in tags: + if f"<{tag}>" in c and f"" in c: + tag_score += 0.5 + scores.append(min(tag_score, 1.0)) + return scores + + +def no_repetition_reward(completions, n=4, **kwargs): + """Penalizes n-gram repetition. Returns [0, 1].""" + scores = [] + for c in completions: + words = c.split() + if len(words) < n: + scores.append(1.0) + continue + ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)] + unique = len(set(ngrams)) + total = len(ngrams) + scores.append(unique / total if total > 0 else 1.0) + return scores + + +def code_execution_reward(completions, **kwargs): + """Checks Python code block syntax validity via compile(). Returns 1.0 or 0.0.""" + pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL) + scores = [] + for c in completions: + match = pattern.search(c) + if not match: + scores.append(0.0) + continue + code = match.group(1) + try: + compile(code, "", "exec") + scores.append(1.0) + except SyntaxError: + scores.append(0.0) + return scores + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + +BUILTIN_REGISTRY = { + "format_reward": format_reward, + "reasoning_accuracy_reward": reasoning_accuracy_reward, + "length_reward": length_reward, + "xml_tag_reward": xml_tag_reward, + "no_repetition_reward": no_repetition_reward, + "code_execution_reward": code_execution_reward, +} + + +# --------------------------------------------------------------------------- +# Inline function compiler +# --------------------------------------------------------------------------- + +_SAFE_BUILTINS = { + "len": len, "int": int, "float": float, "str": str, "bool": bool, + "list": list, "dict": dict, "tuple": tuple, "set": set, + "range": range, "enumerate": enumerate, "zip": zip, + "map": map, "filter": filter, "sorted": sorted, + "min": min, "max": max, "sum": sum, "abs": abs, "round": round, + "any": any, "all": all, "isinstance": isinstance, "type": type, + "print": print, "True": True, "False": False, "None": None, + "ValueError": ValueError, "TypeError": TypeError, + "KeyError": KeyError, "IndexError": IndexError, +} + + +def compile_inline_reward(name, code): + """Compile user-provided code into a reward function. + + The code should be the body of a function that receives + `completions` (list[str]) and `**kwargs`, and returns list[float]. + + Available modules: re, math, json, string. + """ + func_source = ( + f"def _user_reward_{name}(completions, **kwargs):\n" + + "\n".join(f" {line}" for line in code.splitlines()) + ) + + restricted_globals = { + "__builtins__": _SAFE_BUILTINS, + "re": re, + "math": math, + "json": json, + "string": string, + } + + try: + compiled = compile(func_source, f"", "exec") + except SyntaxError as e: + raise ValueError(f"Syntax error in inline reward function '{name}': {e}") + + exec(compiled, restricted_globals) + func = restricted_globals[f"_user_reward_{name}"] + + # Validate with a quick smoke test + try: + result = func(["test"], answer=["test"]) + if not isinstance(result, list): + raise ValueError( + f"Inline reward function '{name}' must return a list, got {type(result).__name__}" + ) + except Exception as e: + if "must return a list" in str(e): + raise + # Other errors during smoke test are acceptable (e.g. missing kwargs) + pass + + return func + + +# --------------------------------------------------------------------------- +# Dispatcher +# --------------------------------------------------------------------------- + +def build_reward_functions(specs_json): + """Parse a JSON list of reward function specs and return a list of callables. + + Each spec is a dict with: + - type: "builtin" or "inline" + - name: function name + - code: (inline only) Python function body + - params: (optional) dict of string params applied via functools.partial + """ + if isinstance(specs_json, str): + specs = json.loads(specs_json) + else: + specs = specs_json + + if not isinstance(specs, list): + raise ValueError("reward_funcs must be a JSON array of reward function specs") + + reward_funcs = [] + for spec in specs: + spec_type = spec.get("type", "builtin") + name = spec.get("name", "") + params = spec.get("params", {}) + + if spec_type == "builtin": + if name not in BUILTIN_REGISTRY: + available = ", ".join(sorted(BUILTIN_REGISTRY.keys())) + raise ValueError( + f"Unknown builtin reward function '{name}'. Available: {available}" + ) + func = BUILTIN_REGISTRY[name] + if params: + # Convert string params to appropriate types + typed_params = {} + for k, v in params.items(): + try: + typed_params[k] = int(v) + except (ValueError, TypeError): + try: + typed_params[k] = float(v) + except (ValueError, TypeError): + typed_params[k] = v + func = functools.partial(func, **typed_params) + reward_funcs.append(func) + + elif spec_type == "inline": + code = spec.get("code", "") + if not code.strip(): + raise ValueError(f"Inline reward function '{name}' has no code") + func = compile_inline_reward(name, code) + reward_funcs.append(func) + + else: + raise ValueError(f"Unknown reward function type '{spec_type}'. Use 'builtin' or 'inline'") + + return reward_funcs diff --git a/backend/python/trl/run.sh b/backend/python/trl/run.sh new file mode 100644 index 000000000..bd17c6e1d --- /dev/null +++ b/backend/python/trl/run.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +startBackend $@ diff --git a/backend/python/trl/test.py b/backend/python/trl/test.py new file mode 100644 index 000000000..d77d4e9f0 --- /dev/null +++ b/backend/python/trl/test.py @@ -0,0 +1,58 @@ +""" +Test script for the TRL fine-tuning gRPC backend. +""" +import unittest +import subprocess +import time + +import grpc +import backend_pb2 +import backend_pb2_grpc + + +class TestBackendServicer(unittest.TestCase): + """Tests for the TRL fine-tuning gRPC service.""" + + def setUp(self): + self.service = subprocess.Popen( + ["python3", "backend.py", "--addr", "localhost:50051"] + ) + time.sleep(10) + + def tearDown(self): + self.service.kill() + self.service.wait() + + def test_server_startup(self): + """Test that the server starts and responds to health checks.""" + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.Health(backend_pb2.HealthMessage()) + self.assertEqual(response.message, b'OK') + except Exception as err: + print(err) + self.fail("Server failed to start") + finally: + self.tearDown() + + def test_list_checkpoints_empty(self): + """Test listing checkpoints on a non-existent directory.""" + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.ListCheckpoints( + backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent") + ) + self.assertEqual(len(response.checkpoints), 0) + except Exception as err: + print(err) + self.fail("ListCheckpoints service failed") + finally: + self.tearDown() + + +if __name__ == '__main__': + unittest.main() diff --git a/backend/python/trl/test.sh b/backend/python/trl/test.sh new file mode 100644 index 000000000..eb59f2aaf --- /dev/null +++ b/backend/python/trl/test.sh @@ -0,0 +1,11 @@ +#!/bin/bash +set -e + +backend_dir=$(dirname $0) +if [ -d $backend_dir/common ]; then + source $backend_dir/common/libbackend.sh +else + source $backend_dir/../common/libbackend.sh +fi + +runUnittests diff --git a/core/services/finetune.go b/core/services/finetune.go index 568922b10..c49577fbf 100644 --- a/core/services/finetune.go +++ b/core/services/finetune.go @@ -265,25 +265,16 @@ func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, sav } s.mu.Unlock() - backendModel, err := s.modelLoader.Load( - model.WithBackendString(job.Backend), - model.WithModel(job.Backend), - model.WithModelID(job.Backend+"-finetune"), - ) + // Kill the backend process directly — gRPC stop deadlocks on single-threaded Python backends + modelID := job.Backend + "-finetune" + err := s.modelLoader.ShutdownModel(modelID) if err != nil { - return fmt.Errorf("failed to load backend: %w", err) - } - - _, err = backendModel.StopFineTune(ctx, &pb.FineTuneStopRequest{ - JobId: jobID, - SaveCheckpoint: saveCheckpoint, - }) - if err != nil { - return fmt.Errorf("failed to stop job: %w", err) + return fmt.Errorf("failed to stop backend: %w", err) } s.mu.Lock() - job.Message = "Stop requested, waiting for training to halt..." + job.Status = "stopped" + job.Message = "Training stopped by user" s.saveJobState(job) s.mu.Unlock()