Files
LocalAI/core/services/finetune/service.go
Richard Palethorpe 670259ce43 chore: Security hardening (#9719)
* fix(http): close 0.0.0.0/[::] SSRF bypass in /api/cors-proxy

The CORS proxy carried its own private-network blocklist (RFC 1918 + a
handful of IPv6 ranges) instead of using the same classification as
pkg/utils/urlfetch.go. The hand-rolled list missed 0.0.0.0/8 and ::/128,
both of which Linux routes to localhost — so any user with FeatureMCP
(default-on for new users) could reach LocalAI's own listener and any
other service bound to 0.0.0.0:port via:

  GET /api/cors-proxy?url=http://0.0.0.0:8080/...
  GET /api/cors-proxy?url=http://[::]:8080/...

Replace the custom check with utils.IsPublicIP (Go stdlib IsLoopback /
IsLinkLocalUnicast / IsPrivate / IsUnspecified, plus IPv4-mapped IPv6
unmasking) and add an upfront hostname rejection for localhost, *.local,
and the cloud metadata aliases so split-horizon DNS can't paper over the
IP check.

The IP-pinning DialContext is unchanged: the validated IP from the
single resolution is reused for the connection, so DNS rebinding still
cannot swap a public answer for a private one between validate and dial.

Regression tests cover 0.0.0.0, 0.0.0.0:PORT, [::], ::ffff:127.0.0.1,
::ffff:10.0.0.1, file://, gopher://, ftp://, localhost, 127.0.0.1,
10.0.0.1, 169.254.169.254, metadata.google.internal.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(downloader): verify SHA before promoting temp file to final path

DownloadFileWithContext renamed the .partial file to its final name
*before* checking the streamed SHA, so a hash mismatch returned an
error but left the tampered file at filePath. Subsequent code that
operated on filePath (a backend launcher, a YAML loader, a re-download
that finds the file already present and skips) would consume the
attacker-supplied bytes.

Reorder: verify the streamed hash first, remove the .partial on
mismatch, then rename. The streamed hash is computed during io.Copy
so no second read is needed.

While here, raise the empty-SHA case from a Debug log to a Warn so
"this download had no integrity check" is visible at the default log
level. Backend installs currently pass through with no digest; the
warning makes that footprint observable without changing behaviour.

Regression test asserts os.IsNotExist on the destination after a
deliberate SHA mismatch.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(auth): require email_verified for OIDC admin promotion

extractOIDCUserInfo read the ID token's "email" claim but never
inspected "email_verified". With LOCALAI_ADMIN_EMAIL set, an attacker
who could register on the configured OIDC IdP under that email (some
IdPs accept self-supplied unverified emails) inherited admin role:

  - first login:  AssignRole(tx, email, adminEmail) → RoleAdmin
  - re-login:     MaybePromote(db, user, adminEmail) → flip to RoleAdmin

Add EmailVerified to oauthUserInfo, parse email_verified from the OIDC
claims (default false on absence so an IdP that omits the claim cannot
short-circuit the gate), and substitute "" for the role-decision email
when verified=false via emailForRoleDecision. The user record still
stores the unverified email for display.

GitHub's path defaults EmailVerified=true: GitHub only returns a public
profile email after verification, and fetchGitHubPrimaryEmail explicitly
filters to Verified=true.

Regression tests cover both the helper contract and integration with
AssignRole, including the bootstrap "first user" branch that would
otherwise mask the gate.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(cli): refuse public bind when no auth backend is configured

When neither an auth DB nor a static API key is set, the auth
middleware passes every request through. That is fine for a developer
laptop, a home LAN, or a Tailnet — the network itself is the trust
boundary. It is not fine on a public IP, where every model install,
settings change, and admin endpoint becomes reachable from the
internet.

Refuse to start in that exact configuration. Loopback, RFC 1918,
RFC 4193 ULA, link-local, and RFC 6598 CGNAT (Tailscale's default
range) all count as trusted; wildcard binds (`:port`, `0.0.0.0`,
`[::]`) are accepted only when every host interface is in one of those
ranges. Hostnames are resolved and treated as trusted only when every
answer is.

A new --allow-insecure-public-bind / LOCALAI_ALLOW_INSECURE_PUBLIC_BIND
flag opts out for deployments that gate access externally (a reverse
proxy enforcing auth, a mesh ACL, etc.). The error message lists this
plus the three constructive alternatives (bind a private interface,
enable --auth, set --api-keys).

The interface enumeration goes through a package-level interfaceAddrsFn
var so tests can simulate cloud-VM, home-LAN, Tailscale-only, and
enumeration-failure topologies without poking at the real network
stack.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* test(http): regression-test the localai_assistant admin gate

ChatEndpoint already rejects metadata.localai_assistant=true from a
non-admin caller, but the gate was open-coded inline with no direct
test coverage. The chat route is FeatureChat-gated (default-on), and
the assistant's in-process MCP server can install/delete models and
edit configs — the wrong handler change would silently turn the LLM
into a confused deputy.

Extract the gate into requireAssistantAccess(c, authEnabled) and pin
its behaviour: auth disabled is a no-op, unauthenticated is 403,
RoleUser is 403, RoleAdmin and the synthetic legacy-key admin are
admitted.

No behaviour change in the production path.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* test(http): assert every API route is auth-classified

The auth middleware classifies path prefixes (/api/, /v1/, /models/,
etc.) as protected and treats anything else as a static-asset
passthrough. A new endpoint shipped under a brand-new prefix — or a
new path that simply isn't on the prefix allowlist — would be
reachable anonymously.

Walk every route registered by API() with auth enabled and a fresh
in-memory database (no users, no keys), and assert each API-prefixed
route returns 401 / 404 / 405 to an anonymous request. Public surfaces
(/api/auth/*, /api/branding, /api/node/* token-authenticated routes,
/healthz, branding asset server, generated-content server, static
assets) are explicit allowlist entries with comments justifying them.

Build-tagged 'auth' so it runs against the SQLite-backed auth DB
(matches the existing auth suite).

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* test(http): pin agent endpoint per-user isolation contract

agents.go's getUserID / effectiveUserID / canImpersonateUser /
wantsAllUsers helpers are the single trust boundary for cross-user
access on agent, agent-jobs, collections, and skills routes. A
regression there is the difference between "regular user reads their
own data" and "regular user reads anyone's data via ?user_id=victim".

Lock in the contract:
  - effectiveUserID ignores ?user_id= for unauthenticated and RoleUser
  - effectiveUserID honours it for RoleAdmin and ProviderAgentWorker
  - wantsAllUsers requires admin AND the literal "true" string
  - canImpersonateUser is admin OR agent-worker, never plain RoleUser

No production change — this commit only adds tests.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(downloader): drop redundant stat in removePartialFile

The stat-then-remove pattern is a TOCTOU window and a wasted syscall —
os.Remove already returns ErrNotExist for the missing-file case, so trust
that and treat it as a no-op.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(http): redact secrets from trace buffer and distribution-token logs

The /api/traces buffer captured Authorization, Cookie, Set-Cookie, and
API-key headers verbatim from every request when tracing was enabled. The
endpoint is admin-only but the buffer is reachable via any heap-style
introspection and the captured tokens otherwise outlive the request.
Strip those header values at capture time. Body redaction is left to a
follow-up — the prompts are usually the operator's own and JSON-walking
is invasive.

Distribution tokens were also logged in plaintext from
core/explorer/discovery.go; logs forward to syslog/journald and outlive
the token. Redact those to a short prefix/suffix instead.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(auth): rate-limit OAuth callbacks separately from password endpoints

The shared 5/min/IP limit on auth endpoints is right for password-style
flows but too tight for OAuth callbacks: corporate SSO funnels many real
users through one outbound IP and would trip the limit. Add a separate
60/min/IP limiter for /api/auth/{github,oidc}/callback so callbacks are
bounded against floods without breaking shared-IP deployments.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(gallery): verify backend tarball sha256 when set in gallery entry

GalleryBackend gained an optional sha256 field; the install path now
threads it through to the existing downloader hash-verify (which already
streams, verifies, and rolls back on mismatch). Galleries without sha256
keep working; the empty-SHA path still emits the existing
"downloading without integrity check" warning.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* test(http): pin CSRF coverage on multipart endpoints

The CSRF middleware in app.go is global (e.Use) so it covers every
multipart upload route — branding assets, fine-tune datasets, audio
transforms, agent collections. Pin that contract: cross-site multipart
POSTs are rejected; same-origin / same-site / API-key clients are not.
Also pins the SameSite=Lax fallback path the skipper relies on when
Sec-Fetch-Site is absent.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(http): XSS hardening — CSP headers, safe href, base-href escape, SVG sandbox

Several closely related XSS-prevention changes spanning the SPA shell, the
React UI, and the branding asset server:

- New SecurityHeaders middleware sets CSP, X-Content-Type-Options,
  X-Frame-Options, and Referrer-Policy on every response. The CSP keeps
  script-src permissive because the Vite bundle relies on inline + eval'd
  scripts; tightening that requires moving to a nonce-based policy.

- The <base href> injection in the SPA shell escaped attacker-controllable
  Host / X-Forwarded-Host headers — a single quote in the host header
  broke out of the attribute. Pass through SecureBaseHref (html.EscapeString).

- Three React sinks rendering untrusted content via dangerouslySetInnerHTML
  switch to text-node rendering with whiteSpace: pre-wrap: user message
  bodies in Chat.jsx and AgentChat.jsx, and the agent activity log in
  AgentChat.jsx. The hand-rolled escape on the agent user-message variant
  is replaced by the same plain-text path.

- New safeHref util collapses non-allowlisted URI schemes (most
  importantly javascript:) to '#'. Applied to gallery `<a href={url}>`
  links in Models / Backends / Manage and to canvas artifact links —
  these come from gallery JSON or assistant tool calls and must be treated
  as untrusted.

- The branding asset server attaches a sandbox CSP plus same-origin CORP
  to .svg responses. The React UI loads logos via <img>, but the same URL
  is also reachable via direct navigation; this prevents script
  execution if a hostile SVG slipped past upload validation.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(http): bound HTTP server with read-header and idle timeouts

A net/http server with no timeouts is trivially Slowloris-able and leaks
idle keep-alive connections. Set ReadHeaderTimeout (30s) to plug the
slow-headers attack and IdleTimeout (120s) to cap keep-alive sockets.

ReadTimeout and WriteTimeout stay at 0 because request bodies can be
multi-GB model uploads and SSE / chat completions stream for many
minutes; operators who need tighter per-request bounds should terminate
slow clients at a reverse proxy.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* test(auth): pin PUT /api/auth/profile field-tampering contract

The handler uses an explicit local body struct (only name and avatar_url)
plus a gorm Updates(map) with a column allowlist, so an attacker posting
{"role":"admin","email":"...","password_hash":"..."} can't mass-assign
those fields. Lock that down with a regression test so a future
"let's just c.Bind(&user)" refactor breaks loudly.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(services): strip directory components from multipart upload filenames

UploadDataset and UploadToCollectionForUser took the raw multipart
file.Filename and joined it into a destination path. The fine-tune
upload was incidentally safe because of a UUID prefix that fused any
leading '..' to a literal segment, but the protection is fragile.
UploadToCollectionForUser handed the filename to a vendored backend
without sanitising at all.

Strip to filepath.Base at both boundaries and reject the trivial
unsafe values ("", ".", "..", "/").

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(react-ui): validate persisted MCP server entries on load

localStorage is shared across same-origin pages; an XSS that lands once
can poison persisted MCP server config to attempt header injection or
to feed a non-http URL into the fetch path on subsequent loads.
Validate every entry: types must match, URL must parse with http(s)
scheme, header keys/values must be control-char-free. Drop anything
that doesn't fit.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(http): close X-Forwarded-Prefix open redirect

The reverse-proxy support concatenated X-Forwarded-Prefix into the
redirect target without validation, so a forged header value of
"//evil.com" turned the SPA-shell redirect helper at /, /browse, and
/browse/* into a 301 to //evil.com/app. The path-strip middleware had
the same shape on its prefix-trailing-slash redirect.

Add SafeForwardedPrefix at the middleware boundary: must start with
a single '/', no protocol-relative '//' opener, no scheme, no
backslash, no control characters. Apply at both consumers; misconfig
trips the validator and the header is dropped.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(http): refuse wildcard CORS when LOCALAI_CORS=true with empty allowlist

When LOCALAI_CORS=true but LOCALAI_CORS_ALLOW_ORIGINS was empty, Echo's
CORSWithConfig saw an empty allow-list and fell back to its default
AllowOrigins=["*"]. An operator who flipped the strict-CORS feature
flag without populating the list got the opposite of what they asked
for. Echo never sets Allow-Credentials: true so this isn't directly
exploitable (cookies aren't sent under wildcard CORS), but the
misconfiguration trap is worth closing. Skip the registration and warn.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(auth): zxcvbn password strength check with user-acknowledged override

The previous policy was len < 8, which let through "Password1" and the
rest of the credential-stuffing corpus. LocalAI has no second factor
yet, so the bar needs to sit higher.

Add ValidatePasswordStrength using github.com/timbutler/zxcvbn (an
actively-maintained fork of the trustelem port; v1.0.4, April 2024):
- min 12 chars, max 72 (bcrypt's truncation point)
- reject NUL bytes (some bcrypt callers truncate at the first NUL)
- require zxcvbn score >= 3 ("safely unguessable, ~10^8 guesses to
  break"); the hint list ["localai", "local-ai", "admin"] penalises
  passwords built from the app's own branding

zxcvbn produces false positives sometimes (a strong-looking password
that happens to match a dictionary word) and operators occasionally
need to set a known-weak password (kiosk demos, CI rigs). Add an
acknowledgement path: PasswordPolicy{AllowWeak: true} skips the
entropy check while still enforcing the hard rules. The structured
PasswordErrorResponse marks weak-password rejections as Overridable
so the UI can surface a "use this anyway" checkbox.

Wired through register, self-service password change, and admin
password reset on both the server and the React UI.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(react-ui): drop HTML5 minLength on new-password inputs

minLength={12} on the new-password input let the browser block the
form submit silently before any JS or network call ran. The browser
focused the field, showed a brief native tooltip, and that was that —
no toast, no fetch, no clue. Reproducible by typing fewer than 12
chars on the second password change of a session.

The JS-level length check in handleSubmit already shows a toast and
the server rejects with a structured error, so the HTML5 attribute
was redundant defence anyway. Drop it.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(react-ui): bundle Geist fonts locally instead of fetching from Google

The new CSP correctly refused to apply styles from
fonts.googleapis.com because style-src is locked to 'self' and
'unsafe-inline'. Loosening the CSP would defeat its purpose; the
right fix is to stop reaching out to a third-party CDN for fonts on
every page load.

Add @fontsource-variable/geist and @fontsource-variable/geist-mono as
npm deps and import them once at boot. Drop the <link rel="preconnect">
and external stylesheet from index.html.

Side benefit: no third-party tracking via Referer / IP on every UI
load, no failure mode when offline / behind a captive portal.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* fix(react-ui): refresh i18n strings to reflect 12-char password minimum

The translations still said "at least 8 characters" everywhere — the
client-side toast on a too-short password change told the user the
wrong floor. Update tooShort and newPasswordPlaceholder /
newPasswordDescription across all five locales (en, es, it, de,
zh-CN) to match the real ValidatePasswordStrength rule.

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(auth): make password length-floor overridable like the entropy check

The 12-char minimum was a policy choice, not a technical invariant —
only "non-empty", "<= 72 bytes", and "no NUL bytes" are real bcrypt
constraints. Treating length-12 as a hard rule was inconsistent with
the entropy check (already overridable) and friction for use cases
where the account is just a name on a session, not a security
boundary (single-user kiosk, CI rig, lab demo).

Restructure ValidatePasswordStrength:
- Hard rules (always enforced): non-empty, <= MaxPasswordLength, no NUL byte
- Policy rules (skipped when AllowWeak=true): length >= 12, zxcvbn score >= 3

PasswordError now marks password_too_short as Overridable too. The
React forms generalised from `error_code === 'password_too_weak'` to
`overridable === true`, and the JS-side preflight length checks were
removed (server is source of truth, returns the same checkbox flow).

Assisted-by: Claude:claude-opus-4-7 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-05-08 16:25:45 +02:00

762 lines
22 KiB
Go

package finetune
import (
"cmp"
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"regexp"
"slices"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery/importers"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services/distributed"
"github.com/mudler/LocalAI/core/services/messaging"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/xlog"
"gopkg.in/yaml.v3"
)
// FineTuneService manages fine-tuning jobs and their lifecycle.
type FineTuneService struct {
appConfig *config.ApplicationConfig
modelLoader *model.ModelLoader
configLoader *config.ModelConfigLoader
mu sync.Mutex
jobs map[string]*schema.FineTuneJob
// Distributed mode (nil when not in distributed mode)
natsClient messaging.Publisher
fineTuneStore *distributed.FineTuneStore
}
// SetNATSClient sets the NATS client for distributed progress publishing.
func (s *FineTuneService) SetNATSClient(nc messaging.Publisher) {
s.mu.Lock()
defer s.mu.Unlock()
s.natsClient = nc
}
// SetFineTuneStore sets the PostgreSQL fine-tune store for distributed persistence.
func (s *FineTuneService) SetFineTuneStore(store *distributed.FineTuneStore) {
s.mu.Lock()
defer s.mu.Unlock()
s.fineTuneStore = store
}
// NewFineTuneService creates a new FineTuneService.
func NewFineTuneService(
appConfig *config.ApplicationConfig,
modelLoader *model.ModelLoader,
configLoader *config.ModelConfigLoader,
) *FineTuneService {
s := &FineTuneService{
appConfig: appConfig,
modelLoader: modelLoader,
configLoader: configLoader,
jobs: make(map[string]*schema.FineTuneJob),
}
s.loadAllJobs()
return s
}
// fineTuneBaseDir returns the base directory for fine-tune job data.
func (s *FineTuneService) fineTuneBaseDir() string {
return filepath.Join(s.appConfig.DataPath, "fine-tune")
}
// jobDir returns the directory for a specific job.
func (s *FineTuneService) jobDir(jobID string) string {
return filepath.Join(s.fineTuneBaseDir(), jobID)
}
// saveJobState persists a job's state to disk as state.json.
func (s *FineTuneService) saveJobState(job *schema.FineTuneJob) {
dir := s.jobDir(job.ID)
if err := os.MkdirAll(dir, 0750); err != nil {
xlog.Error("Failed to create job directory", "job_id", job.ID, "error", err)
return
}
data, err := json.MarshalIndent(job, "", " ")
if err != nil {
xlog.Error("Failed to marshal job state", "job_id", job.ID, "error", err)
return
}
statePath := filepath.Join(dir, "state.json")
if err := os.WriteFile(statePath, data, 0640); err != nil {
xlog.Error("Failed to write job state", "job_id", job.ID, "error", err)
}
}
// loadAllJobs scans the fine-tune directory for persisted jobs and loads them.
func (s *FineTuneService) loadAllJobs() {
baseDir := s.fineTuneBaseDir()
entries, err := os.ReadDir(baseDir)
if err != nil {
// Directory doesn't exist yet — that's fine
return
}
for _, entry := range entries {
if !entry.IsDir() {
continue
}
statePath := filepath.Join(baseDir, entry.Name(), "state.json")
data, err := os.ReadFile(statePath)
if err != nil {
continue
}
var job schema.FineTuneJob
if err := json.Unmarshal(data, &job); err != nil {
xlog.Warn("Failed to parse job state", "path", statePath, "error", err)
continue
}
// Jobs that were running when we shut down are now stale
if job.Status == "queued" || job.Status == "loading_model" || job.Status == "loading_dataset" || job.Status == "training" || job.Status == "saving" {
job.Status = "stopped"
job.Message = "Server restarted while job was running"
}
// Exports that were in progress are now stale
if job.ExportStatus == "exporting" {
job.ExportStatus = "failed"
job.ExportMessage = "Server restarted while export was running"
}
s.jobs[job.ID] = &job
}
if len(s.jobs) > 0 {
xlog.Info("Loaded persisted fine-tune jobs", "count", len(s.jobs))
}
}
// StartJob starts a new fine-tuning job.
func (s *FineTuneService) StartJob(ctx context.Context, userID string, req schema.FineTuneJobRequest) (*schema.FineTuneJobResponse, error) {
s.mu.Lock()
defer s.mu.Unlock()
jobID := uuid.New().String()
backendName := req.Backend
if backendName == "" {
backendName = "trl"
}
// Always use DataPath for output — not user-configurable
outputDir := filepath.Join(s.fineTuneBaseDir(), jobID)
// Build gRPC request
grpcReq := &pb.FineTuneRequest{
Model: req.Model,
TrainingType: req.TrainingType,
TrainingMethod: req.TrainingMethod,
AdapterRank: req.AdapterRank,
AdapterAlpha: req.AdapterAlpha,
AdapterDropout: req.AdapterDropout,
TargetModules: req.TargetModules,
LearningRate: req.LearningRate,
NumEpochs: req.NumEpochs,
BatchSize: req.BatchSize,
GradientAccumulationSteps: req.GradientAccumulationSteps,
WarmupSteps: req.WarmupSteps,
MaxSteps: req.MaxSteps,
SaveSteps: req.SaveSteps,
WeightDecay: req.WeightDecay,
GradientCheckpointing: req.GradientCheckpointing,
Optimizer: req.Optimizer,
Seed: req.Seed,
MixedPrecision: req.MixedPrecision,
DatasetSource: req.DatasetSource,
DatasetSplit: req.DatasetSplit,
OutputDir: outputDir,
JobId: jobID,
ResumeFromCheckpoint: req.ResumeFromCheckpoint,
ExtraOptions: req.ExtraOptions,
}
// Serialize reward functions into extra_options for the backend
if len(req.RewardFunctions) > 0 {
rfJSON, err := json.Marshal(req.RewardFunctions)
if err != nil {
return nil, fmt.Errorf("failed to serialize reward functions: %w", err)
}
if grpcReq.ExtraOptions == nil {
grpcReq.ExtraOptions = make(map[string]string)
}
grpcReq.ExtraOptions["reward_funcs"] = string(rfJSON)
}
// Load the fine-tuning backend (per-job model ID so multiple jobs can run concurrently)
modelID := backendName + "-finetune-" + jobID
backendModel, err := s.modelLoader.Load(
model.WithBackendString(backendName),
model.WithModel(backendName),
model.WithModelID(modelID),
)
if err != nil {
return nil, fmt.Errorf("failed to load backend %s: %w", backendName, err)
}
// Start fine-tuning via gRPC
result, err := backendModel.StartFineTune(ctx, grpcReq)
if err != nil {
return nil, fmt.Errorf("failed to start fine-tuning: %w", err)
}
if !result.Success {
return nil, fmt.Errorf("fine-tuning failed to start: %s", result.Message)
}
// Track the job
job := &schema.FineTuneJob{
ID: jobID,
UserID: userID,
Model: req.Model,
Backend: backendName,
ModelID: modelID,
TrainingType: req.TrainingType,
TrainingMethod: req.TrainingMethod,
Status: "queued",
OutputDir: outputDir,
ExtraOptions: req.ExtraOptions,
CreatedAt: time.Now().UTC().Format(time.RFC3339),
Config: &req,
}
s.jobs[jobID] = job
s.saveJobState(job)
// Persist to PostgreSQL in distributed mode
if s.fineTuneStore != nil {
configJSON, _ := json.Marshal(req)
extraJSON, _ := json.Marshal(req.ExtraOptions)
s.fineTuneStore.Create(&distributed.FineTuneJobRecord{
ID: jobID,
UserID: userID,
Model: req.Model,
Backend: backendName,
ModelID: modelID,
TrainingType: req.TrainingType,
TrainingMethod: req.TrainingMethod,
Status: "queued",
OutputDir: outputDir,
ConfigJSON: string(configJSON),
ExtraOptsJSON: string(extraJSON),
})
}
return &schema.FineTuneJobResponse{
ID: jobID,
Status: "queued",
Message: result.Message,
}, nil
}
// GetJob returns a fine-tuning job by ID.
func (s *FineTuneService) GetJob(userID, jobID string) (*schema.FineTuneJob, error) {
s.mu.Lock()
defer s.mu.Unlock()
job, ok := s.jobs[jobID]
if !ok {
return nil, fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
return nil, fmt.Errorf("job not found: %s", jobID)
}
return job, nil
}
// ListJobs returns all jobs for a user, sorted by creation time (newest first).
func (s *FineTuneService) ListJobs(userID string) []*schema.FineTuneJob {
s.mu.Lock()
defer s.mu.Unlock()
var result []*schema.FineTuneJob
for _, job := range s.jobs {
if userID == "" || job.UserID == userID {
result = append(result, job)
}
}
slices.SortFunc(result, func(a, b *schema.FineTuneJob) int {
return cmp.Compare(b.CreatedAt, a.CreatedAt)
})
return result
}
// StopJob stops a running fine-tuning job.
func (s *FineTuneService) StopJob(ctx context.Context, userID, jobID string, saveCheckpoint bool) error {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
s.mu.Unlock()
// Kill the backend process directly
stopModelID := job.ModelID
if stopModelID == "" {
stopModelID = job.Backend + "-finetune"
}
s.modelLoader.ShutdownModel(stopModelID)
s.mu.Lock()
job.Status = "stopped"
job.Message = "Training stopped by user"
s.saveJobState(job)
if s.fineTuneStore != nil {
s.fineTuneStore.UpdateStatus(jobID, "stopped", "Training stopped by user")
}
s.mu.Unlock()
return nil
}
// DeleteJob removes a fine-tuning job and its associated data from disk.
func (s *FineTuneService) DeleteJob(userID, jobID string) error {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
// Reject deletion of actively running jobs
activeStatuses := map[string]bool{
"queued": true, "loading_model": true, "loading_dataset": true,
"training": true, "saving": true,
}
if activeStatuses[job.Status] {
s.mu.Unlock()
return fmt.Errorf("cannot delete job %s: currently %s (stop it first)", jobID, job.Status)
}
if job.ExportStatus == "exporting" {
s.mu.Unlock()
return fmt.Errorf("cannot delete job %s: export in progress", jobID)
}
exportModelName := job.ExportModelName
delete(s.jobs, jobID)
if s.fineTuneStore != nil {
s.fineTuneStore.Delete(jobID)
}
s.mu.Unlock()
// Remove job directory (state.json, checkpoints, output)
jobDir := s.jobDir(jobID)
if err := os.RemoveAll(jobDir); err != nil {
xlog.Warn("Failed to remove job directory", "job_id", jobID, "path", jobDir, "error", err)
}
// If an exported model exists, clean it up too
if exportModelName != "" {
modelsPath := s.appConfig.SystemState.Model.ModelsPath
modelDir := filepath.Join(modelsPath, exportModelName)
configPath := filepath.Join(modelsPath, exportModelName+".yaml")
if err := os.RemoveAll(modelDir); err != nil {
xlog.Warn("Failed to remove exported model directory", "path", modelDir, "error", err)
}
if err := os.Remove(configPath); err != nil && !os.IsNotExist(err) {
xlog.Warn("Failed to remove exported model config", "path", configPath, "error", err)
}
// Reload model configs
if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil {
xlog.Warn("Failed to reload configs after delete", "error", err)
}
}
xlog.Info("Deleted fine-tune job", "job_id", jobID)
return nil
}
// StreamProgress opens a gRPC progress stream and calls the callback for each update.
func (s *FineTuneService) StreamProgress(ctx context.Context, userID, jobID string, callback func(event *schema.FineTuneProgressEvent)) error {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return fmt.Errorf("job not found: %s", jobID)
}
s.mu.Unlock()
streamModelID := job.ModelID
if streamModelID == "" {
streamModelID = job.Backend + "-finetune"
}
backendModel, err := s.modelLoader.Load(
model.WithBackendString(job.Backend),
model.WithModel(job.Backend),
model.WithModelID(streamModelID),
)
if err != nil {
return fmt.Errorf("failed to load backend: %w", err)
}
return backendModel.FineTuneProgress(ctx, &pb.FineTuneProgressRequest{
JobId: jobID,
}, func(update *pb.FineTuneProgressUpdate) {
// Update job status and persist
s.mu.Lock()
if j, ok := s.jobs[jobID]; ok {
// Don't let progress updates overwrite terminal states
isTerminal := j.Status == "stopped" || j.Status == "completed" || j.Status == "failed"
if !isTerminal {
j.Status = update.Status
}
if update.Message != "" {
j.Message = update.Message
}
s.saveJobState(j)
if s.fineTuneStore != nil {
s.fineTuneStore.UpdateStatus(jobID, j.Status, j.Message)
}
}
s.mu.Unlock()
// Convert extra metrics
extraMetrics := make(map[string]float32)
for k, v := range update.ExtraMetrics {
extraMetrics[k] = v
}
event := &schema.FineTuneProgressEvent{
JobID: update.JobId,
CurrentStep: update.CurrentStep,
TotalSteps: update.TotalSteps,
CurrentEpoch: update.CurrentEpoch,
TotalEpochs: update.TotalEpochs,
Loss: update.Loss,
LearningRate: update.LearningRate,
GradNorm: update.GradNorm,
EvalLoss: update.EvalLoss,
EtaSeconds: update.EtaSeconds,
ProgressPercent: update.ProgressPercent,
Status: update.Status,
Message: update.Message,
CheckpointPath: update.CheckpointPath,
SamplePath: update.SamplePath,
ExtraMetrics: extraMetrics,
}
callback(event)
})
}
// ListCheckpoints lists checkpoints for a job.
func (s *FineTuneService) ListCheckpoints(ctx context.Context, userID, jobID string) ([]*pb.CheckpointInfo, error) {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return nil, fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return nil, fmt.Errorf("job not found: %s", jobID)
}
s.mu.Unlock()
ckptModelID := job.ModelID
if ckptModelID == "" {
ckptModelID = job.Backend + "-finetune"
}
backendModel, err := s.modelLoader.Load(
model.WithBackendString(job.Backend),
model.WithModel(job.Backend),
model.WithModelID(ckptModelID),
)
if err != nil {
return nil, fmt.Errorf("failed to load backend: %w", err)
}
resp, err := backendModel.ListCheckpoints(ctx, &pb.ListCheckpointsRequest{
OutputDir: job.OutputDir,
})
if err != nil {
return nil, fmt.Errorf("failed to list checkpoints: %w", err)
}
return resp.Checkpoints, nil
}
// sanitizeModelName replaces non-alphanumeric characters with hyphens and lowercases.
func sanitizeModelName(s string) string {
re := regexp.MustCompile(`[^a-zA-Z0-9\-]`)
s = re.ReplaceAllString(s, "-")
s = regexp.MustCompile(`-+`).ReplaceAllString(s, "-")
s = strings.Trim(s, "-")
return strings.ToLower(s)
}
// ExportModel starts an async model export from a checkpoint and returns the intended model name immediately.
func (s *FineTuneService) ExportModel(ctx context.Context, userID, jobID string, req schema.ExportRequest) (string, error) {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return "", fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return "", fmt.Errorf("job not found: %s", jobID)
}
if job.ExportStatus == "exporting" {
s.mu.Unlock()
return "", fmt.Errorf("export already in progress for job %s", jobID)
}
s.mu.Unlock()
// Compute model name
modelName := req.Name
if modelName == "" {
base := sanitizeModelName(job.Model)
if base == "" {
base = "model"
}
shortID := jobID
if len(shortID) > 8 {
shortID = shortID[:8]
}
modelName = base + "-ft-" + shortID
}
// Compute output path in models directory
modelsPath := s.appConfig.SystemState.Model.ModelsPath
outputPath := filepath.Join(modelsPath, modelName)
// Check for name collision (synchronous — fast validation)
configPath := filepath.Join(modelsPath, modelName+".yaml")
if err := utils.VerifyPath(modelName+".yaml", modelsPath); err != nil {
return "", fmt.Errorf("invalid model name: %w", err)
}
if _, err := os.Stat(configPath); err == nil {
return "", fmt.Errorf("model %q already exists, choose a different name", modelName)
}
// Create output directory
if err := os.MkdirAll(outputPath, 0750); err != nil {
return "", fmt.Errorf("failed to create output directory: %w", err)
}
// Set export status to "exporting" and persist
s.mu.Lock()
job.ExportStatus = "exporting"
job.ExportMessage = ""
job.ExportModelName = ""
s.saveJobState(job)
s.mu.Unlock()
// Launch the export in a background goroutine
go func() {
s.setExportMessage(job, "Loading export backend...")
exportModelID := job.ModelID
if exportModelID == "" {
exportModelID = job.Backend + "-finetune"
}
backendModel, err := s.modelLoader.Load(
model.WithBackendString(job.Backend),
model.WithModel(job.Backend),
model.WithModelID(exportModelID),
)
if err != nil {
s.setExportFailed(job, fmt.Sprintf("failed to load backend: %v", err))
return
}
// Merge job's extra_options (contains hf_token from training) with request's
mergedOpts := make(map[string]string)
for k, v := range job.ExtraOptions {
mergedOpts[k] = v
}
for k, v := range req.ExtraOptions {
mergedOpts[k] = v // request overrides job
}
grpcReq := &pb.ExportModelRequest{
CheckpointPath: req.CheckpointPath,
OutputPath: outputPath,
ExportFormat: req.ExportFormat,
QuantizationMethod: req.QuantizationMethod,
Model: req.Model,
ExtraOptions: mergedOpts,
}
s.setExportMessage(job, "Running model export (merging and converting — this may take a while)...")
result, err := backendModel.ExportModel(context.Background(), grpcReq)
if err != nil {
s.setExportFailed(job, fmt.Sprintf("export failed: %v", err))
return
}
if !result.Success {
s.setExportFailed(job, fmt.Sprintf("export failed: %s", result.Message))
return
}
s.setExportMessage(job, "Export complete, generating model configuration...")
// Auto-import: detect format and generate config
cfg, err := importers.ImportLocalPath(outputPath, modelName)
if err != nil {
s.setExportFailed(job, fmt.Sprintf("model exported to %s but config generation failed: %v", outputPath, err))
return
}
cfg.Name = modelName
// If base model not detected from files, use the job's model field
if cfg.Model == "" && job.Model != "" {
cfg.Model = job.Model
}
// Write YAML config
yamlData, err := yaml.Marshal(cfg)
if err != nil {
s.setExportFailed(job, fmt.Sprintf("failed to marshal config: %v", err))
return
}
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
s.setExportFailed(job, fmt.Sprintf("failed to write config file: %v", err))
return
}
s.setExportMessage(job, "Registering model with LocalAI...")
// Reload configs so the model is immediately available
if err := s.configLoader.LoadModelConfigsFromPath(modelsPath, s.appConfig.ToConfigLoaderOptions()...); err != nil {
xlog.Warn("Failed to reload configs after export", "error", err)
}
if err := s.configLoader.Preload(modelsPath); err != nil {
xlog.Warn("Failed to preload after export", "error", err)
}
xlog.Info("Model exported and registered", "job_id", jobID, "model_name", modelName, "format", req.ExportFormat)
s.mu.Lock()
job.ExportStatus = "completed"
job.ExportModelName = modelName
job.ExportMessage = ""
s.saveJobState(job)
if s.fineTuneStore != nil {
s.fineTuneStore.UpdateExportStatus(jobID, "completed", "", modelName)
}
s.mu.Unlock()
}()
return modelName, nil
}
// setExportMessage updates the export message and persists the job state.
func (s *FineTuneService) setExportMessage(job *schema.FineTuneJob, msg string) {
s.mu.Lock()
job.ExportMessage = msg
s.saveJobState(job)
s.mu.Unlock()
}
// GetExportedModelPath returns the path to the exported model directory and its name.
func (s *FineTuneService) GetExportedModelPath(userID, jobID string) (string, string, error) {
s.mu.Lock()
job, ok := s.jobs[jobID]
if !ok {
s.mu.Unlock()
return "", "", fmt.Errorf("job not found: %s", jobID)
}
if userID != "" && job.UserID != userID {
s.mu.Unlock()
return "", "", fmt.Errorf("job not found: %s", jobID)
}
if job.ExportStatus != "completed" {
s.mu.Unlock()
return "", "", fmt.Errorf("export not completed for job %s (status: %s)", jobID, job.ExportStatus)
}
exportModelName := job.ExportModelName
s.mu.Unlock()
if exportModelName == "" {
return "", "", fmt.Errorf("no exported model name for job %s", jobID)
}
modelsPath := s.appConfig.SystemState.Model.ModelsPath
modelDir := filepath.Join(modelsPath, exportModelName)
if _, err := os.Stat(modelDir); os.IsNotExist(err) {
return "", "", fmt.Errorf("exported model directory not found: %s", modelDir)
}
return modelDir, exportModelName, nil
}
// setExportFailed sets the export status to failed with a message.
func (s *FineTuneService) setExportFailed(job *schema.FineTuneJob, message string) {
xlog.Error("Export failed", "job_id", job.ID, "error", message)
s.mu.Lock()
job.ExportStatus = "failed"
job.ExportMessage = message
s.saveJobState(job)
if s.fineTuneStore != nil {
s.fineTuneStore.UpdateExportStatus(job.ID, "failed", message, "")
}
s.mu.Unlock()
}
// UploadDataset handles dataset file upload and returns the local path.
// The filename comes straight from a multipart upload, so strip directory
// components — anything else risks a write under a sibling path.
func (s *FineTuneService) UploadDataset(filename string, data []byte) (string, error) {
uploadDir := filepath.Join(s.fineTuneBaseDir(), "datasets")
if err := os.MkdirAll(uploadDir, 0750); err != nil {
return "", fmt.Errorf("failed to create dataset directory: %w", err)
}
base := filepath.Base(filename)
if base == "." || base == ".." || base == "/" || base == "" {
return "", fmt.Errorf("invalid filename")
}
filePath := filepath.Join(uploadDir, uuid.New().String()[:8]+"-"+base)
if err := os.WriteFile(filePath, data, 0640); err != nil {
return "", fmt.Errorf("failed to write dataset: %w", err)
}
return filePath, nil
}
// MarshalProgressEvent converts a progress event to JSON for SSE.
func MarshalProgressEvent(event *schema.FineTuneProgressEvent) (string, error) {
data, err := json.Marshal(event)
if err != nil {
return "", err
}
return string(data), nil
}