diff --git a/backend/python/trl/Makefile b/backend/python/trl/Makefile
new file mode 100644
index 000000000..ababb961c
--- /dev/null
+++ b/backend/python/trl/Makefile
@@ -0,0 +1,26 @@
+# Version of llama.cpp to fetch convert_hf_to_gguf.py from (for GGUF export)
+LLAMA_CPP_CONVERT_VERSION ?= master
+
+.PHONY: trl
+trl:
+ LLAMA_CPP_CONVERT_VERSION=$(LLAMA_CPP_CONVERT_VERSION) bash install.sh
+
+.PHONY: run
+run: trl
+ @echo "Running trl..."
+ bash run.sh
+ @echo "trl run."
+
+.PHONY: test
+test: trl
+ @echo "Testing trl..."
+ bash test.sh
+ @echo "trl tested."
+
+.PHONY: protogen-clean
+protogen-clean:
+ $(RM) backend_pb2_grpc.py backend_pb2.py
+
+.PHONY: clean
+clean: protogen-clean
+ rm -rf venv __pycache__
diff --git a/backend/python/trl/backend.py b/backend/python/trl/backend.py
new file mode 100644
index 000000000..6594c91a5
--- /dev/null
+++ b/backend/python/trl/backend.py
@@ -0,0 +1,759 @@
+#!/usr/bin/env python3
+"""
+TRL fine-tuning backend for LocalAI.
+
+Supports all TRL training methods (SFT, DPO, GRPO, RLOO, Reward, KTO, ORPO)
+using standard HuggingFace transformers + PEFT. Works on both CPU and GPU.
+"""
+import argparse
+import json
+import os
+import queue
+import signal
+import sys
+import threading
+import time
+import uuid
+from concurrent import futures
+
+import grpc
+import backend_pb2
+import backend_pb2_grpc
+
+_ONE_DAY_IN_SECONDS = 60 * 60 * 24
+MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
+
+
+class ProgressCallback:
+ """HuggingFace TrainerCallback that pushes progress updates to a queue."""
+
+ def __init__(self, job_id, progress_queue, total_epochs):
+ self.job_id = job_id
+ self.progress_queue = progress_queue
+ self.total_epochs = total_epochs
+
+ def get_callback(self):
+ from transformers import TrainerCallback
+
+ parent = self
+
+ class _Callback(TrainerCallback):
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ if logs is None:
+ return
+ total_steps = state.max_steps if state.max_steps > 0 else 0
+ progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
+ eta = 0.0
+ if state.global_step > 0 and total_steps > 0:
+ elapsed = time.time() - state.logging_steps # approximate
+ remaining_steps = total_steps - state.global_step
+ if state.global_step > 1:
+ eta = remaining_steps * (elapsed / state.global_step)
+
+ extra_metrics = {}
+ for k, v in logs.items():
+ if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'):
+ extra_metrics[k] = float(v)
+
+ update = backend_pb2.FineTuneProgressUpdate(
+ job_id=parent.job_id,
+ current_step=state.global_step,
+ total_steps=total_steps,
+ current_epoch=float(logs.get('epoch', 0)),
+ total_epochs=float(parent.total_epochs),
+ loss=float(logs.get('loss', 0)),
+ learning_rate=float(logs.get('learning_rate', 0)),
+ grad_norm=float(logs.get('grad_norm', 0)),
+ eval_loss=float(logs.get('eval_loss', 0)),
+ eta_seconds=float(eta),
+ progress_percent=float(progress),
+ status="training",
+ extra_metrics=extra_metrics,
+ )
+ parent.progress_queue.put(update)
+
+ def on_save(self, args, state, control, **kwargs):
+ checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
+ update = backend_pb2.FineTuneProgressUpdate(
+ job_id=parent.job_id,
+ current_step=state.global_step,
+ status="saving",
+ message=f"Checkpoint saved at step {state.global_step}",
+ checkpoint_path=checkpoint_path,
+ )
+ parent.progress_queue.put(update)
+
+ def on_train_end(self, args, state, control, **kwargs):
+ update = backend_pb2.FineTuneProgressUpdate(
+ job_id=parent.job_id,
+ current_step=state.global_step,
+ total_steps=state.max_steps,
+ progress_percent=100.0,
+ status="completed",
+ message="Training completed",
+ )
+ parent.progress_queue.put(update)
+
+ return _Callback()
+
+
+class ActiveJob:
+ """Represents an active fine-tuning job."""
+
+ def __init__(self, job_id):
+ self.job_id = job_id
+ self.progress_queue = queue.Queue()
+ self.trainer = None
+ self.thread = None
+ self.model = None
+ self.tokenizer = None
+ self.error = None
+ self.completed = False
+ self.stopped = False
+
+
+def _is_gated_repo_error(exc):
+ """Check if an exception is caused by a gated HuggingFace repo requiring authentication."""
+ try:
+ from huggingface_hub.utils import GatedRepoError
+ if isinstance(exc, GatedRepoError):
+ return True
+ except ImportError:
+ pass
+ msg = str(exc).lower()
+ if "gated repo" in msg or "access to model" in msg:
+ return True
+ if hasattr(exc, 'response') and hasattr(exc.response, 'status_code'):
+ if exc.response.status_code in (401, 403):
+ return True
+ return False
+
+
+class BackendServicer(backend_pb2_grpc.BackendServicer):
+ def __init__(self):
+ self.active_job = None
+
+ def Health(self, request, context):
+ return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
+
+ def LoadModel(self, request, context):
+ """Accept LoadModel — actual model loading happens in StartFineTune."""
+ return backend_pb2.Result(success=True, message="OK")
+
+ def StartFineTune(self, request, context):
+ if self.active_job is not None and not self.active_job.completed:
+ return backend_pb2.FineTuneJobResult(
+ job_id="",
+ success=False,
+ message="A fine-tuning job is already running",
+ )
+
+ job_id = request.job_id if request.job_id else str(uuid.uuid4())
+ job = ActiveJob(job_id)
+ self.active_job = job
+
+ # Start training in background thread
+ thread = threading.Thread(target=self._run_training, args=(request, job), daemon=True)
+ job.thread = thread
+ thread.start()
+
+ return backend_pb2.FineTuneJobResult(
+ job_id=job_id,
+ success=True,
+ message="Fine-tuning job started",
+ )
+
+ def _run_training(self, request, job):
+ try:
+ self._do_training(request, job)
+ except Exception as e:
+ if _is_gated_repo_error(e):
+ msg = (f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
+ "Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
+ else:
+ msg = f"Training failed: {e}"
+ job.error = msg
+ job.completed = True
+ update = backend_pb2.FineTuneProgressUpdate(
+ job_id=job.job_id,
+ status="failed",
+ message=msg,
+ )
+ job.progress_queue.put(update)
+ # Send sentinel
+ job.progress_queue.put(None)
+
+ def _do_training(self, request, job):
+ import torch
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ from datasets import load_dataset, Dataset
+
+ extra = dict(request.extra_options)
+ training_method = request.training_method or "sft"
+ training_type = request.training_type or "lora"
+
+ # Send loading status
+ job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
+ job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}",
+ ))
+
+ # Determine device and dtype
+ device_map = "auto" if torch.cuda.is_available() else "cpu"
+ dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16
+
+ # HuggingFace token for gated repos (from extra_options or HF_TOKEN env)
+ hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
+
+ # Load model
+ model_kwargs = {"device_map": device_map, "torch_dtype": dtype}
+ if hf_token:
+ model_kwargs["token"] = hf_token
+ if extra.get("trust_remote_code", "false").lower() == "true":
+ model_kwargs["trust_remote_code"] = True
+ if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available():
+ from transformers import BitsAndBytesConfig
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
+
+ model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs)
+ tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token)
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ job.model = model
+ job.tokenizer = tokenizer
+
+ # Apply LoRA if requested
+ if training_type == "lora":
+ from peft import LoraConfig, get_peft_model
+ lora_r = request.adapter_rank if request.adapter_rank > 0 else 16
+ lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16
+ lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0
+
+ target_modules = list(request.target_modules) if request.target_modules else None
+ peft_config = LoraConfig(
+ r=lora_r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ target_modules=target_modules or "all-linear",
+ bias="none",
+ task_type="CAUSAL_LM",
+ )
+ model = get_peft_model(model, peft_config)
+
+ # Load dataset
+ job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
+ job_id=job.job_id, status="loading_dataset", message="Loading dataset",
+ ))
+
+ dataset_split = request.dataset_split or "train"
+ if os.path.exists(request.dataset_source):
+ if request.dataset_source.endswith('.json') or request.dataset_source.endswith('.jsonl'):
+ dataset = load_dataset("json", data_files=request.dataset_source, split=dataset_split)
+ elif request.dataset_source.endswith('.csv'):
+ dataset = load_dataset("csv", data_files=request.dataset_source, split=dataset_split)
+ else:
+ dataset = load_dataset(request.dataset_source, split=dataset_split)
+ else:
+ dataset = load_dataset(request.dataset_source, split=dataset_split)
+
+ # Training config
+ output_dir = request.output_dir or f"./output-{job.job_id}"
+ num_epochs = request.num_epochs if request.num_epochs > 0 else 3
+ batch_size = request.batch_size if request.batch_size > 0 else 2
+ lr = request.learning_rate if request.learning_rate > 0 else 2e-4
+ grad_accum = request.gradient_accumulation_steps if request.gradient_accumulation_steps > 0 else 4
+ warmup_steps = request.warmup_steps if request.warmup_steps > 0 else 5
+ weight_decay = request.weight_decay if request.weight_decay > 0 else 0.01
+ max_steps = request.max_steps if request.max_steps > 0 else -1
+ save_steps = request.save_steps if request.save_steps > 0 else 500
+ seed = request.seed if request.seed > 0 else 3407
+ optimizer = request.optimizer or "adamw_torch"
+
+ # Checkpoint save controls
+ save_total_limit = int(extra.get("save_total_limit", "0")) or None # 0 = unlimited
+ save_strategy = extra.get("save_strategy", "steps") # steps, epoch, no
+
+ # CPU vs GPU training args (can be overridden via extra_options)
+ use_cpu = not torch.cuda.is_available()
+ common_train_kwargs = {}
+ if use_cpu:
+ common_train_kwargs["use_cpu"] = True
+ common_train_kwargs["fp16"] = False
+ common_train_kwargs["bf16"] = False
+ common_train_kwargs["gradient_checkpointing"] = False
+ else:
+ common_train_kwargs["bf16"] = True
+ common_train_kwargs["gradient_checkpointing"] = request.gradient_checkpointing
+
+ # Allow extra_options to override training kwargs
+ for flag in ("use_cpu", "bf16", "fp16", "gradient_checkpointing"):
+ if flag in extra:
+ common_train_kwargs[flag] = extra[flag].lower() == "true"
+
+ # Create progress callback
+ progress_cb = ProgressCallback(job.job_id, job.progress_queue, num_epochs)
+
+ # Build save kwargs (shared across all methods)
+ _save_kwargs = {}
+ if save_strategy == "steps" and save_steps > 0:
+ _save_kwargs["save_steps"] = save_steps
+ _save_kwargs["save_strategy"] = "steps"
+ elif save_strategy == "epoch":
+ _save_kwargs["save_strategy"] = "epoch"
+ elif save_strategy == "no":
+ _save_kwargs["save_strategy"] = "no"
+ else:
+ _save_kwargs["save_steps"] = save_steps
+ _save_kwargs["save_strategy"] = "steps"
+ if save_total_limit:
+ _save_kwargs["save_total_limit"] = save_total_limit
+
+ # Common training arguments shared by all methods
+ _common_args = dict(
+ output_dir=output_dir,
+ num_train_epochs=num_epochs,
+ per_device_train_batch_size=batch_size,
+ learning_rate=lr,
+ gradient_accumulation_steps=grad_accum,
+ warmup_steps=warmup_steps,
+ weight_decay=weight_decay,
+ max_steps=max_steps,
+ seed=seed,
+ optim=optimizer,
+ logging_steps=1,
+ report_to="none",
+ **_save_kwargs,
+ **common_train_kwargs,
+ )
+
+ # Select trainer based on training method
+ if training_method == "sft":
+ from trl import SFTTrainer, SFTConfig
+
+ max_length = int(extra.get("max_seq_length", "512"))
+ packing = extra.get("packing", "false").lower() == "true"
+
+ training_args = SFTConfig(
+ max_length=max_length,
+ packing=packing,
+ **_common_args,
+ )
+
+ trainer = SFTTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset,
+ processing_class=tokenizer,
+ callbacks=[progress_cb.get_callback()],
+ )
+
+ elif training_method == "dpo":
+ from trl import DPOTrainer, DPOConfig
+
+ beta = float(extra.get("beta", "0.1"))
+ loss_type = extra.get("loss_type", "sigmoid")
+ max_length = int(extra.get("max_length", "512"))
+
+ training_args = DPOConfig(
+ beta=beta,
+ loss_type=loss_type,
+ max_length=max_length,
+ **_common_args,
+ )
+
+ trainer = DPOTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset,
+ processing_class=tokenizer,
+ callbacks=[progress_cb.get_callback()],
+ )
+
+ elif training_method == "grpo":
+ from trl import GRPOTrainer, GRPOConfig
+
+ num_generations = int(extra.get("num_generations", "4"))
+ max_completion_length = int(extra.get("max_completion_length", "256"))
+
+ training_args = GRPOConfig(
+ num_generations=num_generations,
+ max_completion_length=max_completion_length,
+ **_common_args,
+ )
+
+ # GRPO requires reward functions passed via extra_options as a JSON list
+ from reward_functions import build_reward_functions
+
+ reward_funcs = []
+ if extra.get("reward_funcs"):
+ reward_funcs = build_reward_functions(extra["reward_funcs"])
+
+ if not reward_funcs:
+ raise ValueError(
+ "GRPO requires at least one reward function. "
+ "Specify reward_functions in the request or "
+ "reward_funcs in extra_options."
+ )
+
+ trainer = GRPOTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset,
+ processing_class=tokenizer,
+ reward_funcs=reward_funcs,
+ callbacks=[progress_cb.get_callback()],
+ )
+
+ elif training_method == "orpo":
+ from trl import ORPOTrainer, ORPOConfig
+
+ beta = float(extra.get("beta", "0.1"))
+ max_length = int(extra.get("max_length", "512"))
+
+ training_args = ORPOConfig(
+ beta=beta,
+ max_length=max_length,
+ **_common_args,
+ )
+
+ trainer = ORPOTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset,
+ processing_class=tokenizer,
+ callbacks=[progress_cb.get_callback()],
+ )
+
+ elif training_method == "kto":
+ from trl import KTOTrainer, KTOConfig
+
+ beta = float(extra.get("beta", "0.1"))
+ max_length = int(extra.get("max_length", "512"))
+
+ training_args = KTOConfig(
+ beta=beta,
+ max_length=max_length,
+ **_common_args,
+ )
+
+ trainer = KTOTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset,
+ processing_class=tokenizer,
+ callbacks=[progress_cb.get_callback()],
+ )
+
+ elif training_method == "rloo":
+ from trl import RLOOTrainer, RLOOConfig
+
+ num_generations = int(extra.get("num_generations", "4"))
+ max_completion_length = int(extra.get("max_completion_length", "256"))
+
+ training_args = RLOOConfig(
+ num_generations=num_generations,
+ max_new_tokens=max_completion_length,
+ **_common_args,
+ )
+
+ trainer = RLOOTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset,
+ processing_class=tokenizer,
+ callbacks=[progress_cb.get_callback()],
+ )
+
+ elif training_method == "reward":
+ from trl import RewardTrainer, RewardConfig
+
+ max_length = int(extra.get("max_length", "512"))
+
+ training_args = RewardConfig(
+ max_length=max_length,
+ **_common_args,
+ )
+
+ trainer = RewardTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset,
+ processing_class=tokenizer,
+ callbacks=[progress_cb.get_callback()],
+ )
+
+ else:
+ raise ValueError(f"Unsupported training method: {training_method}. "
+ "Supported: sft, dpo, grpo, orpo, kto, rloo, reward")
+
+ job.trainer = trainer
+
+ # Start training
+ job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
+ job_id=job.job_id, status="training", message="Training started",
+ ))
+
+ resume_ckpt = request.resume_from_checkpoint if request.resume_from_checkpoint else None
+ trainer.train(resume_from_checkpoint=resume_ckpt)
+
+ # Save final model
+ trainer.save_model(output_dir)
+ if tokenizer:
+ tokenizer.save_pretrained(output_dir)
+
+ job.completed = True
+ # Sentinel to signal stream end
+ job.progress_queue.put(None)
+
+ def FineTuneProgress(self, request, context):
+ if self.active_job is None or self.active_job.job_id != request.job_id:
+ context.set_code(grpc.StatusCode.NOT_FOUND)
+ context.set_details(f"Job {request.job_id} not found")
+ return
+
+ job = self.active_job
+ while True:
+ try:
+ update = job.progress_queue.get(timeout=1.0)
+ if update is None:
+ break
+ yield update
+ if update.status in ("completed", "failed", "stopped"):
+ break
+ except queue.Empty:
+ if job.completed or job.stopped:
+ break
+ if not context.is_active():
+ break
+ continue
+
+ def StopFineTune(self, request, context):
+ # No-op: stopping is handled by killing the backend process from Go.
+ # This stub remains to satisfy the proto-generated gRPC interface.
+ return backend_pb2.Result(success=True, message="No-op (process kill used instead)")
+
+ def ListCheckpoints(self, request, context):
+ output_dir = request.output_dir
+ if not os.path.isdir(output_dir):
+ return backend_pb2.ListCheckpointsResponse(checkpoints=[])
+
+ checkpoints = []
+ for entry in sorted(os.listdir(output_dir)):
+ if entry.startswith("checkpoint-"):
+ ckpt_path = os.path.join(output_dir, entry)
+ if not os.path.isdir(ckpt_path):
+ continue
+ step = 0
+ try:
+ step = int(entry.split("-")[1])
+ except (IndexError, ValueError):
+ pass
+
+ # Try to read trainer_state.json for metadata
+ loss = 0.0
+ epoch = 0.0
+ state_file = os.path.join(ckpt_path, "trainer_state.json")
+ if os.path.exists(state_file):
+ try:
+ with open(state_file) as f:
+ state = json.load(f)
+ if state.get("log_history"):
+ last_log = state["log_history"][-1]
+ loss = last_log.get("loss", 0.0)
+ epoch = last_log.get("epoch", 0.0)
+ except Exception:
+ pass
+
+ created_at = time.strftime(
+ "%Y-%m-%dT%H:%M:%SZ",
+ time.gmtime(os.path.getmtime(ckpt_path)),
+ )
+
+ checkpoints.append(backend_pb2.CheckpointInfo(
+ path=ckpt_path,
+ step=step,
+ epoch=float(epoch),
+ loss=float(loss),
+ created_at=created_at,
+ ))
+
+ return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints)
+
+ def ExportModel(self, request, context):
+ export_format = request.export_format or "lora"
+ output_path = request.output_path
+ checkpoint_path = request.checkpoint_path
+
+ # Extract HF token for gated model access
+ extra = dict(request.extra_options) if request.extra_options else {}
+ hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
+
+ if not checkpoint_path or not os.path.isdir(checkpoint_path):
+ return backend_pb2.Result(success=False, message=f"Checkpoint not found: {checkpoint_path}")
+
+ os.makedirs(output_path, exist_ok=True)
+
+ try:
+ if export_format == "lora":
+ # Just copy the adapter files
+ import shutil
+ for f in os.listdir(checkpoint_path):
+ src = os.path.join(checkpoint_path, f)
+ dst = os.path.join(output_path, f)
+ if os.path.isfile(src):
+ shutil.copy2(src, dst)
+
+ elif export_format in ("merged_16bit", "merged_4bit"):
+ import torch
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ from peft import PeftModel
+
+ base_model_name = request.model
+ if not base_model_name:
+ return backend_pb2.Result(success=False, message="Base model name required for merge export")
+
+ dtype = torch.float16 if export_format == "merged_16bit" else torch.float32
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=dtype, token=hf_token)
+ model = PeftModel.from_pretrained(base_model, checkpoint_path)
+ merged = model.merge_and_unload()
+ merged.save_pretrained(output_path)
+
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
+ tokenizer.save_pretrained(output_path)
+
+ elif export_format == "gguf":
+ import torch
+ import subprocess
+ import shutil
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+ from peft import PeftModel
+
+ base_model_name = request.model
+ if not base_model_name:
+ return backend_pb2.Result(success=False, message="Base model name required for GGUF export")
+
+ # Step 1: Merge LoRA into base model
+ merge_dir = os.path.join(output_path, "_hf_merged")
+ os.makedirs(merge_dir, exist_ok=True)
+
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float16, token=hf_token)
+ model = PeftModel.from_pretrained(base_model, checkpoint_path)
+ merged = model.merge_and_unload()
+ merged.save_pretrained(merge_dir)
+
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name, token=hf_token)
+ tokenizer.save_pretrained(merge_dir)
+
+ # Ensure tokenizer.model (SentencePiece) is present in merge_dir.
+ # Gemma models need this file for GGUF conversion to use the
+ # SentencePiece path; without it, the script falls back to BPE
+ # handling which fails on unrecognized pre-tokenizer hashes.
+ sp_model_path = os.path.join(merge_dir, "tokenizer.model")
+ if not os.path.exists(sp_model_path):
+ sp_copied = False
+ # Method 1: Load the slow tokenizer which keeps the SP model file
+ try:
+ slow_tok = AutoTokenizer.from_pretrained(base_model_name, use_fast=False, token=hf_token)
+ if hasattr(slow_tok, 'vocab_file') and slow_tok.vocab_file and os.path.exists(slow_tok.vocab_file):
+ import shutil as _shutil
+ _shutil.copy2(slow_tok.vocab_file, sp_model_path)
+ sp_copied = True
+ print(f"Copied tokenizer.model from slow tokenizer cache")
+ except Exception as e:
+ print(f"Slow tokenizer method failed: {e}")
+ # Method 2: Download from HF hub
+ if not sp_copied:
+ try:
+ from huggingface_hub import hf_hub_download
+ cached_sp = hf_hub_download(repo_id=base_model_name, filename="tokenizer.model", token=hf_token)
+ import shutil as _shutil
+ _shutil.copy2(cached_sp, sp_model_path)
+ sp_copied = True
+ print(f"Copied tokenizer.model from HF hub")
+ except Exception as e:
+ print(f"HF hub download method failed: {e}")
+ if not sp_copied:
+ print(f"WARNING: Could not obtain tokenizer.model for {base_model_name}. "
+ "GGUF conversion may fail for SentencePiece models.")
+
+ # Free GPU memory before conversion
+ del merged, model, base_model
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Step 2: Convert to GGUF using convert_hf_to_gguf.py
+ quant = request.quantization_method or "auto"
+ outtype_map = {"f16": "f16", "f32": "f32", "bf16": "bf16", "q8_0": "q8_0", "auto": "auto"}
+ outtype = outtype_map.get(quant, "f16")
+
+ gguf_filename = f"{os.path.basename(output_path)}-{outtype}.gguf"
+ gguf_path = os.path.join(output_path, gguf_filename)
+
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ convert_script = os.path.join(script_dir, "convert_hf_to_gguf.py")
+ if not os.path.exists(convert_script):
+ return backend_pb2.Result(success=False,
+ message="convert_hf_to_gguf.py not found. Install the GGUF conversion tools.")
+
+ # Log merge_dir contents for debugging conversion issues
+ merge_files = os.listdir(merge_dir) if os.path.isdir(merge_dir) else []
+ print(f"Merge dir contents: {merge_files}", flush=True)
+
+ env = os.environ.copy()
+ env["NO_LOCAL_GGUF"] = "1"
+ cmd = [sys.executable, convert_script, merge_dir, "--outtype", outtype, "--outfile", gguf_path]
+ conv_result = subprocess.run(cmd, capture_output=True, text=True, timeout=3600, env=env)
+ if conv_result.returncode != 0:
+ diag = f"stdout: {conv_result.stdout[-300:]}\nstderr: {conv_result.stderr[-500:]}"
+ return backend_pb2.Result(success=False,
+ message=f"GGUF conversion failed: {diag}")
+
+ # Clean up intermediate merged model
+ shutil.rmtree(merge_dir, ignore_errors=True)
+ else:
+ return backend_pb2.Result(success=False, message=f"Unsupported export format: {export_format}")
+
+ except Exception as e:
+ if _is_gated_repo_error(e):
+ return backend_pb2.Result(success=False,
+ message=f"Model '{request.model}' is a gated HuggingFace repo and requires authentication. "
+ "Pass 'hf_token' in extra_options or set the HF_TOKEN environment variable.")
+ return backend_pb2.Result(success=False, message=f"Export failed: {e}")
+
+ return backend_pb2.Result(success=True, message=f"Model exported to {output_path}")
+
+
+def serve(address):
+ server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
+ options=[
+ ('grpc.max_message_length', 50 * 1024 * 1024),
+ ('grpc.max_send_message_length', 50 * 1024 * 1024),
+ ('grpc.max_receive_message_length', 50 * 1024 * 1024),
+ ],
+ )
+ backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
+ server.add_insecure_port(address)
+ server.start()
+ print(f"TRL fine-tuning backend listening on {address}", file=sys.stderr, flush=True)
+
+ # Handle graceful shutdown
+ def stop(signum, frame):
+ server.stop(0)
+ sys.exit(0)
+
+ signal.signal(signal.SIGTERM, stop)
+ signal.signal(signal.SIGINT, stop)
+
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except KeyboardInterrupt:
+ server.stop(0)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="TRL fine-tuning gRPC backend")
+ parser.add_argument("--addr", default="localhost:50051", help="gRPC server address")
+ args = parser.parse_args()
+ serve(args.addr)
diff --git a/backend/python/trl/install.sh b/backend/python/trl/install.sh
new file mode 100644
index 000000000..6963e60ed
--- /dev/null
+++ b/backend/python/trl/install.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
+installRequirements
+
+# Fetch convert_hf_to_gguf.py and gguf package from the same llama.cpp version
+LLAMA_CPP_CONVERT_VERSION="${LLAMA_CPP_CONVERT_VERSION:-master}"
+CONVERT_SCRIPT="${EDIR}/convert_hf_to_gguf.py"
+if [ ! -f "${CONVERT_SCRIPT}" ]; then
+ echo "Downloading convert_hf_to_gguf.py from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
+ curl -L --fail --retry 3 \
+ "https://raw.githubusercontent.com/ggml-org/llama.cpp/${LLAMA_CPP_CONVERT_VERSION}/convert_hf_to_gguf.py" \
+ -o "${CONVERT_SCRIPT}" || echo "Warning: Failed to download convert_hf_to_gguf.py. GGUF export will not be available."
+fi
+
+# Install gguf package from the same llama.cpp commit to keep them in sync
+GGUF_PIP_SPEC="gguf @ git+https://github.com/ggml-org/llama.cpp@${LLAMA_CPP_CONVERT_VERSION}#subdirectory=gguf-py"
+echo "Installing gguf package from llama.cpp (${LLAMA_CPP_CONVERT_VERSION})..."
+if [ "x${USE_PIP:-}" == "xtrue" ]; then
+ pip install "${GGUF_PIP_SPEC}" || {
+ echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
+ pip install "gguf>=0.16.0"
+ }
+else
+ uv pip install "${GGUF_PIP_SPEC}" || {
+ echo "Warning: Failed to install gguf from llama.cpp commit, falling back to PyPI..."
+ uv pip install "gguf>=0.16.0"
+ }
+fi
diff --git a/backend/python/trl/requirements-cpu.txt b/backend/python/trl/requirements-cpu.txt
new file mode 100644
index 000000000..c67858542
--- /dev/null
+++ b/backend/python/trl/requirements-cpu.txt
@@ -0,0 +1,9 @@
+--extra-index-url https://download.pytorch.org/whl/cpu
+torch==2.10.0
+trl
+peft
+datasets>=3.0.0
+transformers>=4.56.2
+accelerate>=1.4.0
+huggingface-hub>=1.3.0
+sentencepiece
diff --git a/backend/python/trl/requirements-cublas12.txt b/backend/python/trl/requirements-cublas12.txt
new file mode 100644
index 000000000..05f29591c
--- /dev/null
+++ b/backend/python/trl/requirements-cublas12.txt
@@ -0,0 +1,9 @@
+torch==2.10.0
+trl
+peft
+datasets>=3.0.0
+transformers>=4.56.2
+accelerate>=1.4.0
+huggingface-hub>=1.3.0
+sentencepiece
+bitsandbytes
diff --git a/backend/python/trl/requirements-cublas13.txt b/backend/python/trl/requirements-cublas13.txt
new file mode 100644
index 000000000..05f29591c
--- /dev/null
+++ b/backend/python/trl/requirements-cublas13.txt
@@ -0,0 +1,9 @@
+torch==2.10.0
+trl
+peft
+datasets>=3.0.0
+transformers>=4.56.2
+accelerate>=1.4.0
+huggingface-hub>=1.3.0
+sentencepiece
+bitsandbytes
diff --git a/backend/python/trl/requirements.txt b/backend/python/trl/requirements.txt
new file mode 100644
index 000000000..0834a8fcd
--- /dev/null
+++ b/backend/python/trl/requirements.txt
@@ -0,0 +1,3 @@
+grpcio==1.78.1
+protobuf
+certifi
diff --git a/backend/python/trl/reward_functions.py b/backend/python/trl/reward_functions.py
new file mode 100644
index 000000000..12074f80c
--- /dev/null
+++ b/backend/python/trl/reward_functions.py
@@ -0,0 +1,236 @@
+"""
+Built-in reward functions and inline function compiler for GRPO training.
+
+All reward functions follow TRL's signature: (completions, **kwargs) -> list[float]
+"""
+
+import json
+import re
+import math
+import string
+import functools
+
+
+# ---------------------------------------------------------------------------
+# Built-in reward functions
+# ---------------------------------------------------------------------------
+
+def format_reward(completions, **kwargs):
+ """Checks for ... followed by an answer. Returns 1.0 or 0.0."""
+ pattern = re.compile(r".*?\s*\S", re.DOTALL)
+ return [1.0 if pattern.search(c) else 0.0 for c in completions]
+
+
+def reasoning_accuracy_reward(completions, **kwargs):
+ """Extracts ... content and compares to the expected answer."""
+ answers = kwargs.get("answer", [])
+ if not answers:
+ return [0.0] * len(completions)
+
+ pattern = re.compile(r"(.*?)", re.DOTALL)
+ scores = []
+ for i, c in enumerate(completions):
+ expected = answers[i] if i < len(answers) else ""
+ match = pattern.search(c)
+ if match:
+ extracted = match.group(1).strip()
+ scores.append(1.0 if extracted.lower() == str(expected).strip().lower() else 0.0)
+ else:
+ scores.append(0.0)
+ return scores
+
+
+def length_reward(completions, target_length=200, **kwargs):
+ """Score based on proximity to target_length. Returns [0, 1]."""
+ scores = []
+ for c in completions:
+ length = len(c)
+ if target_length <= 0:
+ scores.append(0.0)
+ else:
+ diff = abs(length - target_length) / target_length
+ scores.append(max(0.0, 1.0 - diff))
+ return scores
+
+
+def xml_tag_reward(completions, **kwargs):
+ """Scores properly opened/closed XML tags (, )."""
+ tags = ["think", "answer"]
+ scores = []
+ for c in completions:
+ tag_score = 0.0
+ for tag in tags:
+ if f"<{tag}>" in c and f"{tag}>" in c:
+ tag_score += 0.5
+ scores.append(min(tag_score, 1.0))
+ return scores
+
+
+def no_repetition_reward(completions, n=4, **kwargs):
+ """Penalizes n-gram repetition. Returns [0, 1]."""
+ scores = []
+ for c in completions:
+ words = c.split()
+ if len(words) < n:
+ scores.append(1.0)
+ continue
+ ngrams = [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]
+ unique = len(set(ngrams))
+ total = len(ngrams)
+ scores.append(unique / total if total > 0 else 1.0)
+ return scores
+
+
+def code_execution_reward(completions, **kwargs):
+ """Checks Python code block syntax validity via compile(). Returns 1.0 or 0.0."""
+ pattern = re.compile(r"```python\s*\n(.*?)```", re.DOTALL)
+ scores = []
+ for c in completions:
+ match = pattern.search(c)
+ if not match:
+ scores.append(0.0)
+ continue
+ code = match.group(1)
+ try:
+ compile(code, "", "exec")
+ scores.append(1.0)
+ except SyntaxError:
+ scores.append(0.0)
+ return scores
+
+
+# ---------------------------------------------------------------------------
+# Registry
+# ---------------------------------------------------------------------------
+
+BUILTIN_REGISTRY = {
+ "format_reward": format_reward,
+ "reasoning_accuracy_reward": reasoning_accuracy_reward,
+ "length_reward": length_reward,
+ "xml_tag_reward": xml_tag_reward,
+ "no_repetition_reward": no_repetition_reward,
+ "code_execution_reward": code_execution_reward,
+}
+
+
+# ---------------------------------------------------------------------------
+# Inline function compiler
+# ---------------------------------------------------------------------------
+
+_SAFE_BUILTINS = {
+ "len": len, "int": int, "float": float, "str": str, "bool": bool,
+ "list": list, "dict": dict, "tuple": tuple, "set": set,
+ "range": range, "enumerate": enumerate, "zip": zip,
+ "map": map, "filter": filter, "sorted": sorted,
+ "min": min, "max": max, "sum": sum, "abs": abs, "round": round,
+ "any": any, "all": all, "isinstance": isinstance, "type": type,
+ "print": print, "True": True, "False": False, "None": None,
+ "ValueError": ValueError, "TypeError": TypeError,
+ "KeyError": KeyError, "IndexError": IndexError,
+}
+
+
+def compile_inline_reward(name, code):
+ """Compile user-provided code into a reward function.
+
+ The code should be the body of a function that receives
+ `completions` (list[str]) and `**kwargs`, and returns list[float].
+
+ Available modules: re, math, json, string.
+ """
+ func_source = (
+ f"def _user_reward_{name}(completions, **kwargs):\n"
+ + "\n".join(f" {line}" for line in code.splitlines())
+ )
+
+ restricted_globals = {
+ "__builtins__": _SAFE_BUILTINS,
+ "re": re,
+ "math": math,
+ "json": json,
+ "string": string,
+ }
+
+ try:
+ compiled = compile(func_source, f"", "exec")
+ except SyntaxError as e:
+ raise ValueError(f"Syntax error in inline reward function '{name}': {e}")
+
+ exec(compiled, restricted_globals)
+ func = restricted_globals[f"_user_reward_{name}"]
+
+ # Validate with a quick smoke test
+ try:
+ result = func(["test"], answer=["test"])
+ if not isinstance(result, list):
+ raise ValueError(
+ f"Inline reward function '{name}' must return a list, got {type(result).__name__}"
+ )
+ except Exception as e:
+ if "must return a list" in str(e):
+ raise
+ # Other errors during smoke test are acceptable (e.g. missing kwargs)
+ pass
+
+ return func
+
+
+# ---------------------------------------------------------------------------
+# Dispatcher
+# ---------------------------------------------------------------------------
+
+def build_reward_functions(specs_json):
+ """Parse a JSON list of reward function specs and return a list of callables.
+
+ Each spec is a dict with:
+ - type: "builtin" or "inline"
+ - name: function name
+ - code: (inline only) Python function body
+ - params: (optional) dict of string params applied via functools.partial
+ """
+ if isinstance(specs_json, str):
+ specs = json.loads(specs_json)
+ else:
+ specs = specs_json
+
+ if not isinstance(specs, list):
+ raise ValueError("reward_funcs must be a JSON array of reward function specs")
+
+ reward_funcs = []
+ for spec in specs:
+ spec_type = spec.get("type", "builtin")
+ name = spec.get("name", "")
+ params = spec.get("params", {})
+
+ if spec_type == "builtin":
+ if name not in BUILTIN_REGISTRY:
+ available = ", ".join(sorted(BUILTIN_REGISTRY.keys()))
+ raise ValueError(
+ f"Unknown builtin reward function '{name}'. Available: {available}"
+ )
+ func = BUILTIN_REGISTRY[name]
+ if params:
+ # Convert string params to appropriate types
+ typed_params = {}
+ for k, v in params.items():
+ try:
+ typed_params[k] = int(v)
+ except (ValueError, TypeError):
+ try:
+ typed_params[k] = float(v)
+ except (ValueError, TypeError):
+ typed_params[k] = v
+ func = functools.partial(func, **typed_params)
+ reward_funcs.append(func)
+
+ elif spec_type == "inline":
+ code = spec.get("code", "")
+ if not code.strip():
+ raise ValueError(f"Inline reward function '{name}' has no code")
+ func = compile_inline_reward(name, code)
+ reward_funcs.append(func)
+
+ else:
+ raise ValueError(f"Unknown reward function type '{spec_type}'. Use 'builtin' or 'inline'")
+
+ return reward_funcs
diff --git a/backend/python/trl/run.sh b/backend/python/trl/run.sh
new file mode 100644
index 000000000..bd17c6e1d
--- /dev/null
+++ b/backend/python/trl/run.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+startBackend $@
diff --git a/backend/python/trl/test.py b/backend/python/trl/test.py
new file mode 100644
index 000000000..d77d4e9f0
--- /dev/null
+++ b/backend/python/trl/test.py
@@ -0,0 +1,58 @@
+"""
+Test script for the TRL fine-tuning gRPC backend.
+"""
+import unittest
+import subprocess
+import time
+
+import grpc
+import backend_pb2
+import backend_pb2_grpc
+
+
+class TestBackendServicer(unittest.TestCase):
+ """Tests for the TRL fine-tuning gRPC service."""
+
+ def setUp(self):
+ self.service = subprocess.Popen(
+ ["python3", "backend.py", "--addr", "localhost:50051"]
+ )
+ time.sleep(10)
+
+ def tearDown(self):
+ self.service.kill()
+ self.service.wait()
+
+ def test_server_startup(self):
+ """Test that the server starts and responds to health checks."""
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.Health(backend_pb2.HealthMessage())
+ self.assertEqual(response.message, b'OK')
+ except Exception as err:
+ print(err)
+ self.fail("Server failed to start")
+ finally:
+ self.tearDown()
+
+ def test_list_checkpoints_empty(self):
+ """Test listing checkpoints on a non-existent directory."""
+ try:
+ self.setUp()
+ with grpc.insecure_channel("localhost:50051") as channel:
+ stub = backend_pb2_grpc.BackendStub(channel)
+ response = stub.ListCheckpoints(
+ backend_pb2.ListCheckpointsRequest(output_dir="/nonexistent")
+ )
+ self.assertEqual(len(response.checkpoints), 0)
+ except Exception as err:
+ print(err)
+ self.fail("ListCheckpoints service failed")
+ finally:
+ self.tearDown()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/backend/python/trl/test.sh b/backend/python/trl/test.sh
new file mode 100644
index 000000000..eb59f2aaf
--- /dev/null
+++ b/backend/python/trl/test.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+set -e
+
+backend_dir=$(dirname $0)
+if [ -d $backend_dir/common ]; then
+ source $backend_dir/common/libbackend.sh
+else
+ source $backend_dir/../common/libbackend.sh
+fi
+
+runUnittests
diff --git a/core/services/finetune.go b/core/services/finetune.go
index 568922b10..c49577fbf 100644
--- a/core/services/finetune.go
+++ b/core/services/finetune.go
@@ -265,25 +265,16 @@ func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, sav
}
s.mu.Unlock()
- backendModel, err := s.modelLoader.Load(
- model.WithBackendString(job.Backend),
- model.WithModel(job.Backend),
- model.WithModelID(job.Backend+"-finetune"),
- )
+ // Kill the backend process directly — gRPC stop deadlocks on single-threaded Python backends
+ modelID := job.Backend + "-finetune"
+ err := s.modelLoader.ShutdownModel(modelID)
if err != nil {
- return fmt.Errorf("failed to load backend: %w", err)
- }
-
- _, err = backendModel.StopFineTune(ctx, &pb.FineTuneStopRequest{
- JobId: jobID,
- SaveCheckpoint: saveCheckpoint,
- })
- if err != nil {
- return fmt.Errorf("failed to stop job: %w", err)
+ return fmt.Errorf("failed to stop backend: %w", err)
}
s.mu.Lock()
- job.Message = "Stop requested, waiting for training to halt..."
+ job.Status = "stopped"
+ job.Message = "Training stopped by user"
s.saveJobState(job)
s.mu.Unlock()