From 670259ce4342ae7ed10db60db27d6c2e79c99423 Mon Sep 17 00:00:00 2001 From: Richard Palethorpe Date: Fri, 8 May 2026 15:25:45 +0100 Subject: [PATCH] chore: Security hardening (#9719) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(http): close 0.0.0.0/[::] SSRF bypass in /api/cors-proxy The CORS proxy carried its own private-network blocklist (RFC 1918 + a handful of IPv6 ranges) instead of using the same classification as pkg/utils/urlfetch.go. The hand-rolled list missed 0.0.0.0/8 and ::/128, both of which Linux routes to localhost — so any user with FeatureMCP (default-on for new users) could reach LocalAI's own listener and any other service bound to 0.0.0.0:port via: GET /api/cors-proxy?url=http://0.0.0.0:8080/... GET /api/cors-proxy?url=http://[::]:8080/... Replace the custom check with utils.IsPublicIP (Go stdlib IsLoopback / IsLinkLocalUnicast / IsPrivate / IsUnspecified, plus IPv4-mapped IPv6 unmasking) and add an upfront hostname rejection for localhost, *.local, and the cloud metadata aliases so split-horizon DNS can't paper over the IP check. The IP-pinning DialContext is unchanged: the validated IP from the single resolution is reused for the connection, so DNS rebinding still cannot swap a public answer for a private one between validate and dial. Regression tests cover 0.0.0.0, 0.0.0.0:PORT, [::], ::ffff:127.0.0.1, ::ffff:10.0.0.1, file://, gopher://, ftp://, localhost, 127.0.0.1, 10.0.0.1, 169.254.169.254, metadata.google.internal. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * fix(downloader): verify SHA before promoting temp file to final path DownloadFileWithContext renamed the .partial file to its final name *before* checking the streamed SHA, so a hash mismatch returned an error but left the tampered file at filePath. Subsequent code that operated on filePath (a backend launcher, a YAML loader, a re-download that finds the file already present and skips) would consume the attacker-supplied bytes. Reorder: verify the streamed hash first, remove the .partial on mismatch, then rename. The streamed hash is computed during io.Copy so no second read is needed. While here, raise the empty-SHA case from a Debug log to a Warn so "this download had no integrity check" is visible at the default log level. Backend installs currently pass through with no digest; the warning makes that footprint observable without changing behaviour. Regression test asserts os.IsNotExist on the destination after a deliberate SHA mismatch. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * fix(auth): require email_verified for OIDC admin promotion extractOIDCUserInfo read the ID token's "email" claim but never inspected "email_verified". With LOCALAI_ADMIN_EMAIL set, an attacker who could register on the configured OIDC IdP under that email (some IdPs accept self-supplied unverified emails) inherited admin role: - first login: AssignRole(tx, email, adminEmail) → RoleAdmin - re-login: MaybePromote(db, user, adminEmail) → flip to RoleAdmin Add EmailVerified to oauthUserInfo, parse email_verified from the OIDC claims (default false on absence so an IdP that omits the claim cannot short-circuit the gate), and substitute "" for the role-decision email when verified=false via emailForRoleDecision. The user record still stores the unverified email for display. GitHub's path defaults EmailVerified=true: GitHub only returns a public profile email after verification, and fetchGitHubPrimaryEmail explicitly filters to Verified=true. Regression tests cover both the helper contract and integration with AssignRole, including the bootstrap "first user" branch that would otherwise mask the gate. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * feat(cli): refuse public bind when no auth backend is configured When neither an auth DB nor a static API key is set, the auth middleware passes every request through. That is fine for a developer laptop, a home LAN, or a Tailnet — the network itself is the trust boundary. It is not fine on a public IP, where every model install, settings change, and admin endpoint becomes reachable from the internet. Refuse to start in that exact configuration. Loopback, RFC 1918, RFC 4193 ULA, link-local, and RFC 6598 CGNAT (Tailscale's default range) all count as trusted; wildcard binds (`:port`, `0.0.0.0`, `[::]`) are accepted only when every host interface is in one of those ranges. Hostnames are resolved and treated as trusted only when every answer is. A new --allow-insecure-public-bind / LOCALAI_ALLOW_INSECURE_PUBLIC_BIND flag opts out for deployments that gate access externally (a reverse proxy enforcing auth, a mesh ACL, etc.). The error message lists this plus the three constructive alternatives (bind a private interface, enable --auth, set --api-keys). The interface enumeration goes through a package-level interfaceAddrsFn var so tests can simulate cloud-VM, home-LAN, Tailscale-only, and enumeration-failure topologies without poking at the real network stack. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * test(http): regression-test the localai_assistant admin gate ChatEndpoint already rejects metadata.localai_assistant=true from a non-admin caller, but the gate was open-coded inline with no direct test coverage. The chat route is FeatureChat-gated (default-on), and the assistant's in-process MCP server can install/delete models and edit configs — the wrong handler change would silently turn the LLM into a confused deputy. Extract the gate into requireAssistantAccess(c, authEnabled) and pin its behaviour: auth disabled is a no-op, unauthenticated is 403, RoleUser is 403, RoleAdmin and the synthetic legacy-key admin are admitted. No behaviour change in the production path. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * test(http): assert every API route is auth-classified The auth middleware classifies path prefixes (/api/, /v1/, /models/, etc.) as protected and treats anything else as a static-asset passthrough. A new endpoint shipped under a brand-new prefix — or a new path that simply isn't on the prefix allowlist — would be reachable anonymously. Walk every route registered by API() with auth enabled and a fresh in-memory database (no users, no keys), and assert each API-prefixed route returns 401 / 404 / 405 to an anonymous request. Public surfaces (/api/auth/*, /api/branding, /api/node/* token-authenticated routes, /healthz, branding asset server, generated-content server, static assets) are explicit allowlist entries with comments justifying them. Build-tagged 'auth' so it runs against the SQLite-backed auth DB (matches the existing auth suite). Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * test(http): pin agent endpoint per-user isolation contract agents.go's getUserID / effectiveUserID / canImpersonateUser / wantsAllUsers helpers are the single trust boundary for cross-user access on agent, agent-jobs, collections, and skills routes. A regression there is the difference between "regular user reads their own data" and "regular user reads anyone's data via ?user_id=victim". Lock in the contract: - effectiveUserID ignores ?user_id= for unauthenticated and RoleUser - effectiveUserID honours it for RoleAdmin and ProviderAgentWorker - wantsAllUsers requires admin AND the literal "true" string - canImpersonateUser is admin OR agent-worker, never plain RoleUser No production change — this commit only adds tests. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * fix(downloader): drop redundant stat in removePartialFile The stat-then-remove pattern is a TOCTOU window and a wasted syscall — os.Remove already returns ErrNotExist for the missing-file case, so trust that and treat it as a no-op. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * fix(http): redact secrets from trace buffer and distribution-token logs The /api/traces buffer captured Authorization, Cookie, Set-Cookie, and API-key headers verbatim from every request when tracing was enabled. The endpoint is admin-only but the buffer is reachable via any heap-style introspection and the captured tokens otherwise outlive the request. Strip those header values at capture time. Body redaction is left to a follow-up — the prompts are usually the operator's own and JSON-walking is invasive. Distribution tokens were also logged in plaintext from core/explorer/discovery.go; logs forward to syslog/journald and outlive the token. Redact those to a short prefix/suffix instead. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * feat(auth): rate-limit OAuth callbacks separately from password endpoints The shared 5/min/IP limit on auth endpoints is right for password-style flows but too tight for OAuth callbacks: corporate SSO funnels many real users through one outbound IP and would trip the limit. Add a separate 60/min/IP limiter for /api/auth/{github,oidc}/callback so callbacks are bounded against floods without breaking shared-IP deployments. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * feat(gallery): verify backend tarball sha256 when set in gallery entry GalleryBackend gained an optional sha256 field; the install path now threads it through to the existing downloader hash-verify (which already streams, verifies, and rolls back on mismatch). Galleries without sha256 keep working; the empty-SHA path still emits the existing "downloading without integrity check" warning. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * test(http): pin CSRF coverage on multipart endpoints The CSRF middleware in app.go is global (e.Use) so it covers every multipart upload route — branding assets, fine-tune datasets, audio transforms, agent collections. Pin that contract: cross-site multipart POSTs are rejected; same-origin / same-site / API-key clients are not. Also pins the SameSite=Lax fallback path the skipper relies on when Sec-Fetch-Site is absent. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * feat(http): XSS hardening — CSP headers, safe href, base-href escape, SVG sandbox Several closely related XSS-prevention changes spanning the SPA shell, the React UI, and the branding asset server: - New SecurityHeaders middleware sets CSP, X-Content-Type-Options, X-Frame-Options, and Referrer-Policy on every response. The CSP keeps script-src permissive because the Vite bundle relies on inline + eval'd scripts; tightening that requires moving to a nonce-based policy. - The injection in the SPA shell escaped attacker-controllable Host / X-Forwarded-Host headers — a single quote in the host header broke out of the attribute. Pass through SecureBaseHref (html.EscapeString). - Three React sinks rendering untrusted content via dangerouslySetInnerHTML switch to text-node rendering with whiteSpace: pre-wrap: user message bodies in Chat.jsx and AgentChat.jsx, and the agent activity log in AgentChat.jsx. The hand-rolled escape on the agent user-message variant is replaced by the same plain-text path. - New safeHref util collapses non-allowlisted URI schemes (most importantly javascript:) to '#'. Applied to gallery `` links in Models / Backends / Manage and to canvas artifact links — these come from gallery JSON or assistant tool calls and must be treated as untrusted. - The branding asset server attaches a sandbox CSP plus same-origin CORP to .svg responses. The React UI loads logos via , but the same URL is also reachable via direct navigation; this prevents script execution if a hostile SVG slipped past upload validation. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * feat(http): bound HTTP server with read-header and idle timeouts A net/http server with no timeouts is trivially Slowloris-able and leaks idle keep-alive connections. Set ReadHeaderTimeout (30s) to plug the slow-headers attack and IdleTimeout (120s) to cap keep-alive sockets. ReadTimeout and WriteTimeout stay at 0 because request bodies can be multi-GB model uploads and SSE / chat completions stream for many minutes; operators who need tighter per-request bounds should terminate slow clients at a reverse proxy. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * test(auth): pin PUT /api/auth/profile field-tampering contract The handler uses an explicit local body struct (only name and avatar_url) plus a gorm Updates(map) with a column allowlist, so an attacker posting {"role":"admin","email":"...","password_hash":"..."} can't mass-assign those fields. Lock that down with a regression test so a future "let's just c.Bind(&user)" refactor breaks loudly. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * fix(services): strip directory components from multipart upload filenames UploadDataset and UploadToCollectionForUser took the raw multipart file.Filename and joined it into a destination path. The fine-tune upload was incidentally safe because of a UUID prefix that fused any leading '..' to a literal segment, but the protection is fragile. UploadToCollectionForUser handed the filename to a vendored backend without sanitising at all. Strip to filepath.Base at both boundaries and reject the trivial unsafe values ("", ".", "..", "/"). Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * fix(react-ui): validate persisted MCP server entries on load localStorage is shared across same-origin pages; an XSS that lands once can poison persisted MCP server config to attempt header injection or to feed a non-http URL into the fetch path on subsequent loads. Validate every entry: types must match, URL must parse with http(s) scheme, header keys/values must be control-char-free. Drop anything that doesn't fit. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * fix(http): close X-Forwarded-Prefix open redirect The reverse-proxy support concatenated X-Forwarded-Prefix into the redirect target without validation, so a forged header value of "//evil.com" turned the SPA-shell redirect helper at /, /browse, and /browse/* into a 301 to //evil.com/app. The path-strip middleware had the same shape on its prefix-trailing-slash redirect. Add SafeForwardedPrefix at the middleware boundary: must start with a single '/', no protocol-relative '//' opener, no scheme, no backslash, no control characters. Apply at both consumers; misconfig trips the validator and the header is dropped. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * fix(http): refuse wildcard CORS when LOCALAI_CORS=true with empty allowlist When LOCALAI_CORS=true but LOCALAI_CORS_ALLOW_ORIGINS was empty, Echo's CORSWithConfig saw an empty allow-list and fell back to its default AllowOrigins=["*"]. An operator who flipped the strict-CORS feature flag without populating the list got the opposite of what they asked for. Echo never sets Allow-Credentials: true so this isn't directly exploitable (cookies aren't sent under wildcard CORS), but the misconfiguration trap is worth closing. Skip the registration and warn. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * feat(auth): zxcvbn password strength check with user-acknowledged override The previous policy was len < 8, which let through "Password1" and the rest of the credential-stuffing corpus. LocalAI has no second factor yet, so the bar needs to sit higher. Add ValidatePasswordStrength using github.com/timbutler/zxcvbn (an actively-maintained fork of the trustelem port; v1.0.4, April 2024): - min 12 chars, max 72 (bcrypt's truncation point) - reject NUL bytes (some bcrypt callers truncate at the first NUL) - require zxcvbn score >= 3 ("safely unguessable, ~10^8 guesses to break"); the hint list ["localai", "local-ai", "admin"] penalises passwords built from the app's own branding zxcvbn produces false positives sometimes (a strong-looking password that happens to match a dictionary word) and operators occasionally need to set a known-weak password (kiosk demos, CI rigs). Add an acknowledgement path: PasswordPolicy{AllowWeak: true} skips the entropy check while still enforcing the hard rules. The structured PasswordErrorResponse marks weak-password rejections as Overridable so the UI can surface a "use this anyway" checkbox. Wired through register, self-service password change, and admin password reset on both the server and the React UI. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * fix(react-ui): drop HTML5 minLength on new-password inputs minLength={12} on the new-password input let the browser block the form submit silently before any JS or network call ran. The browser focused the field, showed a brief native tooltip, and that was that — no toast, no fetch, no clue. Reproducible by typing fewer than 12 chars on the second password change of a session. The JS-level length check in handleSubmit already shows a toast and the server rejects with a structured error, so the HTML5 attribute was redundant defence anyway. Drop it. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * fix(react-ui): bundle Geist fonts locally instead of fetching from Google The new CSP correctly refused to apply styles from fonts.googleapis.com because style-src is locked to 'self' and 'unsafe-inline'. Loosening the CSP would defeat its purpose; the right fix is to stop reaching out to a third-party CDN for fonts on every page load. Add @fontsource-variable/geist and @fontsource-variable/geist-mono as npm deps and import them once at boot. Drop the and external stylesheet from index.html. Side benefit: no third-party tracking via Referer / IP on every UI load, no failure mode when offline / behind a captive portal. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * fix(react-ui): refresh i18n strings to reflect 12-char password minimum The translations still said "at least 8 characters" everywhere — the client-side toast on a too-short password change told the user the wrong floor. Update tooShort and newPasswordPlaceholder / newPasswordDescription across all five locales (en, es, it, de, zh-CN) to match the real ValidatePasswordStrength rule. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe * feat(auth): make password length-floor overridable like the entropy check The 12-char minimum was a policy choice, not a technical invariant — only "non-empty", "<= 72 bytes", and "no NUL bytes" are real bcrypt constraints. Treating length-12 as a hard rule was inconsistent with the entropy check (already overridable) and friction for use cases where the account is just a name on a session, not a security boundary (single-user kiosk, CI rig, lab demo). Restructure ValidatePasswordStrength: - Hard rules (always enforced): non-empty, <= MaxPasswordLength, no NUL byte - Policy rules (skipped when AllowWeak=true): length >= 12, zxcvbn score >= 3 PasswordError now marks password_too_short as Overridable too. The React forms generalised from `error_code === 'password_too_weak'` to `overridable === true`, and the JS-side preflight length checks were removed (server is source of truth, returns the same checkbox flow). Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe --------- Signed-off-by: Richard Palethorpe --- core/cli/run.go | 12 + core/cli/run_safety.go | 126 ++++++++++ core/cli/run_safety_test.go | 154 ++++++++++++ core/explorer/discovery.go | 15 +- core/gallery/backend_sha256_test.go | 55 +++++ core/gallery/backend_types.go | 3 + core/gallery/backends.go | 8 +- core/http/app.go | 48 +++- core/http/auth/auth_suite_test.go | 2 - core/http/auth/oauth.go | 62 +++-- core/http/auth/oauth_email_decision.go | 17 ++ core/http/auth/oauth_email_decision_test.go | 61 +++++ core/http/auth/password.go | 126 +++++++++- core/http/auth/password_test.go | 100 ++++++++ core/http/csrf_multipart_test.go | 145 +++++++++++ .../localai/agents_isolation_test.go | 178 ++++++++++++++ core/http/endpoints/localai/branding.go | 11 + core/http/endpoints/localai/branding_test.go | 88 +++++++ core/http/endpoints/localai/cors_proxy.go | 49 ++-- .../http/endpoints/localai/cors_proxy_test.go | 110 +++++++++ core/http/endpoints/openai/chat.go | 12 +- .../endpoints/openai/chat_assistant_gate.go | 25 ++ .../openai/chat_assistant_gate_test.go | 82 +++++++ core/http/middleware/forwarded_prefix.go | 34 +++ core/http/middleware/forwarded_prefix_test.go | 50 ++++ core/http/middleware/security_headers.go | 53 +++++ core/http/middleware/security_headers_test.go | 113 +++++++++ core/http/middleware/strippathprefix.go | 5 + core/http/middleware/trace.go | 33 ++- core/http/middleware/trace_redact_test.go | 66 +++++ core/http/react-ui/bun.lock | 6 + core/http/react-ui/index.html | 3 - core/http/react-ui/package.json | 2 + .../http/react-ui/public/locales/de/auth.json | 6 +- .../http/react-ui/public/locales/en/auth.json | 6 +- .../http/react-ui/public/locales/es/auth.json | 6 +- .../http/react-ui/public/locales/it/auth.json | 6 +- .../react-ui/public/locales/zh-CN/auth.json | 6 +- .../react-ui/src/components/CanvasPanel.jsx | 5 +- core/http/react-ui/src/main.jsx | 2 + core/http/react-ui/src/pages/Account.jsx | 39 ++- core/http/react-ui/src/pages/AgentChat.jsx | 5 +- core/http/react-ui/src/pages/Backends.jsx | 3 +- core/http/react-ui/src/pages/Chat.jsx | 2 +- core/http/react-ui/src/pages/Login.jsx | 40 +++- core/http/react-ui/src/pages/Manage.jsx | 5 +- core/http/react-ui/src/pages/Models.jsx | 3 +- core/http/react-ui/src/pages/Users.jsx | 41 +++- core/http/react-ui/src/utils/api.js | 11 +- .../react-ui/src/utils/mcpClientStorage.js | 42 +++- core/http/react-ui/src/utils/url.js | 31 +++ core/http/route_coverage_test.go | 225 ++++++++++++++++++ core/http/routes/auth.go | 41 ++-- core/http/routes/auth_profile_test.go | 165 +++++++++++++ core/services/agentpool/agent_pool.go | 8 +- core/services/finetune/service.go | 8 +- go.mod | 1 + go.sum | 2 + pkg/downloader/uri.go | 36 +-- pkg/downloader/uri_test.go | 43 ++++ pkg/utils/urlfetch.go | 9 +- 61 files changed, 2482 insertions(+), 169 deletions(-) create mode 100644 core/cli/run_safety.go create mode 100644 core/cli/run_safety_test.go create mode 100644 core/gallery/backend_sha256_test.go create mode 100644 core/http/auth/oauth_email_decision.go create mode 100644 core/http/auth/oauth_email_decision_test.go create mode 100644 core/http/auth/password_test.go create mode 100644 core/http/csrf_multipart_test.go create mode 100644 core/http/endpoints/localai/agents_isolation_test.go create mode 100644 core/http/endpoints/localai/branding_test.go create mode 100644 core/http/endpoints/localai/cors_proxy_test.go create mode 100644 core/http/endpoints/openai/chat_assistant_gate.go create mode 100644 core/http/endpoints/openai/chat_assistant_gate_test.go create mode 100644 core/http/middleware/forwarded_prefix.go create mode 100644 core/http/middleware/forwarded_prefix_test.go create mode 100644 core/http/middleware/security_headers.go create mode 100644 core/http/middleware/security_headers_test.go create mode 100644 core/http/middleware/trace_redact_test.go create mode 100644 core/http/react-ui/src/utils/url.js create mode 100644 core/http/route_coverage_test.go create mode 100644 core/http/routes/auth_profile_test.go diff --git a/core/cli/run.go b/core/cli/run.go index 077ef8b23..079cc8ffd 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -70,6 +70,7 @@ type RunCMD struct { OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"` UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"` DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"` + AllowInsecurePublicBind bool `env:"LOCALAI_ALLOW_INSECURE_PUBLIC_BIND" default:"false" help:"Allow binding the API to a public-internet address without any authentication configured. Without this flag the server refuses to start when the bind address is public (or a wildcard on a host with a public interface) and no auth backend or static API key is set. Loopback, RFC 1918 LAN, ULA, link-local, and CGNAT (Tailscale) ranges are accepted regardless." group:"hardening"` DisableMetricsEndpoint bool `env:"LOCALAI_DISABLE_METRICS_ENDPOINT,DISABLE_METRICS_ENDPOINT" default:"false" help:"Disable the /metrics endpoint" group:"api"` HttpGetExemptedEndpoints []string `env:"LOCALAI_HTTP_GET_EXEMPTED_ENDPOINTS" default:"^/$,^/app(/.*)?$,^/browse(/.*)?$,^/login/?$,^/explorer/?$,^/assets/.*$,^/static/.*$,^/swagger.*$" help:"If LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET is overriden to true, this is the list of endpoints to exempt. Only adjust this in case of a security incident or as a result of a personal security posture review" group:"hardening"` Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"` @@ -516,6 +517,17 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { return fmt.Errorf("LocalAI failed to start: %w.\nTroubleshooting steps:\n 1. Check that your models directory exists and is accessible: %s\n 2. Verify model config files are valid YAML: 'local-ai util usecase-heuristic '\n 3. Check available disk space and file permissions\n 4. Run with --log-level=debug for more details\nSee https://localai.io/basics/troubleshooting/ for more help", err, r.ModelsPath) } + // Refuse to bind a public-internet address without authentication unless + // the operator has explicitly opted in. The auth middleware degrades to + // pass-through when there is no auth DB and no legacy keys; on a loopback, + // LAN, or VPN that's the historical "trusted network" deployment, but on + // a public IP it makes every model, gallery install, settings change, and + // admin endpoint reachable by anyone who can connect to the port. + authConfigured := app.AuthDB() != nil || len(r.APIKeys) > 0 + if err := requireAuthOrTrustedBind(r.Address, authConfigured, r.AllowInsecurePublicBind); err != nil { + return err + } + appHTTP, err := http.API(app) if err != nil { xlog.Error("error during HTTP App construction", "error", err) diff --git a/core/cli/run_safety.go b/core/cli/run_safety.go new file mode 100644 index 000000000..61804e0c3 --- /dev/null +++ b/core/cli/run_safety.go @@ -0,0 +1,126 @@ +package cli + +import ( + "fmt" + "net" + "strings" + + "github.com/mudler/LocalAI/pkg/utils" +) + +// interfaceAddrsFn is the host-interface enumeration call. Tests swap it to +// simulate cloud-VM, home-LAN, and Tailscale-only topologies without poking +// the real network stack. +var interfaceAddrsFn = net.InterfaceAddrs + +// requireAuthOrTrustedBind fails closed when the server would otherwise bind a +// public-internet-reachable address with no authentication configured. Loopback, +// RFC 1918, ULA, link-local, and CGNAT (Tailscale's default range) are all +// trusted. Wildcard binds are trusted only when every host interface is. +// +// Operators with an external gating layer (e.g. a reverse proxy that enforces +// auth) can opt out via --allow-insecure-public-bind. +func requireAuthOrTrustedBind(address string, authConfigured, allowInsecurePublicBind bool) error { + if authConfigured || allowInsecurePublicBind { + return nil + } + if isTrustedBind(address) { + return nil + } + return fmt.Errorf(`refusing to start: API bound to public address %q with no authentication configured. + +When auth is disabled, the server has no idea who is calling it — every model, +gallery install, settings change, and admin endpoint is reachable by anyone +who can connect to the port. That is acceptable on a loopback, LAN, or VPN +address but not on a public IP. + +Pick one: + 1. Bind to a private/LAN/VPN interface only (e.g. --address 10.0.0.5:8080) + 2. Enable user authentication: --auth (or LOCALAI_AUTH=true), then sign in + 3. Set a static API key: --api-keys (LOCALAI_API_KEY=) + 4. Allow the public bind anyway: --allow-insecure-public-bind (only when an + external system is gating access to this + listener)`, address) +} + +// isTrustedBind reports whether `address` binds only to addresses that are +// local, on a private LAN, or on a VPN. Hostnames it can't classify cleanly +// are rejected. +func isTrustedBind(address string) bool { + if address == "" { + return false + } + host, _, err := net.SplitHostPort(address) + if err != nil { + return false + } + host = strings.TrimSpace(host) + if host == "" { + return allInterfacesTrusted() + } + if ip := net.ParseIP(host); ip != nil { + if ip.IsUnspecified() { + return allInterfacesTrusted() + } + return isPrivateOrLocalIP(ip) + } + // Hostname — every resolved address must be trusted. A name resolving to + // a mix of public and private addresses fails closed. + addrs, err := net.LookupHost(host) + if err != nil || len(addrs) == 0 { + return false + } + for _, a := range addrs { + ip := net.ParseIP(a) + if ip == nil || !isPrivateOrLocalIP(ip) { + return false + } + } + return true +} + +// isPrivateOrLocalIP returns true for loopback, RFC 1918 / RFC 4193 private, +// link-local, and RFC 6598 CGNAT addresses. CGNAT (100.64/10) gets the +// special case because the Go stdlib doesn't classify it as private but +// Tailscale and similar overlay VPNs hand them out. +func isPrivateOrLocalIP(ip net.IP) bool { + if ip.IsUnspecified() { + return false + } + if !utils.IsPublicIP(ip) { + return true + } + ip4 := ip.To4() + return ip4 != nil && ip4[0] == 100 && (ip4[1]&0xc0) == 64 +} + +// allInterfacesTrusted reports whether every IP assigned to a local interface +// is private/local. A wildcard bind on a host with even one public interface +// is genuinely exposing that public interface. +// +// Returns false on enumeration failure or when the host has no addresses +// at all — we can't prove the bind is safe. +func allInterfacesTrusted() bool { + addrs, err := interfaceAddrsFn() + if err != nil { + return false + } + sawAny := false + for _, a := range addrs { + var ip net.IP + switch v := a.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + } + if ip == nil || ip.IsUnspecified() { + continue + } + sawAny = true + if !isPrivateOrLocalIP(ip) { + return false + } + } + return sawAny +} diff --git a/core/cli/run_safety_test.go b/core/cli/run_safety_test.go new file mode 100644 index 000000000..21bcd3d87 --- /dev/null +++ b/core/cli/run_safety_test.go @@ -0,0 +1,154 @@ +package cli + +import ( + "errors" + "net" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// withInterfaceAddrs swaps interfaceAddrsFn for the duration of one spec. +// Ginkgo's DeferCleanup restores the original after the spec finishes, so +// concurrent specs can each pretend to be running on a different host. +func withInterfaceAddrs(cidrs ...string) { + original := interfaceAddrsFn + interfaceAddrsFn = func() ([]net.Addr, error) { + out := make([]net.Addr, 0, len(cidrs)) + for _, c := range cidrs { + ip, ipnet, err := net.ParseCIDR(c) + Expect(err).ToNot(HaveOccurred()) + out = append(out, &net.IPNet{IP: ip, Mask: ipnet.Mask}) + } + return out, nil + } + DeferCleanup(func() { interfaceAddrsFn = original }) +} + +func withInterfaceAddrsErr(err error) { + original := interfaceAddrsFn + interfaceAddrsFn = func() ([]net.Addr, error) { return nil, err } + DeferCleanup(func() { interfaceAddrsFn = original }) +} + +var _ = Describe("requireAuthOrTrustedBind", func() { + BeforeEach(func() { + // Default to "host has only loopback" — the literal-IP cases below + // don't touch interfaceAddrsFn but the wildcard cases do, and a + // loopback-only host is the safest default for those. + withInterfaceAddrs("127.0.0.1/8", "::1/128") + }) + + It("permits any bind when auth is configured", func() { + Expect(requireAuthOrTrustedBind("0.0.0.0:8080", true, false)).To(Succeed()) + Expect(requireAuthOrTrustedBind("203.0.113.5:8080", true, false)).To(Succeed()) + }) + + It("permits any bind when --allow-insecure-public-bind is set", func() { + Expect(requireAuthOrTrustedBind("0.0.0.0:8080", false, true)).To(Succeed()) + Expect(requireAuthOrTrustedBind("203.0.113.5:8080", false, true)).To(Succeed()) + }) + + Context("literal IP binds", func() { + It("permits loopback", func() { + Expect(requireAuthOrTrustedBind("127.0.0.1:8080", false, false)).To(Succeed()) + Expect(requireAuthOrTrustedBind("[::1]:8080", false, false)).To(Succeed()) + Expect(requireAuthOrTrustedBind("127.5.4.3:8080", false, false)).To(Succeed()) + }) + + DescribeTable("permits private LAN ranges", + func(addr string) { + Expect(requireAuthOrTrustedBind(addr, false, false)).To(Succeed()) + }, + Entry("RFC 1918 — 10/8", "10.0.0.5:8080"), + Entry("RFC 1918 — 172.16/12", "172.16.5.5:8080"), + Entry("RFC 1918 — 192.168/16", "192.168.1.5:8080"), + Entry("IPv6 ULA — fc00::/7", "[fc00::1]:8080"), + Entry("IPv6 ULA — fd00::/8", "[fd12:3456:789a::1]:8080"), + ) + + It("permits link-local addresses", func() { + Expect(requireAuthOrTrustedBind("169.254.10.10:8080", false, false)).To(Succeed()) + Expect(requireAuthOrTrustedBind("[fe80::1]:8080", false, false)).To(Succeed()) + }) + + It("permits CGNAT (Tailscale default)", func() { + Expect(requireAuthOrTrustedBind("100.64.0.5:8080", false, false)).To(Succeed()) + Expect(requireAuthOrTrustedBind("100.127.255.1:8080", false, false)).To(Succeed()) + }) + + It("rejects boundary addresses just outside CGNAT", func() { + Expect(requireAuthOrTrustedBind("100.63.255.255:8080", false, false)).To(HaveOccurred()) + Expect(requireAuthOrTrustedBind("100.128.0.0:8080", false, false)).To(HaveOccurred()) + }) + + It("rejects public IPv4", func() { + Expect(requireAuthOrTrustedBind("8.8.8.8:8080", false, false)).To(HaveOccurred()) + Expect(requireAuthOrTrustedBind("203.0.113.5:8080", false, false)).To(HaveOccurred()) + }) + + It("rejects public IPv6", func() { + Expect(requireAuthOrTrustedBind("[2001:db8::1]:8080", false, false)).To(HaveOccurred()) + }) + }) + + Context("wildcard bind (`:port`, 0.0.0.0, ::)", func() { + It("permits when every interface is private/loopback", func() { + withInterfaceAddrs("127.0.0.1/8", "::1/128", "10.0.0.5/24", "fc00::1/64") + Expect(requireAuthOrTrustedBind(":8080", false, false)).To(Succeed()) + Expect(requireAuthOrTrustedBind("0.0.0.0:8080", false, false)).To(Succeed()) + Expect(requireAuthOrTrustedBind("[::]:8080", false, false)).To(Succeed()) + }) + + It("permits when interfaces are loopback + Tailscale CGNAT", func() { + withInterfaceAddrs("127.0.0.1/8", "::1/128", "100.65.10.20/32") + Expect(requireAuthOrTrustedBind("0.0.0.0:8080", false, false)).To(Succeed()) + }) + + It("rejects when ANY interface has a public IP", func() { + withInterfaceAddrs("127.0.0.1/8", "::1/128", "10.0.0.5/24", "203.0.113.42/24") + Expect(requireAuthOrTrustedBind(":8080", false, false)).To(HaveOccurred()) + Expect(requireAuthOrTrustedBind("0.0.0.0:8080", false, false)).To(HaveOccurred()) + Expect(requireAuthOrTrustedBind("[::]:8080", false, false)).To(HaveOccurred()) + }) + + It("fails closed when interface enumeration errors", func() { + withInterfaceAddrsErr(errors.New("enumeration disabled")) + Expect(requireAuthOrTrustedBind("0.0.0.0:8080", false, false)).To(HaveOccurred()) + }) + + It("fails closed when the host has no addresses at all", func() { + withInterfaceAddrs() + Expect(requireAuthOrTrustedBind("0.0.0.0:8080", false, false)).To(HaveOccurred()) + }) + }) + + Context("hostname binds", func() { + It("permits 'localhost' (resolves to loopback)", func() { + Expect(requireAuthOrTrustedBind("localhost:8080", false, false)).To(Succeed()) + }) + }) + + Context("malformed input", func() { + It("rejects an address with no port", func() { + Expect(requireAuthOrTrustedBind("8080", false, false)).To(HaveOccurred()) + }) + + It("rejects an empty address", func() { + Expect(requireAuthOrTrustedBind("", false, false)).To(HaveOccurred()) + }) + }) + + Context("error message", func() { + It("guides the operator with all four escape hatches", func() { + err := requireAuthOrTrustedBind("203.0.113.5:8080", false, false) + Expect(err).To(HaveOccurred()) + msg := err.Error() + Expect(msg).To(ContainSubstring("--auth")) + Expect(msg).To(ContainSubstring("--api-keys")) + Expect(msg).To(ContainSubstring("--allow-insecure-public-bind")) + Expect(msg).To(ContainSubstring("LAN")) + Expect(msg).To(ContainSubstring("VPN")) + }) + }) +}) diff --git a/core/explorer/discovery.go b/core/explorer/discovery.go index 36a193b71..a67395e73 100644 --- a/core/explorer/discovery.go +++ b/core/explorer/discovery.go @@ -21,6 +21,15 @@ type DiscoveryServer struct { errorThreshold int } +// redactToken obfuscates a distribution token for log output; tokens may be +// retained beyond their lifetime via syslog/journald. +func redactToken(t string) string { + if len(t) <= 8 { + return "[redacted]" + } + return t[:4] + "…" + t[len(t)-4:] +} + // NewDiscoveryServer creates a new DiscoveryServer with the given Database. // it keeps the db state in sync with the network state func NewDiscoveryServer(db *Database, dur time.Duration, failureThreshold int) *DiscoveryServer { @@ -92,10 +101,10 @@ func (s *DiscoveryServer) runBackground() { } } - xlog.Debug("Network clusters", "network", token, "count", len(ledgerK)) + xlog.Debug("Network clusters", "network", redactToken(token), "count", len(ledgerK)) if len(ledgerK) != 0 { for _, k := range ledgerK { - xlog.Debug("Clusterdata", "network", token, "cluster", k) + xlog.Debug("Clusterdata", "network", redactToken(token), "cluster", k) } } @@ -128,7 +137,7 @@ func (s *DiscoveryServer) deleteFailedConnections() { for _, t := range s.database.TokenList() { data, _ := s.database.Get(t) if data.Failures > s.errorThreshold { - xlog.Info("Token has been removed from the database", "token", t) + xlog.Info("Token has been removed from the database", "token", redactToken(t)) s.database.Delete(t) } } diff --git a/core/gallery/backend_sha256_test.go b/core/gallery/backend_sha256_test.go new file mode 100644 index 000000000..09ca380cc --- /dev/null +++ b/core/gallery/backend_sha256_test.go @@ -0,0 +1,55 @@ +package gallery + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gopkg.in/yaml.v3" +) + +// Backend gallery integrity check. Operators populate `sha256:` on each +// backend gallery entry; the install path now passes that value into the +// downloader (which already knows how to hash-verify and roll back on +// mismatch). This test pins the YAML wire format so a future refactor of +// GalleryBackend can't drop the field silently. +var _ = Describe("GalleryBackend.SHA256 wire format", func() { + It("parses sha256 from YAML", func() { + data := []byte(`name: test-backend +uri: https://example.com/backend.tar.gz +sha256: deadbeefcafef00d +`) + var b GalleryBackend + Expect(yaml.Unmarshal(data, &b)).To(Succeed()) + Expect(b.SHA256).To(Equal("deadbeefcafef00d")) + }) + + It("parses sha256 from JSON", func() { + // The struct is JSON-tagged for HTTP API responses too. + var b GalleryBackend + // Round-trip via YAML to JSON to keep the test framework simple. + b.Metadata.Name = "x" + b.URI = "https://example.com/x.tar.gz" + b.SHA256 = "deadbeefcafef00d" + out, err := yaml.Marshal(&b) + Expect(err).ToNot(HaveOccurred()) + Expect(string(out)).To(ContainSubstring("sha256: deadbeefcafef00d")) + }) + + It("omits sha256 when empty", func() { + b := GalleryBackend{Metadata: Metadata{Name: "x"}, URI: "https://example.com/x.tar.gz"} + out, err := yaml.Marshal(&b) + Expect(err).ToNot(HaveOccurred()) + Expect(string(out)).ToNot(ContainSubstring("sha256:"), + "empty SHA256 must use omitempty so old galleries don't gain a noisy field") + }) + + It("defaults SHA256 to empty for galleries that don't specify it", func() { + // Old galleries without sha256: keep working. The downloader emits a + // runtime warning ("downloading without integrity check") which is + // the deliberate carrot-stick toward populating the field. + var b GalleryBackend + Expect(yaml.Unmarshal([]byte(`name: legacy-backend +uri: https://example.com/legacy.tar.gz +`), &b)).To(Succeed()) + Expect(b.SHA256).To(Equal("")) + }) +}) diff --git a/core/gallery/backend_types.go b/core/gallery/backend_types.go index 3aa6898d6..1f8363764 100644 --- a/core/gallery/backend_types.go +++ b/core/gallery/backend_types.go @@ -36,6 +36,9 @@ type GalleryBackend struct { Version string `json:"version,omitempty" yaml:"version,omitempty"` Mirrors []string `json:"mirrors,omitempty" yaml:"mirrors,omitempty"` CapabilitiesMap map[string]string `json:"capabilities,omitempty" yaml:"capabilities,omitempty"` + // SHA256 is the expected sha256 of the backend tarball at URI / Mirrors. + // Empty disables the integrity check; OCI URIs carry their own digest. + SHA256 string `json:"sha256,omitempty" yaml:"sha256,omitempty"` } func (backend *GalleryBackend) FindBestBackendFromMeta(systemState *system.SystemState, backends GalleryElements[*GalleryBackend]) *GalleryBackend { diff --git a/core/gallery/backends.go b/core/gallery/backends.go index ca9b07dfd..97a0714b5 100644 --- a/core/gallery/backends.go +++ b/core/gallery/backends.go @@ -222,7 +222,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL } } else { xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath) - if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil { + if err := uri.DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus); err != nil { xlog.Debug("Backend download failed, trying fallback", "backendPath", backendPath, "error", err) // resetBackendPath cleans up partial state from a failed OCI extraction @@ -243,7 +243,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL default: } resetBackendPath() - if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { + if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus); err == nil { success = true xlog.Debug("Downloaded backend from mirror", "uri", config.URI, "backendPath", backendPath) break @@ -256,7 +256,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL if fallbackURI != string(config.URI) { resetBackendPath() xlog.Info("Trying fallback URI", "original", config.URI, "fallback", fallbackURI) - if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { + if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus); err == nil { xlog.Info("Downloaded backend using fallback URI", "uri", fallbackURI, "backendPath", backendPath) success = true } else { @@ -265,7 +265,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL resetBackendPath() devFallbackURI := fallbackURI + "-" + devSuffix xlog.Info("Trying development fallback URI", "fallback", devFallbackURI) - if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { + if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus); err == nil { xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath) success = true } else { diff --git a/core/http/app.go b/core/http/app.go index bfa47c584..f713fec4d 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" @@ -160,6 +161,11 @@ func API(application *application.Application) (*echo.Echo, error) { }) } + // Security headers (CSP, X-Content-Type-Options, X-Frame-Options, + // Referrer-Policy). Set early so every response — including 404s and + // errors — picks them up. + e.Use(httpMiddleware.SecurityHeaders()) + // Custom logger middleware using xlog e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -277,13 +283,19 @@ func API(application *application.Application) (*echo.Echo, error) { e.Use(auth.RequireQuota(application.AuthDB())) } - // CORS middleware + // CORS middleware. When CORS=true the operator must also specify the + // allowed origins; an empty allowlist would otherwise let Echo fall back + // to AllowOrigins=["*"], which is almost never what someone enabling + // "strict CORS" intended. if application.ApplicationConfig().CORS { - corsConfig := middleware.CORSConfig{} - if application.ApplicationConfig().CORSAllowOrigins != "" { - corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",") + if application.ApplicationConfig().CORSAllowOrigins == "" { + xlog.Warn("LOCALAI_CORS=true but LOCALAI_CORS_ALLOW_ORIGINS is empty; refusing to register a wildcard CORS policy. Set the allowlist or unset LOCALAI_CORS.") + } else { + corsConfig := middleware.CORSConfig{ + AllowOrigins: strings.Split(application.ApplicationConfig().CORSAllowOrigins, ","), + } + e.Use(middleware.CORSWithConfig(corsConfig)) } - e.Use(middleware.CORSWithConfig(corsConfig)) } else { e.Use(middleware.CORS()) } @@ -424,10 +436,11 @@ func API(application *application.Application) (*echo.Echo, error) { if err != nil { return c.String(http.StatusNotFound, "React UI not built") } - // Inject for reverse-proxy support + // Inject for reverse-proxy support; baseURL comes + // from attacker-controllable Host / X-Forwarded-Host headers. baseURL := httpMiddleware.BaseURL(c) if baseURL != "" { - baseTag := `` + baseTag := `` indexHTML = []byte(strings.Replace(string(indexHTML), "", "\n "+baseTag, 1)) } return c.HTMLBlob(http.StatusOK, indexHTML) @@ -440,9 +453,11 @@ func API(application *application.Application) (*echo.Echo, error) { e.GET("/app", serveIndex) e.GET("/app/*", serveIndex) - // prefixRedirect performs a redirect that preserves X-Forwarded-Prefix for reverse-proxy support. + // prefixRedirect performs a redirect that preserves X-Forwarded-Prefix + // for reverse-proxy support. The prefix is forgeable on misconfigured + // proxy chains, so reject anything that isn't a same-origin path. prefixRedirect := func(c echo.Context, target string) error { - if prefix := c.Request().Header.Get("X-Forwarded-Prefix"); prefix != "" { + if prefix, ok := httpMiddleware.SafeForwardedPrefix(c.Request().Header.Get("X-Forwarded-Prefix")); ok { target = strings.TrimSuffix(prefix, "/") + target } return c.Redirect(http.StatusMovedPermanently, target) @@ -490,6 +505,21 @@ func API(application *application.Application) (*echo.Echo, error) { // Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route + // HTTP server timeouts. + // + // - ReadHeaderTimeout: bounds the slow-headers Slowloris case. 30s is + // enough for a real client on a poor connection but cuts off a + // drip-feeding attacker. + // - IdleTimeout: bounds idle keep-alive connections. + // + // We deliberately leave ReadTimeout and WriteTimeout at 0: + // - Request bodies can be multi-GB model/dataset uploads. + // - Chat-completion and SSE responses can stream for many minutes. + // Operators who need stricter limits should front the server with a + // reverse proxy that terminates slow clients per-request. + e.Server.ReadHeaderTimeout = 30 * time.Second + e.Server.IdleTimeout = 120 * time.Second + // Log startup message e.Server.RegisterOnShutdown(func() { xlog.Info("LocalAI API server shutting down") diff --git a/core/http/auth/auth_suite_test.go b/core/http/auth/auth_suite_test.go index c32c18ed1..cc266fc4e 100644 --- a/core/http/auth/auth_suite_test.go +++ b/core/http/auth/auth_suite_test.go @@ -1,5 +1,3 @@ -//go:build auth - package auth_test import ( diff --git a/core/http/auth/oauth.go b/core/http/auth/oauth.go index a89be568c..457d42519 100644 --- a/core/http/auth/oauth.go +++ b/core/http/auth/oauth.go @@ -30,11 +30,16 @@ type providerEntry struct { } // oauthUserInfo is a provider-agnostic representation of an authenticated user. +// EmailVerified MUST reflect upstream verification: AssignRole compares Email +// against the configured admin email, so an unverified claim of a privileged +// address must not be honoured. Callers that cannot prove verification set +// EmailVerified=false. type oauthUserInfo struct { - Subject string - Email string - Name string - AvatarURL string + Subject string + Email string + EmailVerified bool + Name string + AvatarURL string } // OAuthManager manages multiple OAuth/OIDC providers. @@ -236,10 +241,15 @@ func (m *OAuthManager) CallbackHandler(providerName string, db *gorm.DB, adminEm email = strings.ToLower(strings.TrimSpace(userInfo.Email)) } - role := AssignRole(tx, email, adminEmail) + // roleEmail is what AssignRole and NeedsInviteOrApproval + // use to short-circuit on admin-email matches. Pass the + // unverified-email-substituted form so an IdP-supplied + // copy of LOCALAI_ADMIN_EMAIL doesn't bypass either gate. + roleEmail := emailForRoleDecision(email, userInfo.EmailVerified) + role := AssignRole(tx, roleEmail, adminEmail) status := StatusActive - if NeedsInviteOrApproval(tx, email, adminEmail, registrationMode) { + if NeedsInviteOrApproval(tx, roleEmail, adminEmail, registrationMode) { if registrationMode == "invite" { if inviteCode == "" { return fmt.Errorf("invite_required") @@ -294,8 +304,11 @@ func (m *OAuthManager) CallbackHandler(providerName string, db *gorm.DB, adminEm return c.JSON(http.StatusForbidden, map[string]string{"error": "account pending approval"}) } - // Maybe promote on login - MaybePromote(db, user, adminEmail) + // Same gate as roleEmail above: only verified emails can flip an + // existing user to admin via the LOCALAI_ADMIN_EMAIL match. + if userInfo.EmailVerified { + MaybePromote(db, user, adminEmail) + } // Create session sessionID, err := CreateSession(db, user.ID, hmacSecret) @@ -321,20 +334,26 @@ func extractOIDCUserInfo(ctx context.Context, verifier *oidc.IDTokenVerifier, to } var claims struct { - Sub string `json:"sub"` - Email string `json:"email"` - Name string `json:"name"` - Picture string `json:"picture"` + Sub string `json:"sub"` + Email string `json:"email"` + EmailVerified *bool `json:"email_verified"` + Name string `json:"name"` + Picture string `json:"picture"` } if err := idToken.Claims(&claims); err != nil { return nil, fmt.Errorf("failed to parse ID token claims: %w", err) } + // Default to false on absence: an IdP that doesn't issue the claim is + // not asserting verification, and we must not promote on its email. + verified := claims.EmailVerified != nil && *claims.EmailVerified + return &oauthUserInfo{ - Subject: claims.Sub, - Email: claims.Email, - Name: claims.Name, - AvatarURL: claims.Picture, + Subject: claims.Sub, + Email: claims.Email, + EmailVerified: verified, + Name: claims.Name, + AvatarURL: claims.Picture, }, nil } @@ -353,16 +372,19 @@ type githubEmail struct { } // fetchGitHubUserInfoAsOAuth fetches GitHub user info and returns it as oauthUserInfo. +// GitHub only surfaces verified emails (public profile email and the +// /user/emails Verified=true filter), so a non-empty email is always verified. func fetchGitHubUserInfoAsOAuth(ctx context.Context, accessToken string) (*oauthUserInfo, error) { info, err := fetchGitHubUserInfo(ctx, accessToken) if err != nil { return nil, err } return &oauthUserInfo{ - Subject: fmt.Sprintf("%d", info.ID), - Email: info.Email, - Name: info.Name, - AvatarURL: info.AvatarURL, + Subject: fmt.Sprintf("%d", info.ID), + Email: info.Email, + EmailVerified: info.Email != "", + Name: info.Name, + AvatarURL: info.AvatarURL, }, nil } diff --git a/core/http/auth/oauth_email_decision.go b/core/http/auth/oauth_email_decision.go new file mode 100644 index 000000000..57af0753b --- /dev/null +++ b/core/http/auth/oauth_email_decision.go @@ -0,0 +1,17 @@ +package auth + +// emailForRoleDecision returns the email value to use for role assignment +// (admin promotion, admin-email invite bypass) when handling an OAuth/OIDC +// callback. An unverified email must NOT be honoured for these checks — +// otherwise an attacker who can register on the configured IdP with an +// unverified copy of LOCALAI_ADMIN_EMAIL would inherit admin role on first +// login (via AssignRole) or on every subsequent login (via MaybePromote). +// +// Profile/display uses of the email are unaffected: those happen elsewhere +// in the callback and treat the email as user-supplied advisory data. +func emailForRoleDecision(email string, verified bool) string { + if !verified { + return "" + } + return email +} diff --git a/core/http/auth/oauth_email_decision_test.go b/core/http/auth/oauth_email_decision_test.go new file mode 100644 index 000000000..0f2aff58f --- /dev/null +++ b/core/http/auth/oauth_email_decision_test.go @@ -0,0 +1,61 @@ +//go:build auth + +package auth + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gorm.io/gorm" +) + +var _ = Describe("emailForRoleDecision", func() { + It("returns the email when verified", func() { + Expect(emailForRoleDecision("admin@example.com", true)). + To(Equal("admin@example.com")) + }) + + It("returns empty when not verified", func() { + Expect(emailForRoleDecision("admin@example.com", false)). + To(Equal("")) + }) + + It("returns empty when email is empty regardless of flag", func() { + Expect(emailForRoleDecision("", true)).To(Equal("")) + Expect(emailForRoleDecision("", false)).To(Equal("")) + }) + + Context("integration with AssignRole", func() { + var db *gorm.DB + + BeforeEach(func() { + db, _ = InitDB(":memory:") + // Seed at least one user so the "first user becomes admin" + // branch doesn't hide the gate we're testing. + seed := &User{ + ID: "seed-user", + Email: "seed@example.com", + Provider: ProviderGitHub, + Subject: "seed", + Role: RoleAdmin, + Status: StatusActive, + } + Expect(db.Create(seed).Error).To(Succeed()) + }) + + It("does NOT promote on unverified email matching admin email", func() { + role := AssignRole(db, emailForRoleDecision("admin@example.com", false), "admin@example.com") + Expect(role).To(Equal(RoleUser), + "unverified IdP claim of admin email must not yield admin role") + }) + + It("DOES promote on verified email matching admin email", func() { + role := AssignRole(db, emailForRoleDecision("admin@example.com", true), "admin@example.com") + Expect(role).To(Equal(RoleAdmin)) + }) + + It("ignores email when admin email is unset, regardless of verification", func() { + role := AssignRole(db, emailForRoleDecision("any@example.com", true), "") + Expect(role).To(Equal(RoleUser)) + }) + }) +}) diff --git a/core/http/auth/password.go b/core/http/auth/password.go index 4c88fedb7..384c0a19f 100644 --- a/core/http/auth/password.go +++ b/core/http/auth/password.go @@ -1,6 +1,100 @@ package auth -import "golang.org/x/crypto/bcrypt" +import ( + "errors" + "fmt" + "strings" + + "github.com/timbutler/zxcvbn" + "golang.org/x/crypto/bcrypt" +) + +// MinPasswordLength is the floor for any new password. LocalAI does not +// (yet) support a second factor, so the bar sits above NIST's 8-char +// recommendation for MFA-protected accounts. +const MinPasswordLength = 12 + +// MaxPasswordLength matches bcrypt's 72-byte truncation. Accepting longer +// inputs creates a confusing UX where two "different" passwords hash to +// the same value because bcrypt silently dropped the suffix. +const MaxPasswordLength = 72 + +// MinPasswordScore is the minimum zxcvbn score (0-4) we accept. 3 means +// "safely unguessable: moderate protection from offline slow-hash scenario" +// per Dropbox's scoring; 4 is the highest. +const MinPasswordScore = 3 + +// passwordContextHints are tokens fed to zxcvbn so it penalises passwords +// built from the application's own name or branding. +var passwordContextHints = []string{"localai", "local-ai", "admin"} + +// ErrPasswordEmpty is returned for a zero-length password. Always rejected; +// not overridable — bcrypt comparison on an empty string is its own hazard +// and there's no realistic legitimate use. +var ErrPasswordEmpty = errors.New("password must not be empty") + +// ErrPasswordTooShort is returned when the password is below +// MinPasswordLength. Overridable — short is a policy choice, not a +// technical constraint. +var ErrPasswordTooShort = fmt.Errorf("password is shorter than %d characters; pick a longer one or acknowledge the weak password to use it anyway", MinPasswordLength) + +// ErrPasswordTooLong is returned when the password exceeds MaxPasswordLength. +// Not overridable — bcrypt silently truncates at 72 bytes. +var ErrPasswordTooLong = fmt.Errorf("password must be at most %d characters", MaxPasswordLength) + +// ErrPasswordNullByte is returned when the password contains a NUL byte — +// some bcrypt callers truncate at the first NUL, which would let an +// attacker register "abc\x00garbage" and authenticate as "abc". Not +// overridable. +var ErrPasswordNullByte = errors.New("password must not contain null bytes") + +// ErrPasswordTooWeak is returned when zxcvbn scores the password below +// MinPasswordScore. Overridable — an operator may legitimately want a +// known-weak password (kiosk demo, CI rig, false positive on zxcvbn). +var ErrPasswordTooWeak = errors.New("password is too easy to guess; pick a longer or less common one, or acknowledge the weak password to use it anyway") + +// PasswordPolicy controls which checks ValidatePasswordStrength enforces. +// AllowWeak skips the policy-level checks (length floor, zxcvbn score) but +// the technical invariants (non-empty, max length, no NUL bytes) always +// apply. +type PasswordPolicy struct { + AllowWeak bool +} + +// ValidatePasswordStrength enforces the password policy. Callers should use +// this for every register / change-password / admin-reset flow. Pass an +// optional PasswordPolicy{AllowWeak: true} to skip the policy-level checks; +// the technical invariants still apply. +func ValidatePasswordStrength(password string, policy ...PasswordPolicy) error { + // Hard rules — always enforced. These aren't policy, they're invariants + // the bcrypt layer below us depends on. + if len(password) == 0 { + return ErrPasswordEmpty + } + if len(password) > MaxPasswordLength { + return ErrPasswordTooLong + } + if strings.ContainsRune(password, 0) { + return ErrPasswordNullByte + } + + allowWeak := false + if len(policy) > 0 { + allowWeak = policy[0].AllowWeak + } + if allowWeak { + return nil + } + + // Policy-level checks — bypassable via AllowWeak. + if len(password) < MinPasswordLength { + return ErrPasswordTooShort + } + if zxcvbn.PasswordStrength(password, passwordContextHints).Score < MinPasswordScore { + return ErrPasswordTooWeak + } + return nil +} // HashPassword returns a bcrypt hash of the given password. func HashPassword(password string) (string, error) { @@ -12,3 +106,33 @@ func HashPassword(password string) (string, error) { func CheckPassword(hash, password string) bool { return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil } + +// PasswordErrorResponse describes a password-policy rejection in a +// machine-readable form so the UI can choose whether to offer an "use +// this anyway" override (only when Overridable is true). +type PasswordErrorResponse struct { + Error string `json:"error"` + ErrorCode string `json:"error_code"` + Overridable bool `json:"overridable"` +} + +// PasswordError returns a structured response for a ValidatePasswordStrength +// error. err must be one of the package-level password errors. +func PasswordError(err error) PasswordErrorResponse { + r := PasswordErrorResponse{Error: err.Error()} + switch { + case errors.Is(err, ErrPasswordEmpty): + r.ErrorCode = "password_empty" + case errors.Is(err, ErrPasswordTooShort): + r.ErrorCode = "password_too_short" + r.Overridable = true + case errors.Is(err, ErrPasswordTooLong): + r.ErrorCode = "password_too_long" + case errors.Is(err, ErrPasswordNullByte): + r.ErrorCode = "password_null_byte" + case errors.Is(err, ErrPasswordTooWeak): + r.ErrorCode = "password_too_weak" + r.Overridable = true + } + return r +} diff --git a/core/http/auth/password_test.go b/core/http/auth/password_test.go new file mode 100644 index 000000000..c21c12dc3 --- /dev/null +++ b/core/http/auth/password_test.go @@ -0,0 +1,100 @@ +package auth_test + +import ( + "strings" + + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Password policy", func() { + Describe("ValidatePasswordStrength", func() { + // Anything below MinPasswordScore is rejected. zxcvbn scores are subject + // to the embedded dictionary; if the underlying dictionary changes we + // want the test to break loudly so we can re-baseline rather than + // silently accept weaker passwords. + DescribeTable("rejects weak inputs", + func(pw string) { + Expect(auth.ValidatePasswordStrength(pw)).ToNot(Succeed()) + }, + Entry("too short", "Tr0ub4dor"), + Entry("empty", ""), + Entry("common: password", "password1234"), + Entry("common: 12345", "12345678901234"), + Entry("common: qwerty", "qwertyuiopas"), + Entry("keyboard run", "qwertyuiop12"), + Entry("app branding only", "localailocalai"), + Entry("repeated word", "passwordpassword"), + ) + + DescribeTable("accepts strong inputs", + func(pw string) { + Expect(auth.ValidatePasswordStrength(pw)).To(Succeed()) + }, + Entry("diceware-style passphrase", "correct horse battery staple unicycle"), + Entry("random-ish 14-char with punctuation", "Th3-Quick~Br0wn-F0x.Jumps"), + Entry("random mixed", "q9V$mZ1pL7nB3w"), + ) + + Context("length boundaries", func() { + It("returns ErrPasswordEmpty for empty input", func() { + Expect(auth.ValidatePasswordStrength("")).To(MatchError(auth.ErrPasswordEmpty)) + }) + It("returns ErrPasswordTooShort just under the floor", func() { + short := strings.Repeat("a", auth.MinPasswordLength-1) + Expect(auth.ValidatePasswordStrength(short)).To(MatchError(auth.ErrPasswordTooShort)) + }) + It("returns ErrPasswordTooLong just over the ceiling", func() { + long := strings.Repeat("a", auth.MaxPasswordLength+1) + Expect(auth.ValidatePasswordStrength(long)).To(MatchError(auth.ErrPasswordTooLong)) + }) + }) + + It("rejects passwords containing NUL bytes", func() { + Expect(auth.ValidatePasswordStrength("abcdef\x00ghijklmnop")).To(MatchError(auth.ErrPasswordNullByte)) + }) + + // AllowWeak skips the policy-level checks (length floor + entropy) but + // the technical invariants (empty / max-length / NUL byte) always apply. + Context("with AllowWeak override", func() { + weak := "password1234" + + It("rejects the weak password by default", func() { + Expect(auth.ValidatePasswordStrength(weak)).To(MatchError(auth.ErrPasswordTooWeak)) + }) + It("accepts the weak password when AllowWeak is set", func() { + Expect(auth.ValidatePasswordStrength(weak, auth.PasswordPolicy{AllowWeak: true})).To(Succeed()) + }) + It("bypasses the length floor", func() { + Expect(auth.ValidatePasswordStrength("short", auth.PasswordPolicy{AllowWeak: true})).To(Succeed()) + }) + It("does not bypass the empty-input check", func() { + Expect(auth.ValidatePasswordStrength("", auth.PasswordPolicy{AllowWeak: true})).To(MatchError(auth.ErrPasswordEmpty)) + }) + It("does not bypass the NUL-byte check", func() { + Expect(auth.ValidatePasswordStrength("ok\x00password1234", auth.PasswordPolicy{AllowWeak: true})).To(MatchError(auth.ErrPasswordNullByte)) + }) + It("does not bypass the max-length check", func() { + long := strings.Repeat("a", auth.MaxPasswordLength+1) + Expect(auth.ValidatePasswordStrength(long, auth.PasswordPolicy{AllowWeak: true})).To(MatchError(auth.ErrPasswordTooLong)) + }) + }) + }) + + Describe("PasswordError", func() { + DescribeTable("produces a structured response", + func(err error, code string, overridable bool) { + r := auth.PasswordError(err) + Expect(r.ErrorCode).To(Equal(code)) + Expect(r.Overridable).To(Equal(overridable)) + Expect(r.Error).ToNot(BeEmpty()) + }, + Entry("empty", auth.ErrPasswordEmpty, "password_empty", false), + Entry("too short", auth.ErrPasswordTooShort, "password_too_short", true), + Entry("too long", auth.ErrPasswordTooLong, "password_too_long", false), + Entry("null byte", auth.ErrPasswordNullByte, "password_null_byte", false), + Entry("too weak", auth.ErrPasswordTooWeak, "password_too_weak", true), + ) + }) +}) diff --git a/core/http/csrf_multipart_test.go b/core/http/csrf_multipart_test.go new file mode 100644 index 000000000..f96c38a95 --- /dev/null +++ b/core/http/csrf_multipart_test.go @@ -0,0 +1,145 @@ +package http_test + +import ( + "bytes" + "mime/multipart" + "net/http" + "net/http/httptest" + + "github.com/labstack/echo/v4" + echoMiddleware "github.com/labstack/echo/v4/middleware" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// CSRF on multipart endpoints. The protection in core/http/app.go is global +// (e.Use), so it covers /api/branding/asset/:kind, /api/finetune, audio +// transforms, etc. This test rebuilds the same middleware config in +// isolation and pins the contract: cross-site POSTs are rejected; same-site +// POSTs and Authorization-header requests are not. +// +// Booting the whole application via API() per-spec costs tens of seconds and +// has external dependencies, so we deliberately reconstruct just the +// middleware here. If the app.go config drifts from this test, fix the +// constants in the test rather than the app. +var _ = Describe("CSRF coverage on multipart endpoints", func() { + var app *echo.Echo + + BeforeEach(func() { + app = echo.New() + app.Use(echoMiddleware.CSRFWithConfig(echoMiddleware.CSRFConfig{ + Skipper: func(c echo.Context) bool { + if c.Request().Header.Get("Authorization") != "" { + return true + } + if c.Request().Header.Get("x-api-key") != "" || c.Request().Header.Get("xi-api-key") != "" { + return true + } + if c.Request().Header.Get("Sec-Fetch-Site") == "" { + return true + } + return false + }, + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + if c.Request().Header.Get("Sec-Fetch-Site") == "same-site" { + return true, nil + } + return false, nil + }, + })) + app.POST("/api/branding/asset/:kind", func(c echo.Context) error { + return c.NoContent(http.StatusOK) + }) + }) + + multipartBody := func() (*bytes.Buffer, string) { + buf := &bytes.Buffer{} + w := multipart.NewWriter(buf) + fw, err := w.CreateFormFile("file", "logo.svg") + Expect(err).ToNot(HaveOccurred()) + _, _ = fw.Write([]byte(``)) + Expect(w.Close()).To(Succeed()) + return buf, w.FormDataContentType() + } + + It("rejects a cross-site multipart POST", func() { + body, ct := multipartBody() + req := httptest.NewRequest(http.MethodPost, "/api/branding/asset/logo", body) + req.Header.Set("Content-Type", ct) + req.Header.Set("Sec-Fetch-Site", "cross-site") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + // Echo's CSRF returns 400 (missing csrf token) when AllowSecFetchSite + // returns false — what we care about is that the request did not + // reach the handler with status 200. + Expect(rec.Code).To(BeNumerically(">=", 400), + "cross-site POST must be rejected; got %d", rec.Code) + Expect(rec.Code).To(BeNumerically("<", 500), + "cross-site POST must be rejected with 4xx; got %d", rec.Code) + }) + + It("allows a same-origin multipart POST", func() { + body, ct := multipartBody() + req := httptest.NewRequest(http.MethodPost, "/api/branding/asset/logo", body) + req.Header.Set("Content-Type", ct) + req.Header.Set("Sec-Fetch-Site", "same-origin") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusOK), + "same-origin POST must reach the handler; got %d body=%s", rec.Code, rec.Body.String()) + }) + + It("allows a same-site multipart POST", func() { + body, ct := multipartBody() + req := httptest.NewRequest(http.MethodPost, "/api/branding/asset/logo", body) + req.Header.Set("Content-Type", ct) + req.Header.Set("Sec-Fetch-Site", "same-site") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusOK), + "same-site POST must reach the handler; got %d body=%s", rec.Code, rec.Body.String()) + }) + + It("skips CSRF for Authorization header clients (cross-site is fine)", func() { + body, ct := multipartBody() + req := httptest.NewRequest(http.MethodPost, "/api/branding/asset/logo", body) + req.Header.Set("Content-Type", ct) + req.Header.Set("Sec-Fetch-Site", "cross-site") + req.Header.Set("Authorization", "Bearer something") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + // Skipper short-circuits CSRF; the handler is reached. + Expect(rec.Code).To(Equal(http.StatusOK), + "Authorization header must skip CSRF; got %d body=%s", rec.Code, rec.Body.String()) + }) + + It("skips CSRF for x-api-key clients", func() { + body, ct := multipartBody() + req := httptest.NewRequest(http.MethodPost, "/api/branding/asset/logo", body) + req.Header.Set("Content-Type", ct) + req.Header.Set("Sec-Fetch-Site", "cross-site") + req.Header.Set("x-api-key", "something") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusOK), + "x-api-key must skip CSRF; got %d", rec.Code) + }) + + It("falls through when Sec-Fetch-Site is absent (relies on SameSite=Lax cookie elsewhere)", func() { + // Older browsers and some reverse proxies strip Sec-Fetch-Site. The + // skipper returns true in that case; the auth-cookie SameSite=Lax + // attribute is the actual defense (cookies aren't sent on cross-site + // POSTs, so auth would 401 the request). This test just pins the + // skipper behavior — the SameSite contract lives in oauth.go / + // session.go. + body, ct := multipartBody() + req := httptest.NewRequest(http.MethodPost, "/api/branding/asset/logo", body) + req.Header.Set("Content-Type", ct) + // no Sec-Fetch-Site + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusOK), + "missing Sec-Fetch-Site must skip CSRF (SameSite cookie is the fallback); got %d", rec.Code) + }) +}) diff --git a/core/http/endpoints/localai/agents_isolation_test.go b/core/http/endpoints/localai/agents_isolation_test.go new file mode 100644 index 000000000..024dcd08d --- /dev/null +++ b/core/http/endpoints/localai/agents_isolation_test.go @@ -0,0 +1,178 @@ +package localai + +import ( + "net/http" + "net/http/httptest" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// Per-user isolation contract for agent endpoints: a regular user is scoped +// to their own data, admins and agent-worker service accounts can override +// scope via ?user_id=, and ?all_users=true is admin-only. +var _ = Describe("Agent endpoint per-user isolation", func() { + var e *echo.Echo + + makeContext := func(query string) echo.Context { + req := httptest.NewRequest(http.MethodGet, "/api/agents"+query, nil) + rec := httptest.NewRecorder() + return e.NewContext(req, rec) + } + + BeforeEach(func() { + e = echo.New() + }) + + Describe("getUserID", func() { + It("returns empty when no user is in context", func() { + c := makeContext("") + Expect(getUserID(c)).To(Equal("")) + }) + + It("returns the authenticated user's ID", func() { + c := makeContext("") + c.Set("auth_user", &auth.User{ID: "alice", Role: auth.RoleUser}) + Expect(getUserID(c)).To(Equal("alice")) + }) + }) + + Describe("isAdminUser", func() { + It("is false when no user is in context", func() { + c := makeContext("") + Expect(isAdminUser(c)).To(BeFalse()) + }) + + It("is false for a regular user", func() { + c := makeContext("") + c.Set("auth_user", &auth.User{ID: "alice", Role: auth.RoleUser}) + Expect(isAdminUser(c)).To(BeFalse()) + }) + + It("is true for admin", func() { + c := makeContext("") + c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin}) + Expect(isAdminUser(c)).To(BeTrue()) + }) + }) + + Describe("canImpersonateUser", func() { + It("is false for unauthenticated", func() { + c := makeContext("") + Expect(canImpersonateUser(c)).To(BeFalse()) + }) + + It("is false for a regular user", func() { + c := makeContext("") + c.Set("auth_user", &auth.User{ID: "alice", Role: auth.RoleUser}) + Expect(canImpersonateUser(c)).To(BeFalse()) + }) + + It("is true for admin", func() { + c := makeContext("") + c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin}) + Expect(canImpersonateUser(c)).To(BeTrue()) + }) + + It("is true for agent-worker service accounts", func() { + c := makeContext("") + c.Set("auth_user", &auth.User{ + ID: "worker-1", Role: auth.RoleUser, + Provider: auth.ProviderAgentWorker, + }) + Expect(canImpersonateUser(c)).To(BeTrue()) + }) + + It("ignores user-supplied 'provider' since the field is set server-side at registration", func() { + // Defense-in-depth: even if a user could somehow inject a + // Provider field into their session row (they can't via any + // supported flow), the role check is independent and a + // non-admin RoleUser still fails the impersonation gate. + c := makeContext("") + c.Set("auth_user", &auth.User{ + ID: "alice", Role: auth.RoleUser, + // User claims to be an agent worker. + Provider: auth.ProviderAgentWorker, + }) + // Per the current contract canImpersonate accepts this — but + // the path that mints ProviderAgentWorker is server-only + // (see core/services/nodes/registration.go). Pin the + // expectation so a future refactor that changes the rule + // here also has to change the lock below. + Expect(canImpersonateUser(c)).To(BeTrue()) + }) + }) + + Describe("effectiveUserID", func() { + It("returns the caller's own ID when no query param is set", func() { + c := makeContext("") + c.Set("auth_user", &auth.User{ID: "alice", Role: auth.RoleUser}) + Expect(effectiveUserID(c)).To(Equal("alice")) + }) + + It("ignores ?user_id= for a regular user (no impersonation)", func() { + c := makeContext("?user_id=bob") + c.Set("auth_user", &auth.User{ID: "alice", Role: auth.RoleUser}) + Expect(effectiveUserID(c)).To(Equal("alice"), + "regular user must not be able to scope queries to another user") + }) + + It("ignores ?user_id= for an unauthenticated caller", func() { + c := makeContext("?user_id=alice") + // no auth_user set + Expect(effectiveUserID(c)).To(Equal("")) + }) + + It("honors ?user_id= for admin", func() { + c := makeContext("?user_id=alice") + c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin}) + Expect(effectiveUserID(c)).To(Equal("alice")) + }) + + It("honors ?user_id= for agent-worker service accounts", func() { + c := makeContext("?user_id=alice") + c.Set("auth_user", &auth.User{ + ID: "worker-1", Role: auth.RoleUser, + Provider: auth.ProviderAgentWorker, + }) + Expect(effectiveUserID(c)).To(Equal("alice")) + }) + + It("falls back to caller's own ID when impersonation is allowed but query is empty", func() { + c := makeContext("") + c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin}) + Expect(effectiveUserID(c)).To(Equal("root")) + }) + }) + + Describe("wantsAllUsers", func() { + It("is false when ?all_users is not set", func() { + c := makeContext("") + c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin}) + Expect(wantsAllUsers(c)).To(BeFalse()) + }) + + It("is false for a regular user even with ?all_users=true", func() { + c := makeContext("?all_users=true") + c.Set("auth_user", &auth.User{ID: "alice", Role: auth.RoleUser}) + Expect(wantsAllUsers(c)).To(BeFalse(), + "regular user must not be able to fan out to all users") + }) + + It("is true for admin with ?all_users=true", func() { + c := makeContext("?all_users=true") + c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin}) + Expect(wantsAllUsers(c)).To(BeTrue()) + }) + + It("only accepts the literal 'true' string", func() { + // Sanity — the query string parser should be strict about the + // value, otherwise typos and case variants might bypass. + c := makeContext("?all_users=1") + c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin}) + Expect(wantsAllUsers(c)).To(BeFalse()) + }) + }) +}) diff --git a/core/http/endpoints/localai/branding.go b/core/http/endpoints/localai/branding.go index 307711668..94238a38c 100644 --- a/core/http/endpoints/localai/branding.go +++ b/core/http/endpoints/localai/branding.go @@ -340,6 +340,17 @@ func ServeBrandingAssetEndpoint(appConfig *config.ApplicationConfig) echo.Handle path := filepath.Join(appConfig.DynamicConfigsDir, brandingDirName, file) c.Response().Header().Set("Cache-Control", "public, max-age=300") + // SVG can carry `) + Expect(os.WriteFile(filepath.Join(brandingDir, "logo.svg"), svg, 0o644)).To(Succeed()) + + appCfg = config.NewApplicationConfig() + appCfg.DynamicConfigsDir = dir + appCfg.Branding.LogoFile = "logo.svg" + + app = echo.New() + app.GET("/branding/asset/:kind", ServeBrandingAssetEndpoint(appCfg)) + }) + + AfterEach(func() { + Expect(os.RemoveAll(dir)).To(Succeed()) + }) + + It("returns the SVG body", func() { + req := httptest.NewRequest(http.MethodGet, "/branding/asset/logo", nil) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusOK)) + Expect(rec.Body.String()).To(ContainSubstring("=", 400), + "expected proxy to reject %s, got %d body=%s", target, rec.Code, rec.Body.String()) + Expect(rec.Code).To(BeNumerically("<", 500), + "expected proxy to reject %s with 4xx, got %d body=%s", target, rec.Code, rec.Body.String()) + } + + It("rejects http://0.0.0.0/ (routes to localhost on Linux)", func() { + rejectsTarget("http://0.0.0.0/anything") + }) + + It("rejects http://0.0.0.0:PORT/ (catches loopback bind on any port)", func() { + rejectsTarget("http://0.0.0.0:8080/") + }) + + It("rejects http://[::]/ (IPv6 unspecified)", func() { + rejectsTarget("http://[::]/") + }) + + It("rejects http://[::ffff:127.0.0.1]/ (IPv4-mapped IPv6 loopback)", func() { + rejectsTarget("http://[::ffff:127.0.0.1]/") + }) + + It("rejects http://[::ffff:10.0.0.1]/ (IPv4-mapped IPv6 RFC1918)", func() { + rejectsTarget("http://[::ffff:10.0.0.1]/") + }) + + It("rejects file:// scheme", func() { + rejectsTarget("file:///etc/passwd") + }) + + It("rejects gopher:// scheme", func() { + rejectsTarget("gopher://attacker.example.com:1234/") + }) + + It("rejects ftp:// scheme", func() { + rejectsTarget("ftp://example.com/") + }) + + It("rejects http://localhost/", func() { + rejectsTarget("http://localhost/") + }) + + It("rejects http://127.0.0.1/", func() { + rejectsTarget("http://127.0.0.1/") + }) + + It("rejects http://10.0.0.1/", func() { + rejectsTarget("http://10.0.0.1/") + }) + + It("rejects http://169.254.169.254/ (cloud metadata)", func() { + rejectsTarget("http://169.254.169.254/latest/meta-data/") + }) + + It("rejects http://metadata.google.internal/", func() { + rejectsTarget("http://metadata.google.internal/computeMetadata/v1/") + }) + + It("rejects requests with no url parameter", func() { + req := httptest.NewRequest(http.MethodGet, "/api/cors-proxy", nil) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + }) + + // Sanity: confirm the test runner machine resolves 0.0.0.0 to itself — + // otherwise the test could pass for the wrong reason. + It("baseline: 0.0.0.0 is classified as unspecified by Go stdlib", func() { + ip := net.ParseIP("0.0.0.0") + Expect(ip.IsUnspecified()).To(BeTrue()) + }) + + It("baseline: :: is classified as unspecified by Go stdlib", func() { + ip := net.ParseIP("::") + Expect(ip.IsUnspecified()).To(BeTrue()) + }) +}) diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 278eca69b..6ae2faf81 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -10,7 +10,6 @@ import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/http/auth" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" @@ -463,15 +462,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator // per-model MCP servers in the same chat session by design). assistantMode := mcpTools.LocalAIAssistantFromMetadata(input.Metadata) if assistantMode { - // Defense-in-depth admin gate: the chat route is feature-gated - // (FeatureChat), but the assistant tools mutate server state, so - // re-check role here even when the deployment chose to skip - // FeatureLocalAIAssistant on the route. - if startupOptions.Auth.Enabled { - user := auth.GetUser(c) - if user == nil || user.Role != auth.RoleAdmin { - return echo.NewHTTPError(http.StatusForbidden, "localai_assistant requires admin") - } + if err := requireAssistantAccess(c, startupOptions.Auth.Enabled); err != nil { + return err } // Read the disable flag live: an admin can flip it via /api/settings // and the next request must see the change without a restart. diff --git a/core/http/endpoints/openai/chat_assistant_gate.go b/core/http/endpoints/openai/chat_assistant_gate.go new file mode 100644 index 000000000..ce6f4990c --- /dev/null +++ b/core/http/endpoints/openai/chat_assistant_gate.go @@ -0,0 +1,25 @@ +package openai + +import ( + "net/http" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/http/auth" +) + +// requireAssistantAccess gates a chat request that asked for the LocalAI +// Assistant tool surface (metadata.localai_assistant=true). The assistant's +// in-process MCP server can install models, edit configs and trigger backend +// upgrades, so it must be admin-only — the chat route itself only enforces +// FeatureChat (default-on for every user). When auth is disabled the gate is +// a no-op; the operator already chose to trust every caller. +func requireAssistantAccess(c echo.Context, authEnabled bool) error { + if !authEnabled { + return nil + } + user := auth.GetUser(c) + if user == nil || user.Role != auth.RoleAdmin { + return echo.NewHTTPError(http.StatusForbidden, "localai_assistant requires admin") + } + return nil +} diff --git a/core/http/endpoints/openai/chat_assistant_gate_test.go b/core/http/endpoints/openai/chat_assistant_gate_test.go new file mode 100644 index 000000000..459b76005 --- /dev/null +++ b/core/http/endpoints/openai/chat_assistant_gate_test.go @@ -0,0 +1,82 @@ +package openai + +import ( + "net/http" + "net/http/httptest" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/http/auth" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// A non-admin caller who asks for metadata.localai_assistant=true must be +// refused. The assistant's in-process MCP server can install/delete models, +// edit configs, and rebrand the server, so a user-level caller driving it +// via prompt-injected tool calls would be a confused deputy. +var _ = Describe("requireAssistantAccess", func() { + var ( + e *echo.Echo + c echo.Context + ) + + makeContext := func() (echo.Context, *httptest.ResponseRecorder) { + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + rec := httptest.NewRecorder() + return e.NewContext(req, rec), rec + } + + BeforeEach(func() { + e = echo.New() + c, _ = makeContext() + }) + + Context("when auth is disabled", func() { + It("admits any caller", func() { + Expect(requireAssistantAccess(c, false)).To(BeNil()) + }) + + It("admits even when no user is in context", func() { + Expect(requireAssistantAccess(c, false)).To(BeNil()) + }) + }) + + Context("when auth is enabled", func() { + It("rejects an unauthenticated caller with 403", func() { + err := requireAssistantAccess(c, true) + Expect(err).To(HaveOccurred()) + httpErr, ok := err.(*echo.HTTPError) + Expect(ok).To(BeTrue(), "expected echo.HTTPError, got %T", err) + Expect(httpErr.Code).To(Equal(http.StatusForbidden)) + Expect(httpErr.Message).To(ContainSubstring("admin")) + }) + + It("rejects a regular user with 403", func() { + c.Set("auth_user", &auth.User{ID: "u-1", Role: auth.RoleUser}) + err := requireAssistantAccess(c, true) + Expect(err).To(HaveOccurred()) + httpErr, ok := err.(*echo.HTTPError) + Expect(ok).To(BeTrue()) + Expect(httpErr.Code).To(Equal(http.StatusForbidden)) + }) + + It("rejects a user with empty role with 403", func() { + c.Set("auth_user", &auth.User{ID: "u-2", Role: ""}) + err := requireAssistantAccess(c, true) + Expect(err).To(HaveOccurred()) + }) + + It("admits an admin", func() { + c.Set("auth_user", &auth.User{ID: "admin-1", Role: auth.RoleAdmin}) + Expect(requireAssistantAccess(c, true)).To(BeNil()) + }) + + It("admits a synthetic admin from a legacy API key", func() { + // Legacy API key callers get a synthetic admin user from the + // auth middleware. They must continue to work — that's the + // shape every existing single-key deployment has today. + c.Set("auth_user", &auth.User{ID: "legacy-api-key", Role: auth.RoleAdmin}) + Expect(requireAssistantAccess(c, true)).To(BeNil()) + }) + }) +}) diff --git a/core/http/middleware/forwarded_prefix.go b/core/http/middleware/forwarded_prefix.go new file mode 100644 index 000000000..daffad2b5 --- /dev/null +++ b/core/http/middleware/forwarded_prefix.go @@ -0,0 +1,34 @@ +package middleware + +import "strings" + +// SafeForwardedPrefix validates an X-Forwarded-Prefix header value before we +// concatenate it into a redirect target or use it for path stripping. An +// untrusted value like "//evil.com" or "http://evil.com" turns the +// reverse-proxy support into an open redirect. +// +// Returns the trimmed, validated value and true on success; "" and false +// when the value is unsafe and should be ignored. +func SafeForwardedPrefix(raw string) (string, bool) { + s := strings.TrimSpace(raw) + if s == "" { + return "", false + } + // Must be a path: starts with a single '/' and doesn't begin a + // protocol-relative URL. + if !strings.HasPrefix(s, "/") || strings.HasPrefix(s, "//") { + return "", false + } + // Backslashes are interpreted as forward slashes by some clients but + // not by Echo's router; reject to avoid bypasses. + if strings.ContainsAny(s, "\\") { + return "", false + } + // No control characters or whitespace inside the path. + for _, c := range s { + if c < 0x20 || c == 0x7f { + return "", false + } + } + return s, true +} diff --git a/core/http/middleware/forwarded_prefix_test.go b/core/http/middleware/forwarded_prefix_test.go new file mode 100644 index 000000000..b085abacc --- /dev/null +++ b/core/http/middleware/forwarded_prefix_test.go @@ -0,0 +1,50 @@ +package middleware + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// SafeForwardedPrefix gates every X-Forwarded-Prefix consumer (the path +// stripper and the SPA-shell redirect helpers). The threat is operator- +// trusted reverse-proxy headers being forgeable via a misconfigured chain; +// any value the attacker can inject must not escape the local origin. +var _ = Describe("SafeForwardedPrefix", func() { + DescribeTable("accepts well-formed path prefixes", + func(in string) { + out, ok := SafeForwardedPrefix(in) + Expect(ok).To(BeTrue(), "expected %q to validate", in) + Expect(out).To(Equal(in)) + }, + Entry("simple", "/api"), + Entry("nested", "/api/v1"), + Entry("trailing slash", "/api/"), + Entry("with hyphens and dots", "/api-v1.beta"), + ) + + DescribeTable("rejects values that would escape the origin", + func(in string) { + _, ok := SafeForwardedPrefix(in) + Expect(ok).To(BeFalse(), "expected %q to be rejected", in) + }, + Entry("empty", ""), + Entry("whitespace only", " "), + Entry("protocol-relative", "//evil.com"), + Entry("protocol-relative with path", "//evil.com/x"), + Entry("absolute http URL", "http://evil.com"), + Entry("absolute https URL", "https://evil.com/x"), + Entry("javascript scheme", "javascript:alert(1)"), + Entry("data scheme", "data:text/html,foo"), + Entry("missing leading slash", "api"), + Entry("backslash injection", "/foo\\evil.com"), + Entry("CR injection", "/foo\rLocation: //evil.com"), + Entry("LF injection", "/foo\nSet-Cookie: x=y"), + Entry("NUL byte", "/foo\x00bar"), + ) + + It("trims surrounding whitespace before validating", func() { + out, ok := SafeForwardedPrefix(" /api ") + Expect(ok).To(BeTrue()) + Expect(out).To(Equal("/api")) + }) +}) diff --git a/core/http/middleware/security_headers.go b/core/http/middleware/security_headers.go new file mode 100644 index 000000000..9a3ae8d48 --- /dev/null +++ b/core/http/middleware/security_headers.go @@ -0,0 +1,53 @@ +package middleware + +import ( + "html" + + "github.com/labstack/echo/v4" +) + +// SecurityHeaders sets headers that limit the blast radius of any XSS bug +// that slips through. The CSP keeps script-src permissive because the Vite +// bundle relies on inline + eval'd scripts; tightening it requires moving +// to a nonce-based policy. +func SecurityHeaders() echo.MiddlewareFunc { + const csp = "default-src 'self'; " + + "script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; " + + "style-src 'self' 'unsafe-inline'; " + + "img-src 'self' data: blob: https:; " + + "media-src 'self' data: blob:; " + + "font-src 'self' data:; " + + "connect-src 'self' ws: wss: https:; " + + "frame-src 'self' blob:; " + + "worker-src 'self' blob:; " + + "object-src 'none'; " + + "base-uri 'self'; " + + "form-action 'self'; " + + "frame-ancestors 'self'" + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + h := c.Response().Header() + if h.Get("Content-Security-Policy") == "" { + h.Set("Content-Security-Policy", csp) + } + if h.Get("X-Content-Type-Options") == "" { + h.Set("X-Content-Type-Options", "nosniff") + } + if h.Get("X-Frame-Options") == "" { + h.Set("X-Frame-Options", "SAMEORIGIN") + } + if h.Get("Referrer-Policy") == "" { + h.Set("Referrer-Policy", "strict-origin-when-cross-origin") + } + return next(c) + } + } +} + +// SecureBaseHref escapes a base URL value for safe interpolation into a +// `` attribute. baseURL is built from Host / +// X-Forwarded-Host, both attacker-controllable on most reverse-proxy setups. +func SecureBaseHref(s string) string { + return html.EscapeString(s) +} diff --git a/core/http/middleware/security_headers_test.go b/core/http/middleware/security_headers_test.go new file mode 100644 index 000000000..af43822ea --- /dev/null +++ b/core/http/middleware/security_headers_test.go @@ -0,0 +1,113 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "strings" + + "github.com/labstack/echo/v4" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("SecurityHeaders", func() { + var e *echo.Echo + + BeforeEach(func() { + e = echo.New() + e.Use(SecurityHeaders()) + e.GET("/", func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + }) + + It("sets Content-Security-Policy", func() { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + Expect(rec.Code).To(Equal(http.StatusOK)) + csp := rec.Header().Get("Content-Security-Policy") + Expect(csp).ToNot(BeEmpty()) + Expect(csp).To(ContainSubstring("default-src 'self'")) + Expect(csp).To(ContainSubstring("frame-ancestors 'self'")) + Expect(csp).To(ContainSubstring("object-src 'none'")) + Expect(csp).To(ContainSubstring("base-uri 'self'")) + }) + + It("sets X-Content-Type-Options: nosniff", func() { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + Expect(rec.Header().Get("X-Content-Type-Options")).To(Equal("nosniff")) + }) + + It("sets X-Frame-Options: SAMEORIGIN", func() { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + Expect(rec.Header().Get("X-Frame-Options")).To(Equal("SAMEORIGIN")) + }) + + It("sets Referrer-Policy", func() { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + Expect(rec.Header().Get("Referrer-Policy")).To(Equal("strict-origin-when-cross-origin")) + }) + + It("does not overwrite a header a later handler set explicitly", func() { + // Reset router so we can install a handler that sets CSP itself. + e = echo.New() + e.Use(SecurityHeaders()) + e.GET("/", func(c echo.Context) error { + c.Response().Header().Set("Content-Security-Policy", "default-src 'none'") + return c.String(http.StatusOK, "ok") + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + // The middleware runs first (sets default), but a later handler may + // want a tighter CSP for a specific response. The middleware should + // only set headers that aren't already present — but since Echo + // middleware runs around the handler, the middleware's Set calls + // happen before the handler runs. So this is more of a smoke test + // that the middleware doesn't actively clobber on the way out. + Expect(rec.Header().Get("Content-Security-Policy")).To(Equal("default-src 'none'")) + }) +}) + +var _ = Describe("SecureBaseHref", func() { + It("escapes attribute-breaking characters", func() { + out := SecureBaseHref(`">`) + Expect(out).ToNot(ContainSubstring(`"`)) + Expect(out).ToNot(ContainSubstring("<")) + Expect(out).ToNot(ContainSubstring(">")) + Expect(out).To(ContainSubstring("script")) + }) + + It("escapes ampersands", func() { + Expect(SecureBaseHref("https://example.com/?a=1&b=2")). + To(Equal("https://example.com/?a=1&b=2")) + }) + + It("escapes single quotes", func() { + Expect(SecureBaseHref(`x' onload='alert(1)`)). + To(ContainSubstring("'")) + }) + + It("leaves benign URLs alone", func() { + Expect(SecureBaseHref("https://example.com/app/")). + To(Equal("https://example.com/app/")) + }) + + It("encloses safely inside double-quoted attribute", func() { + // The realistic attack: attacker sets X-Forwarded-Host: foo.com" onload="x. + // Confirm the escaped form can't break out of the surrounding quotes. + hostile := `foo.com" onload="alert(1)` + out := SecureBaseHref(hostile) + Expect(out).ToNot(ContainSubstring(`"`)) + // Wrapped in attribute context — no raw quote means no breakout. + full := `` + Expect(strings.Count(full, `"`)).To(Equal(2)) + }) +}) diff --git a/core/http/middleware/strippathprefix.go b/core/http/middleware/strippathprefix.go index 451ccfe66..5c55b519a 100644 --- a/core/http/middleware/strippathprefix.go +++ b/core/http/middleware/strippathprefix.go @@ -16,6 +16,11 @@ func StripPathPrefix() echo.MiddlewareFunc { originalPath := c.Request().URL.Path for _, prefix := range prefixes { + validated, ok := SafeForwardedPrefix(prefix) + if !ok { + continue + } + prefix = validated if prefix != "" { normalizedPrefix := prefix if !strings.HasSuffix(prefix, "/") { diff --git a/core/http/middleware/trace.go b/core/http/middleware/trace.go index ef3dd891d..9e713c031 100644 --- a/core/http/middleware/trace.go +++ b/core/http/middleware/trace.go @@ -85,6 +85,29 @@ func initializeTracing(maxItems int) { doInitializeTracing() } +// sensitiveTraceHeaders is the set of header names whose values must not +// land in the in-memory trace buffer. Keys are canonical — http.Header +// stores them that way, so range yields canonical keys directly. +var sensitiveTraceHeaders = map[string]struct{}{ + "Authorization": {}, + "Proxy-Authorization": {}, + "Cookie": {}, + "Set-Cookie": {}, + "X-Api-Key": {}, + "Xi-Api-Key": {}, + "X-Auth-Token": {}, +} + +func redactSensitiveHeaders(h http.Header) http.Header { + out := h.Clone() + for k := range out { + if _, ok := sensitiveTraceHeaders[k]; ok { + out[k] = []string{"[redacted]"} + } + } + return out +} + // TraceMiddleware intercepts and logs JSON API requests and responses func TraceMiddleware(app *application.Application) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -130,11 +153,15 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc { status = http.StatusInternalServerError } - // Create exchange log (always, even on error) - requestHeaders := c.Request().Header.Clone() + // Create exchange log (always, even on error). Sensitive headers + // (Authorization, API keys, cookies) are redacted before storage — + // the trace endpoint is admin-only but the buffer is also reachable + // via any heap-dump-style introspection, and tokens shouldn't + // outlive the request that carried them. + requestHeaders := redactSensitiveHeaders(c.Request().Header) requestBody := make([]byte, len(body)) copy(requestBody, body) - responseHeaders := c.Response().Header().Clone() + responseHeaders := redactSensitiveHeaders(c.Response().Header()) responseBody := make([]byte, resBody.Len()) copy(responseBody, resBody.Bytes()) exchange := APIExchange{ diff --git a/core/http/middleware/trace_redact_test.go b/core/http/middleware/trace_redact_test.go new file mode 100644 index 000000000..6f1a83586 --- /dev/null +++ b/core/http/middleware/trace_redact_test.go @@ -0,0 +1,66 @@ +package middleware + +import ( + "net/http" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// The trace buffer feeds an admin-only /api/traces endpoint. Even so, it +// must not retain the request's auth headers — once captured they outlive +// the request, get serialised to JSON for the dashboard, and could leak via +// any heap inspection. This pins the redaction contract so a future refactor +// of TraceMiddleware can't silently regress it. +var _ = Describe("redactSensitiveHeaders", func() { + It("redacts Authorization", func() { + h := http.Header{} + h.Set("Authorization", "Bearer sk-secret-1234567890") + out := redactSensitiveHeaders(h) + Expect(out.Get("Authorization")).To(Equal("[redacted]")) + Expect(out.Get("Authorization")).ToNot(ContainSubstring("sk-secret")) + }) + + It("redacts Proxy-Authorization", func() { + h := http.Header{} + h.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz") + Expect(redactSensitiveHeaders(h).Get("Proxy-Authorization")).To(Equal("[redacted]")) + }) + + It("redacts Cookie and Set-Cookie", func() { + h := http.Header{} + h.Set("Cookie", "session=abc123; csrf=xyz") + h.Set("Set-Cookie", "session=newvalue; HttpOnly") + out := redactSensitiveHeaders(h) + Expect(out.Get("Cookie")).To(Equal("[redacted]")) + Expect(out.Get("Set-Cookie")).To(Equal("[redacted]")) + }) + + It("redacts X-Api-Key (and case variants)", func() { + h := http.Header{} + h.Set("X-Api-Key", "key-1") + h.Set("xi-api-key", "key-2") + out := redactSensitiveHeaders(h) + Expect(out.Get("X-Api-Key")).To(Equal("[redacted]")) + Expect(out.Get("Xi-Api-Key")).To(Equal("[redacted]")) + }) + + It("preserves benign headers", func() { + h := http.Header{} + h.Set("Content-Type", "application/json") + h.Set("User-Agent", "curl/8.0") + h.Set("Accept", "*/*") + out := redactSensitiveHeaders(h) + Expect(out.Get("Content-Type")).To(Equal("application/json")) + Expect(out.Get("User-Agent")).To(Equal("curl/8.0")) + Expect(out.Get("Accept")).To(Equal("*/*")) + }) + + It("does not mutate the input header", func() { + h := http.Header{} + h.Set("Authorization", "Bearer abc") + _ = redactSensitiveHeaders(h) + Expect(h.Get("Authorization")).To(Equal("Bearer abc"), + "redactSensitiveHeaders must operate on a clone — caller's header must be untouched") + }) +}) diff --git a/core/http/react-ui/bun.lock b/core/http/react-ui/bun.lock index 66b1dbc1c..91204e725 100644 --- a/core/http/react-ui/bun.lock +++ b/core/http/react-ui/bun.lock @@ -13,6 +13,8 @@ "@codemirror/search": "^6.5.10", "@codemirror/state": "^6.5.2", "@codemirror/view": "^6.36.8", + "@fontsource-variable/geist": "^5.2.8", + "@fontsource-variable/geist-mono": "^5.2.7", "@fortawesome/fontawesome-free": "^6.7.2", "@lezer/highlight": "^1.2.1", "@modelcontextprotocol/ext-apps": "^1.2.2", @@ -169,6 +171,10 @@ "@eslint/plugin-kit": ["@eslint/plugin-kit@0.4.1", "", { "dependencies": { "@eslint/core": "^0.17.0", "levn": "^0.4.1" } }, "sha512-43/qtrDUokr7LJqoF2c3+RInu/t4zfrpYdoSDfYyhg52rwLV6TnOvdG4fXm7IkSB3wErkcmJS9iEhjVtOSEjjA=="], + "@fontsource-variable/geist": ["@fontsource-variable/geist@5.2.8", "", {}, "sha512-cJ6m9e+8MQ5dCYJsLylfZrgBh6KkG4bOLckB35Tr9J/EqdkEM6QllH5PxqP1dhTvFup+HtMRPuz9xOjxXJggxw=="], + + "@fontsource-variable/geist-mono": ["@fontsource-variable/geist-mono@5.2.7", "", {}, "sha512-ZKlZ5sjtalb2TwXKs400mAGDlt/+2ENLNySPx0wTz3bP3mWARCsUW+rpxzZc7e05d2qGch70pItt3K4qttbIYA=="], + "@fortawesome/fontawesome-free": ["@fortawesome/fontawesome-free@6.7.2", "", {}, "sha512-JUOtgFW6k9u4Y+xeIaEiLr3+cjoUPiAuLXoyKOJSia6Duzb7pq+A76P9ZdPDoAoxHdHzq6gE9/jKBGXlZT8FbA=="], "@gulpjs/to-absolute-glob": ["@gulpjs/to-absolute-glob@4.0.0", "", { "dependencies": { "is-negated-glob": "^1.0.0" } }, "sha512-kjotm7XJrJ6v+7knhPaRgaT6q8F8K2jiafwYdNHLzmV0uGLuZY43FK6smNSHUPrhq5kX2slCUy+RGG/xGqmIKA=="], diff --git a/core/http/react-ui/index.html b/core/http/react-ui/index.html index ee1a144cf..4f1c21ef5 100644 --- a/core/http/react-ui/index.html +++ b/core/http/react-ui/index.html @@ -5,9 +5,6 @@ LocalAI - - -
diff --git a/core/http/react-ui/package.json b/core/http/react-ui/package.json index 994a5139a..eceed908a 100644 --- a/core/http/react-ui/package.json +++ b/core/http/react-ui/package.json @@ -21,6 +21,8 @@ "@codemirror/search": "^6.5.10", "@codemirror/state": "^6.5.2", "@codemirror/view": "^6.36.8", + "@fontsource-variable/geist": "^5.2.8", + "@fontsource-variable/geist-mono": "^5.2.7", "@fortawesome/fontawesome-free": "^6.7.2", "@lezer/highlight": "^1.2.1", "@modelcontextprotocol/ext-apps": "^1.2.2", diff --git a/core/http/react-ui/public/locales/de/auth.json b/core/http/react-ui/public/locales/de/auth.json index 2ded3480a..a60238a20 100644 --- a/core/http/react-ui/public/locales/de/auth.json +++ b/core/http/react-ui/public/locales/de/auth.json @@ -10,7 +10,7 @@ "namePlaceholder": "Ihr Name (optional)", "password": "Passwort", "passwordPlaceholder": "Passwort eingeben...", - "newPasswordPlaceholder": "Mindestens 8 Zeichen", + "newPasswordPlaceholder": "Mindestens 12 Zeichen", "confirmPassword": "Passwort bestätigen", "confirmPasswordPlaceholder": "Passwort wiederholen", "inviteCodeLabel": "Einladungscode", @@ -70,7 +70,7 @@ "currentPasswordDescription": "Geben Sie Ihr aktuelles Passwort zur Bestätigung Ihrer Identität ein", "currentPasswordPlaceholder": "Aktuelles Passwort", "newPassword": "Neues Passwort", - "newPasswordDescription": "Muss mindestens 8 Zeichen haben", + "newPasswordDescription": "Muss mindestens 12 Zeichen haben", "newPasswordPlaceholder": "Neues Passwort", "confirmPassword": "Passwort bestätigen", "confirmPasswordDescription": "Geben Sie Ihr neues Passwort erneut ein", @@ -79,7 +79,7 @@ "changing": "Wird geändert...", "changed": "Passwort geändert", "passwordsDoNotMatch": "Passwörter stimmen nicht überein", - "tooShort": "Das neue Passwort muss mindestens 8 Zeichen haben", + "tooShort": "Das neue Passwort muss mindestens 12 Zeichen haben", "oauthOnly": "Passwortverwaltung ist für {{provider}}-Konten nicht verfügbar." }, "apiKeys": { diff --git a/core/http/react-ui/public/locales/en/auth.json b/core/http/react-ui/public/locales/en/auth.json index 69efa39ed..242bb5889 100644 --- a/core/http/react-ui/public/locales/en/auth.json +++ b/core/http/react-ui/public/locales/en/auth.json @@ -10,7 +10,7 @@ "namePlaceholder": "Your name (optional)", "password": "Password", "passwordPlaceholder": "Enter password...", - "newPasswordPlaceholder": "At least 8 characters", + "newPasswordPlaceholder": "At least 12 characters", "confirmPassword": "Confirm Password", "confirmPasswordPlaceholder": "Repeat password", "inviteCodeLabel": "Invite Code", @@ -70,7 +70,7 @@ "currentPasswordDescription": "Enter your existing password to verify your identity", "currentPasswordPlaceholder": "Current password", "newPassword": "New password", - "newPasswordDescription": "Must be at least 8 characters", + "newPasswordDescription": "Must be at least 12 characters", "newPasswordPlaceholder": "New password", "confirmPassword": "Confirm password", "confirmPasswordDescription": "Re-enter your new password", @@ -79,7 +79,7 @@ "changing": "Changing...", "changed": "Password changed", "passwordsDoNotMatch": "Passwords do not match", - "tooShort": "New password must be at least 8 characters", + "tooShort": "New password must be at least 12 characters", "oauthOnly": "Password management is not available for {{provider}} accounts." }, "apiKeys": { diff --git a/core/http/react-ui/public/locales/es/auth.json b/core/http/react-ui/public/locales/es/auth.json index c0bdfaef2..6c867d775 100644 --- a/core/http/react-ui/public/locales/es/auth.json +++ b/core/http/react-ui/public/locales/es/auth.json @@ -10,7 +10,7 @@ "namePlaceholder": "Tu nombre (opcional)", "password": "Contraseña", "passwordPlaceholder": "Introduce la contraseña...", - "newPasswordPlaceholder": "Al menos 8 caracteres", + "newPasswordPlaceholder": "Al menos 12 caracteres", "confirmPassword": "Confirmar contraseña", "confirmPasswordPlaceholder": "Repite la contraseña", "inviteCodeLabel": "Código de invitación", @@ -70,7 +70,7 @@ "currentPasswordDescription": "Introduce tu contraseña actual para verificar tu identidad", "currentPasswordPlaceholder": "Contraseña actual", "newPassword": "Nueva contraseña", - "newPasswordDescription": "Debe tener al menos 8 caracteres", + "newPasswordDescription": "Debe tener al menos 12 caracteres", "newPasswordPlaceholder": "Nueva contraseña", "confirmPassword": "Confirmar contraseña", "confirmPasswordDescription": "Vuelve a introducir tu nueva contraseña", @@ -79,7 +79,7 @@ "changing": "Cambiando...", "changed": "Contraseña cambiada", "passwordsDoNotMatch": "Las contraseñas no coinciden", - "tooShort": "La nueva contraseña debe tener al menos 8 caracteres", + "tooShort": "La nueva contraseña debe tener al menos 12 caracteres", "oauthOnly": "La gestión de contraseña no está disponible para cuentas {{provider}}." }, "apiKeys": { diff --git a/core/http/react-ui/public/locales/it/auth.json b/core/http/react-ui/public/locales/it/auth.json index 38a606e79..a42258ee6 100644 --- a/core/http/react-ui/public/locales/it/auth.json +++ b/core/http/react-ui/public/locales/it/auth.json @@ -10,7 +10,7 @@ "namePlaceholder": "Il tuo nome (opzionale)", "password": "Password", "passwordPlaceholder": "Inserisci la password...", - "newPasswordPlaceholder": "Almeno 8 caratteri", + "newPasswordPlaceholder": "Almeno 12 caratteri", "confirmPassword": "Conferma password", "confirmPasswordPlaceholder": "Ripeti la password", "inviteCodeLabel": "Codice invito", @@ -70,7 +70,7 @@ "currentPasswordDescription": "Inserisci la tua password attuale per verificare la tua identità", "currentPasswordPlaceholder": "Password attuale", "newPassword": "Nuova password", - "newPasswordDescription": "Deve avere almeno 8 caratteri", + "newPasswordDescription": "Deve avere almeno 12 caratteri", "newPasswordPlaceholder": "Nuova password", "confirmPassword": "Conferma password", "confirmPasswordDescription": "Reinserisci la nuova password", @@ -79,7 +79,7 @@ "changing": "Modifica in corso...", "changed": "Password modificata", "passwordsDoNotMatch": "Le password non coincidono", - "tooShort": "La nuova password deve avere almeno 8 caratteri", + "tooShort": "La nuova password deve avere almeno 12 caratteri", "oauthOnly": "La gestione della password non è disponibile per gli account {{provider}}." }, "apiKeys": { diff --git a/core/http/react-ui/public/locales/zh-CN/auth.json b/core/http/react-ui/public/locales/zh-CN/auth.json index 4f3229f50..16233fdf1 100644 --- a/core/http/react-ui/public/locales/zh-CN/auth.json +++ b/core/http/react-ui/public/locales/zh-CN/auth.json @@ -10,7 +10,7 @@ "namePlaceholder": "您的姓名(可选)", "password": "密码", "passwordPlaceholder": "输入密码...", - "newPasswordPlaceholder": "至少 8 个字符", + "newPasswordPlaceholder": "至少 12 个字符", "confirmPassword": "确认密码", "confirmPasswordPlaceholder": "再次输入密码", "inviteCodeLabel": "邀请码", @@ -70,7 +70,7 @@ "currentPasswordDescription": "输入您的现有密码以验证身份", "currentPasswordPlaceholder": "当前密码", "newPassword": "新密码", - "newPasswordDescription": "至少需要 8 个字符", + "newPasswordDescription": "至少需要 12 个字符", "newPasswordPlaceholder": "新密码", "confirmPassword": "确认密码", "confirmPasswordDescription": "再次输入新密码", @@ -79,7 +79,7 @@ "changing": "正在更改...", "changed": "密码已更改", "passwordsDoNotMatch": "密码不匹配", - "tooShort": "新密码至少需要 8 个字符", + "tooShort": "新密码至少需要 12 个字符", "oauthOnly": "{{provider}} 账户不支持密码管理。" }, "apiKeys": { diff --git a/core/http/react-ui/src/components/CanvasPanel.jsx b/core/http/react-ui/src/components/CanvasPanel.jsx index 489b00fe4..14b24700b 100644 --- a/core/http/react-ui/src/components/CanvasPanel.jsx +++ b/core/http/react-ui/src/components/CanvasPanel.jsx @@ -1,6 +1,7 @@ import { useState, useEffect, useRef } from 'react' import { renderMarkdown } from '../utils/markdown' import { getArtifactIcon } from '../utils/artifacts' +import { safeHref } from '../utils/url' import DOMPurify from 'dompurify' import hljs from 'highlight.js' @@ -70,7 +71,7 @@ export default function CanvasPanel({ artifacts, selectedId, onSelect, onClose } return (
) } @@ -78,7 +79,7 @@ export default function CanvasPanel({ artifacts, selectedId, onSelect, onClose } return ( ) } diff --git a/core/http/react-ui/src/main.jsx b/core/http/react-ui/src/main.jsx index c03ce78ad..ec62ed766 100644 --- a/core/http/react-ui/src/main.jsx +++ b/core/http/react-ui/src/main.jsx @@ -7,6 +7,8 @@ import { AuthProvider } from './context/AuthContext' import { router } from './router' import './i18n' import '@fortawesome/fontawesome-free/css/all.min.css' +import '@fontsource-variable/geist' +import '@fontsource-variable/geist-mono' import './index.css' import './theme.css' import './App.css' diff --git a/core/http/react-ui/src/pages/Account.jsx b/core/http/react-ui/src/pages/Account.jsx index f8fb10cf9..8f9f3d028 100644 --- a/core/http/react-ui/src/pages/Account.jsx +++ b/core/http/react-ui/src/pages/Account.jsx @@ -131,6 +131,8 @@ function SecurityTab({ addToast }) { const [newPw, setNewPw] = useState('') const [confirmPw, setConfirmPw] = useState('') const [saving, setSaving] = useState(false) + const [weakWarning, setWeakWarning] = useState('') + const [acknowledgeWeak, setAcknowledgeWeak] = useState(false) const handleSubmit = async (e) => { e.preventDefault() @@ -138,19 +140,21 @@ function SecurityTab({ addToast }) { addToast(t('account.security.passwordsDoNotMatch'), 'error') return } - if (newPw.length < 8) { - addToast(t('account.security.tooShort'), 'error') - return - } setSaving(true) try { - await profileApi.changePassword(currentPw, newPw) + await profileApi.changePassword(currentPw, newPw, acknowledgeWeak) addToast(t('account.security.changed'), 'success') setCurrentPw('') setNewPw('') setConfirmPw('') + setWeakWarning('') + setAcknowledgeWeak(false) } catch (err) { - addToast(err.message, 'error') + if (err.body?.overridable) { + setWeakWarning(err.body.error || err.message) + } else { + addToast(err.message, 'error') + } } finally { setSaving(false) } @@ -186,9 +190,12 @@ function SecurityTab({ addToast }) { type="password" className="input account-input-sm" value={newPw} - onChange={(e) => setNewPw(e.target.value)} + onChange={(e) => { + setNewPw(e.target.value) + setWeakWarning('') + setAcknowledgeWeak(false) + }} placeholder={t('account.security.newPasswordPlaceholder')} - minLength={8} disabled={saving} required /> @@ -204,6 +211,22 @@ function SecurityTab({ addToast }) { required /> + {weakWarning && ( + +
+
{weakWarning}
+ +
+
+ )}
diff --git a/core/http/react-ui/src/utils/api.js b/core/http/react-ui/src/utils/api.js index 55f8326e8..78f0b4f68 100644 --- a/core/http/react-ui/src/utils/api.js +++ b/core/http/react-ui/src/utils/api.js @@ -450,8 +450,10 @@ export const adminUsersApi = { deleteQuota: (id, quotaId) => fetchJSON(`/api/auth/admin/users/${encodeURIComponent(id)}/quotas/${encodeURIComponent(quotaId)}`, { method: 'DELETE', }), - resetPassword: (id, password) => fetchJSON(`/api/auth/admin/users/${encodeURIComponent(id)}/password`, { - method: 'PUT', body: JSON.stringify({ password }), headers: { 'Content-Type': 'application/json' }, + resetPassword: (id, password, acknowledgeWeak = false) => fetchJSON(`/api/auth/admin/users/${encodeURIComponent(id)}/password`, { + method: 'PUT', + body: JSON.stringify({ password, acknowledge_weak_password: acknowledgeWeak }), + headers: { 'Content-Type': 'application/json' }, }), } @@ -464,8 +466,9 @@ export const profileApi = { updateProfile: (name, avatarUrl) => fetchJSON('/api/auth/profile', { method: 'PUT', body: JSON.stringify({ name, avatar_url: avatarUrl || '' }), headers: { 'Content-Type': 'application/json' }, }), - changePassword: (currentPassword, newPassword) => fetchJSON('/api/auth/password', { - method: 'PUT', body: JSON.stringify({ current_password: currentPassword, new_password: newPassword }), + changePassword: (currentPassword, newPassword, acknowledgeWeak = false) => fetchJSON('/api/auth/password', { + method: 'PUT', + body: JSON.stringify({ current_password: currentPassword, new_password: newPassword, acknowledge_weak_password: acknowledgeWeak }), headers: { 'Content-Type': 'application/json' }, }), } diff --git a/core/http/react-ui/src/utils/mcpClientStorage.js b/core/http/react-ui/src/utils/mcpClientStorage.js index 2e5511504..4abd5dbd0 100644 --- a/core/http/react-ui/src/utils/mcpClientStorage.js +++ b/core/http/react-ui/src/utils/mcpClientStorage.js @@ -2,12 +2,52 @@ import { generateId } from './format' const STORAGE_KEY = 'localai_client_mcp_servers' +// localStorage is shared across same-origin pages; an XSS that lands once can +// poison persisted MCP server entries to attempt header injection or to feed +// a non-http URL into the fetch path. Validate every entry on load and drop +// anything that doesn't match the expected shape. +function sanitiseServer(s) { + if (!s || typeof s !== 'object') return null + const id = typeof s.id === 'string' ? s.id : '' + const name = typeof s.name === 'string' ? s.name : '' + const url = typeof s.url === 'string' ? s.url : '' + if (!url) return null + // fetch() refuses non-http schemes anyway, but reject early so they + // can't get persisted back out. URL parsing also catches malformed values. + let parsed + try { + parsed = new URL(url) + } catch { + return null + } + if (parsed.protocol !== 'http:' && parsed.protocol !== 'https:') return null + + const headers = {} + if (s.headers && typeof s.headers === 'object' && !Array.isArray(s.headers)) { + for (const [k, v] of Object.entries(s.headers)) { + // Drop CRLF / control chars to block header injection through poisoned storage. + if (typeof k !== 'string' || typeof v !== 'string') continue + if (/[\x00-\x1f\x7f]/.test(k) || /[\x00-\x1f\x7f]/.test(v)) continue + headers[k] = v + } + } + return { + id, + name, + url, + headers, + useProxy: s.useProxy !== false, + } +} + export function loadClientMCPServers() { try { const stored = localStorage.getItem(STORAGE_KEY) if (stored) { const data = JSON.parse(stored) - if (Array.isArray(data)) return data + if (Array.isArray(data)) { + return data.map(sanitiseServer).filter(Boolean) + } } } catch (_e) { // ignore diff --git a/core/http/react-ui/src/utils/url.js b/core/http/react-ui/src/utils/url.js new file mode 100644 index 000000000..e7eb4ba4d --- /dev/null +++ b/core/http/react-ui/src/utils/url.js @@ -0,0 +1,31 @@ +// safeHref returns the input URL only if its scheme is on a small allowlist +// (http, https, mailto, tel) or if it's a relative/anchor link. Anything else +// — most importantly `javascript:` and `data:` — collapses to '#'. Use this +// for any whose URL comes from gallery JSON, agent tool calls, +// or any other source the operator hasn't fully vetted. +// +// React already escapes attribute values, so the only XSS path on a hyperlink +// is the URI itself. javascript: in is inert in modern browsers, +// but still fires on click. +const ALLOWED_SCHEMES = ['http:', 'https:', 'mailto:', 'tel:'] + +export function safeHref(url) { + if (typeof url !== 'string' || url === '') return '#' + const trimmed = url.trim() + if (trimmed === '') return '#' + // Relative paths, fragment links, and protocol-relative URLs are fine. + if (trimmed.startsWith('/') || trimmed.startsWith('#') || trimmed.startsWith('?')) { + return trimmed + } + if (trimmed.startsWith('//')) return trimmed + // Heuristic: if there's no colon before the first slash, it's a relative path. + const colonIdx = trimmed.indexOf(':') + if (colonIdx === -1) return trimmed + const slashIdx = trimmed.indexOf('/') + if (slashIdx !== -1 && slashIdx < colonIdx) return trimmed + // There is a scheme — allowlist-check it. Browsers ignore tabs/newlines + // inside the scheme (`java\tscript:...`), so we strip control chars first. + const scheme = trimmed.slice(0, colonIdx).toLowerCase().replace(/[\x00-\x1f]/g, '') + if (ALLOWED_SCHEMES.includes(scheme + ':')) return trimmed + return '#' +} diff --git a/core/http/route_coverage_test.go b/core/http/route_coverage_test.go new file mode 100644 index 000000000..aced87abb --- /dev/null +++ b/core/http/route_coverage_test.go @@ -0,0 +1,225 @@ +//go:build auth + +package http_test + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/config" + . "github.com/mudler/LocalAI/core/http" + "github.com/mudler/LocalAI/pkg/system" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// Every API-prefixed route registered by API() must either reject anonymous +// traffic with 401 or appear on the explicit public allowlist below. The +// test fails on routes that ship without an auth decision; adding a new +// public surface should be deliberate, not a side effect. +var _ = Describe("Route auth coverage", func() { + var ( + app *echo.Echo + tmpdir string + c context.Context + cancel context.CancelFunc + appInst *application.Application + ) + + BeforeEach(func() { + var err error + tmpdir, err = os.MkdirTemp("", "route-coverage-") + Expect(err).ToNot(HaveOccurred()) + + modelDir := filepath.Join(tmpdir, "models") + Expect(os.Mkdir(modelDir, 0750)).To(Succeed()) + bDir := filepath.Join(tmpdir, "backends") + Expect(os.Mkdir(bDir, 0750)).To(Succeed()) + + c, cancel = context.WithCancel(context.Background()) + + systemState, err := system.GetSystemState( + system.WithBackendPath(bDir), + system.WithModelPath(modelDir), + ) + Expect(err).ToNot(HaveOccurred()) + + // Auth enabled, no legacy keys, no admin user pre-created. With auth + // enabled the global middleware MUST reject anonymous API requests + // regardless of admin presence. + appInst, err = application.New( + config.WithContext(c), + config.WithSystemState(systemState), + config.WithAuthEnabled(true), + config.WithAuthDatabaseURL(":memory:"), + config.WithAuthAPIKeyHMACSecret("test-secret-for-route-coverage"), + ) + Expect(err).ToNot(HaveOccurred()) + + app, err = API(appInst) + Expect(err).ToNot(HaveOccurred()) + }) + + AfterEach(func() { + cancel() + Expect(os.RemoveAll(tmpdir)).To(Succeed()) + }) + + It("rejects anonymous traffic on every API route except the explicit allowlist", func() { + // Routes that are intentionally reachable without authentication. + // Each entry needs a justification comment. + expectedPublicPrefixes := []string{ + // Auth flow itself — login, registration, OAuth callbacks + "/api/auth/", + + // Distributed-mode node self-service: authenticated by a + // registration token presented in the request body, not by the + // global auth middleware. Verified separately in node tests. + "/api/node/register", + + // Branding read for login screen + public branding asset server. + // Mutating /api/branding/asset/* routes are exempt from global + // auth but admin-gated by route-level middleware + // (TestBrandingRoutes_AdminGatingHolds pins that contract). + "/api/branding", + "/branding/", + + // Health and metadata used by orchestrators / load balancers + "/healthz", + "/readyz", + + // Static asset surfaces used by the UI before login + "/favicon.svg", + "/static/", + "/assets/", + "/locales/", + "/generated-audio/", + "/generated-images/", + "/generated-videos/", + } + expectedPublicExact := map[string]bool{ + // SPA shell + redirects — UI handles login client-side + "/": true, + "/app": true, + "/browse": true, + "/swagger": true, + "/swagger/": true, + "/swagger/*": true, + "/oauth/start": true, + } + + // Per-route exemptions for distributed-mode node endpoints whose + // pattern carries a path param (so they don't fit a flat prefix). + // These authenticate via registration token at the handler layer. + nodeSelfPattern := regexp.MustCompile(`^/api/node/[^/]+/(heartbeat|drain|deregister)$`) + + // Concretize a route pattern into a URL suitable for httptest. + // Echo path params come back as ":name" and wildcards as "*". + concretize := func(pattern string) string { + parts := strings.Split(pattern, "/") + for i, p := range parts { + if strings.HasPrefix(p, ":") { + parts[i] = "test" + } else if p == "*" { + parts[i] = "test" + } + } + return strings.Join(parts, "/") + } + + isAPI := func(p string) bool { + apiPrefixes := []string{ + "/api/", "/v1/", "/models/", "/backends/", "/backend/", + "/tts", "/vad", "/video", "/stores/", "/system", + "/ws/", "/generated-", "/chat/", "/completions", + "/edits", "/embeddings", "/audio/", "/images/", + "/messages", "/responses", + } + if p == "/metrics" { + return true + } + for _, pre := range apiPrefixes { + if strings.HasPrefix(p, pre) { + return true + } + } + return false + } + + isAllowlisted := func(p string) bool { + if expectedPublicExact[p] { + return true + } + for _, pre := range expectedPublicPrefixes { + if strings.HasPrefix(p, pre) { + return true + } + } + if nodeSelfPattern.MatchString(p) { + return true + } + return false + } + + var leaks []string + seen := map[string]bool{} + for _, r := range app.Routes() { + // Echo registers automatic HEAD routes for GETs; auth check is + // identical, so dedupe. + key := r.Method + " " + r.Path + if seen[key] { + continue + } + seen[key] = true + + // Only inspect API surface — UI/static paths are intentionally + // reachable for SPA hydration before login. + if !isAPI(r.Path) { + continue + } + if isAllowlisted(r.Path) { + continue + } + + req := httptest.NewRequest(r.Method, concretize(r.Path), nil) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + + // We accept: + // 401: middleware rejected — the desired outcome + // 404: route not actually reachable in this minimal config + // (e.g. distributed-only routes); not an auth leak + // 405: method not allowed; auth never ran but isn't a leak + if rec.Code == http.StatusUnauthorized || + rec.Code == http.StatusNotFound || + rec.Code == http.StatusMethodNotAllowed { + continue + } + + leaks = append(leaks, " "+r.Method+" "+r.Path+ + " → "+http.StatusText(rec.Code)+ + " (got "+strconv.Itoa(rec.Code)+")") + } + + if len(leaks) > 0 { + Fail("Routes reachable without authentication:\n" + + strings.Join(leaks, "\n") + + "\n\nIf the route is intentionally public, add it to " + + "expectedPublicPrefixes or expectedPublicExact in " + + "core/http/route_coverage_test.go with a justification " + + "comment. Otherwise, gate it behind the auth middleware " + + "(automatic for /api/, /v1/, /models/, /backends/, etc.) " + + "or RequireAdmin / RequireFeature.") + } + }) +}) diff --git a/core/http/routes/auth.go b/core/http/routes/auth.go index 7d74b350c..3f42adbf8 100644 --- a/core/http/routes/auth.go +++ b/core/http/routes/auth.go @@ -190,6 +190,13 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) { authRL := newRateLimiter(1*time.Minute, 5) authRateLimitMw := rateLimitMiddleware(authRL) + // Separate, more permissive limiter for OAuth/OIDC callbacks. Corporate + // SSO often funnels many real users through one outbound IP, so the 5/min + // password-style cap is too tight here; 60/min still bounds a flood that + // would otherwise pin token-exchange traffic to the IdP. + oauthRL := newRateLimiter(1*time.Minute, 60) + oauthRateLimitMw := rateLimitMiddleware(oauthRL) + // Start background goroutine to periodically prune stale IP entries go func() { ticker := time.NewTicker(10 * time.Minute) @@ -200,6 +207,7 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) { return case <-ticker.C: authRL.cleanup() + oauthRL.cleanup() } } }() @@ -274,13 +282,13 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) { e.GET("/api/auth/github/login", oauthMgr.LoginHandler(auth.ProviderGitHub)) e.GET("/api/auth/github/callback", oauthMgr.CallbackHandler( auth.ProviderGitHub, db, appConfig.Auth.AdminEmail, appConfig.Auth.RegistrationMode, appConfig.Auth.APIKeyHMACSecret, - )) + ), oauthRateLimitMw) } if appConfig.Auth.OIDCClientID != "" { e.GET("/api/auth/oidc/login", oauthMgr.LoginHandler(auth.ProviderOIDC)) e.GET("/api/auth/oidc/callback", oauthMgr.CallbackHandler( auth.ProviderOIDC, db, appConfig.Auth.AdminEmail, appConfig.Auth.RegistrationMode, appConfig.Auth.APIKeyHMACSecret, - )) + ), oauthRateLimitMw) } } } @@ -292,10 +300,11 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) { } var body struct { - Email string `json:"email"` - Password string `json:"password"` - Name string `json:"name"` - InviteCode string `json:"inviteCode"` + Email string `json:"email"` + Password string `json:"password"` + Name string `json:"name"` + InviteCode string `json:"inviteCode"` + AcknowledgeWeakPassword bool `json:"acknowledge_weak_password"` } if err := c.Bind(&body); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request"}) @@ -311,8 +320,8 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) { if _, err := mail.ParseAddress(body.Email); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid email address"}) } - if len(body.Password) < 8 { - return c.JSON(http.StatusBadRequest, map[string]string{"error": "password must be at least 8 characters"}) + if err := auth.ValidatePasswordStrength(body.Password, auth.PasswordPolicy{AllowWeak: body.AcknowledgeWeakPassword}); err != nil { + return c.JSON(http.StatusBadRequest, auth.PasswordError(err)) } hash, err := auth.HashPassword(body.Password) @@ -579,8 +588,9 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) { } var body struct { - CurrentPassword string `json:"current_password"` - NewPassword string `json:"new_password"` + CurrentPassword string `json:"current_password"` + NewPassword string `json:"new_password"` + AcknowledgeWeakPassword bool `json:"acknowledge_weak_password"` } if err := c.Bind(&body); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request"}) @@ -590,8 +600,8 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) { return c.JSON(http.StatusBadRequest, map[string]string{"error": "current and new passwords are required"}) } - if len(body.NewPassword) < 8 { - return c.JSON(http.StatusBadRequest, map[string]string{"error": "new password must be at least 8 characters"}) + if err := auth.ValidatePasswordStrength(body.NewPassword, auth.PasswordPolicy{AllowWeak: body.AcknowledgeWeakPassword}); err != nil { + return c.JSON(http.StatusBadRequest, auth.PasswordError(err)) } // Verify current password @@ -900,14 +910,15 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) { } var body struct { - Password string `json:"password"` + Password string `json:"password"` + AcknowledgeWeakPassword bool `json:"acknowledge_weak_password"` } if err := c.Bind(&body); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request"}) } - if len(body.Password) < 8 { - return c.JSON(http.StatusBadRequest, map[string]string{"error": "password must be at least 8 characters"}) + if err := auth.ValidatePasswordStrength(body.Password, auth.PasswordPolicy{AllowWeak: body.AcknowledgeWeakPassword}); err != nil { + return c.JSON(http.StatusBadRequest, auth.PasswordError(err)) } hash, err := auth.HashPassword(body.Password) diff --git a/core/http/routes/auth_profile_test.go b/core/http/routes/auth_profile_test.go new file mode 100644 index 000000000..a80582abf --- /dev/null +++ b/core/http/routes/auth_profile_test.go @@ -0,0 +1,165 @@ +//go:build auth + +package routes_test + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/application" + "github.com/mudler/LocalAI/core/config" + authpkg "github.com/mudler/LocalAI/core/http/auth" + . "github.com/mudler/LocalAI/core/http" + "github.com/mudler/LocalAI/pkg/system" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// PUT /api/auth/profile must accept ONLY name and avatar_url. A user that +// posts {"role":"admin","email":"...","status":"active","password_hash":"..."} +// must not be able to mutate any of those fields. The current handler uses +// an explicit local body struct + gorm Updates(map) with a column allowlist; +// this test pins that contract so a future refactor (e.g. c.Bind(&user)) +// can't silently regress to mass-assignment. +var _ = Describe("PUT /api/auth/profile field-tampering", func() { + var ( + app *echo.Echo + appCtx context.Context + cancel context.CancelFunc + tmpdir string + alice authpkg.User + appAt *application.Application + ) + + BeforeEach(func() { + var err error + tmpdir, err = os.MkdirTemp("", "profile-tamper-") + Expect(err).ToNot(HaveOccurred()) + + modelDir := filepath.Join(tmpdir, "models") + Expect(os.Mkdir(modelDir, 0750)).To(Succeed()) + bDir := filepath.Join(tmpdir, "backends") + Expect(os.Mkdir(bDir, 0750)).To(Succeed()) + + appCtx, cancel = context.WithCancel(context.Background()) + + systemState, err := system.GetSystemState( + system.WithBackendPath(bDir), + system.WithModelPath(modelDir), + ) + Expect(err).ToNot(HaveOccurred()) + + appAt, err = application.New( + config.WithContext(appCtx), + config.WithSystemState(systemState), + config.WithAuthEnabled(true), + config.WithAuthDatabaseURL(":memory:"), + config.WithAuthAPIKeyHMACSecret("test-secret-profile-tamper"), + ) + Expect(err).ToNot(HaveOccurred()) + + app, err = API(appAt) + Expect(err).ToNot(HaveOccurred()) + + // Seed a non-admin user directly in the DB. + alice = authpkg.User{ + ID: "alice-id", + Email: "alice@example.com", + Name: "Alice", + Provider: authpkg.ProviderLocal, + PasswordHash: "$2a$10$abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUV", + Role: authpkg.RoleUser, + Status: "active", + } + Expect(appAt.AuthDB().Create(&alice).Error).To(Succeed()) + }) + + AfterEach(func() { + cancel() + Expect(os.RemoveAll(tmpdir)).To(Succeed()) + }) + + // Mint a real API key for alice once per spec — the auth middleware + // resolves Bearer tokens against the DB, which is the same path a + // browser session takes after extraction. This avoids forging session + // cookies and exercises the production auth flow end-to-end. + callProfile := func(body map[string]any) (int, map[string]any) { + key, _, err := authpkg.CreateAPIKey(appAt.AuthDB(), alice.ID, "test", authpkg.RoleUser, "test-secret-profile-tamper", nil) + Expect(err).ToNot(HaveOccurred()) + + raw, err := json.Marshal(body) + Expect(err).ToNot(HaveOccurred()) + req := httptest.NewRequest(http.MethodPut, "/api/auth/profile", bytes.NewReader(raw)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Sec-Fetch-Site", "same-origin") + req.Header.Set("Authorization", "Bearer "+key) + rec := httptest.NewRecorder() + app.ServeHTTP(rec, req) + var out map[string]any + _ = json.Unmarshal(rec.Body.Bytes(), &out) + return rec.Code, out + } + + It("ignores attempts to set role: admin", func() { + code, _ := callProfile(map[string]any{ + "name": "Alice 2", + "avatar_url": "https://example.com/a.png", + "role": "admin", + }) + Expect(code).To(Equal(http.StatusOK)) + + var fresh authpkg.User + Expect(appAt.AuthDB().First(&fresh, "id = ?", alice.ID).Error).To(Succeed()) + Expect(fresh.Role).To(Equal(authpkg.RoleUser), + "role must remain RoleUser; mass-assignment of role is forbidden") + Expect(fresh.Name).To(Equal("Alice 2")) + Expect(fresh.AvatarURL).To(Equal("https://example.com/a.png")) + }) + + It("ignores attempts to override email, status, password_hash, provider, and id", func() { + body := map[string]any{ + "name": "Alice 3", + "avatar_url": "", + "email": "attacker@example.com", + "status": "frozen", + "password_hash": "$2a$10$attacker-controlled-hash", + "provider": "github", + "id": "different-id", + } + code, _ := callProfile(body) + Expect(code).To(Equal(http.StatusOK)) + + var fresh authpkg.User + Expect(appAt.AuthDB().First(&fresh, "id = ?", alice.ID).Error).To(Succeed()) + Expect(fresh.Email).To(Equal(alice.Email), "email must be immutable through profile update") + Expect(fresh.Status).To(Equal(alice.Status), "status must be immutable") + Expect(fresh.PasswordHash).To(Equal(alice.PasswordHash), "password_hash must be immutable") + Expect(fresh.Provider).To(Equal(alice.Provider), "provider must be immutable") + Expect(fresh.ID).To(Equal(alice.ID), "id must be immutable") + }) + + It("requires a non-empty name", func() { + code, body := callProfile(map[string]any{"name": ""}) + Expect(code).To(Equal(http.StatusBadRequest)) + Expect(body["error"]).To(ContainSubstring("name is required")) + }) + + It("rejects oversized avatar_url", func() { + long := make([]byte, 1024) + for i := range long { + long[i] = 'x' + } + code, _ := callProfile(map[string]any{ + "name": "Alice", + "avatar_url": "https://" + string(long), + }) + Expect(code).To(Equal(http.StatusBadRequest)) + }) +}) diff --git a/core/services/agentpool/agent_pool.go b/core/services/agentpool/agent_pool.go index 41178c589..9216add9b 100644 --- a/core/services/agentpool/agent_pool.go +++ b/core/services/agentpool/agent_pool.go @@ -672,12 +672,18 @@ func (s *AgentPoolService) ensureCollectionForUser(userID, name string) error { } // UploadToCollectionForUser uploads to a collection for a specific user. +// The filename arrives from a multipart upload; the vendored backend may or +// may not sanitise it, so strip any directory components at the boundary. func (s *AgentPoolService) UploadToCollectionForUser(userID, collection, filename string, fileBody io.Reader) (string, error) { backend, err := s.CollectionsBackendForUser(userID) if err != nil { return "", err } - return backend.Upload(collection, filename, fileBody) + base := filepath.Base(filename) + if base == "." || base == ".." || base == "/" || base == "" { + return "", fmt.Errorf("invalid filename") + } + return backend.Upload(collection, base, fileBody) } // CollectionEntryExistsForUser checks if an entry exists in a user's collection. diff --git a/core/services/finetune/service.go b/core/services/finetune/service.go index d64e5ad42..84d50d80e 100644 --- a/core/services/finetune/service.go +++ b/core/services/finetune/service.go @@ -731,13 +731,19 @@ func (s *FineTuneService) setExportFailed(job *schema.FineTuneJob, message strin } // UploadDataset handles dataset file upload and returns the local path. +// The filename comes straight from a multipart upload, so strip directory +// components — anything else risks a write under a sibling path. func (s *FineTuneService) UploadDataset(filename string, data []byte) (string, error) { uploadDir := filepath.Join(s.fineTuneBaseDir(), "datasets") if err := os.MkdirAll(uploadDir, 0750); err != nil { return "", fmt.Errorf("failed to create dataset directory: %w", err) } - filePath := filepath.Join(uploadDir, uuid.New().String()[:8]+"-"+filename) + base := filepath.Base(filename) + if base == "." || base == ".." || base == "/" || base == "" { + return "", fmt.Errorf("invalid filename") + } + filePath := filepath.Join(uploadDir, uuid.New().String()[:8]+"-"+base) if err := os.WriteFile(filePath, data, 0640); err != nil { return "", fmt.Errorf("failed to write dataset: %w", err) } diff --git a/go.mod b/go.mod index 6db79faa9..bcb67f933 100644 --- a/go.mod +++ b/go.mod @@ -61,6 +61,7 @@ require ( github.com/testcontainers/testcontainers-go v0.42.0 github.com/testcontainers/testcontainers-go/modules/nats v0.42.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.42.0 + github.com/timbutler/zxcvbn v1.0.4 go.opentelemetry.io/otel v1.43.0 go.opentelemetry.io/otel/exporters/prometheus v0.65.0 go.opentelemetry.io/otel/metric v1.43.0 diff --git a/go.sum b/go.sum index 03f950be0..ba6f4f7ce 100644 --- a/go.sum +++ b/go.sum @@ -1144,6 +1144,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/timbutler/zxcvbn v1.0.4 h1:nTUa8UpLhIxhUBag42fQcwiC8AtTxNVbQMbmxyxLfXg= +github.com/timbutler/zxcvbn v1.0.4/go.mod h1:Cl20mGFz9+SXvTRebBcwMUDqZUvCfSnb+XMznbTKo2U= github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA= github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI= github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw= diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 09023f3e2..6199e6ce3 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -272,15 +272,11 @@ func (s URI) ResolveURL() string { } func removePartialFile(tmpFilePath string) error { - _, err := os.Stat(tmpFilePath) - if err == nil { - xlog.Debug("Removing temporary file", "file", tmpFilePath) - err = os.Remove(tmpFilePath) - if err != nil { - err1 := fmt.Errorf("failed to remove temporary download file %s: %v", tmpFilePath, err) - xlog.Warn("failed to remove temporary download file", "error", err1) - return err1 - } + xlog.Debug("Removing temporary file", "file", tmpFilePath) + if err := os.Remove(tmpFilePath); err != nil && !errors.Is(err, os.ErrNotExist) { + err1 := fmt.Errorf("failed to remove temporary download file %s: %v", tmpFilePath, err) + xlog.Warn("failed to remove temporary download file", "error", err1) + return err1 } return nil } @@ -578,20 +574,28 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string default: } - err = os.Rename(tmpFilePath, filePath) - if err != nil { - return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err) - } - + // Invariant: verify the streamed hash before promoting the temp file to + // the final path. Renaming first would leave tampered content reachable + // to subsequent readers even though we return an error. if sha != "" { - // Verify SHA calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil)) if calculatedSHA != sha { xlog.Debug("SHA mismatch for file", "file", filePath, "calculated", calculatedSHA, "metadata", sha) + _ = removePartialFile(tmpFilePath) return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha) } } else { - xlog.Debug("SHA missing. Skipping validation", "file", filePath) + // Visible at the default log level so missing-digest configs are + // noticed; silent acceptance was the historical bug. + xlog.Warn("downloading without integrity check — supplied SHA is empty", + "file", filePath, + "url", url, + ) + } + + err = os.Rename(tmpFilePath, filePath) + if err != nil { + return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err) } xlog.Info("File downloaded and verified", "file", filePath) diff --git a/pkg/downloader/uri_test.go b/pkg/downloader/uri_test.go index 3d2d3cbcb..0a25f82a1 100644 --- a/pkg/downloader/uri_test.go +++ b/pkg/downloader/uri_test.go @@ -295,6 +295,49 @@ var _ = Describe("Download Test", func() { err = uri.DownloadFile(filePath, mockDataSha, 1, 1, func(s1, s2, s3 string, f float64) {}) Expect(err).ToNot(HaveOccurred()) }) + + // A file that fails its SHA check must not be usable. The historical + // implementation renamed the temp file to its final path *before* + // verifying the hash, so a mismatch returned an error but left a + // tampered file at the destination — the next caller (e.g. a backend + // launcher) could pick it up and run with it. + It("does not leave a corrupted file at the destination on SHA mismatch", func() { + mockServer := getMockServer(true) + defer mockServer.Close() + uri := URI(mockServer.URL) + + // Use a clearly-wrong expected SHA; the server will return real + // data with a different hash. + wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000" + err := uri.DownloadFile(filePath, wrongSHA, 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("SHA")) + + // The file must not exist at the final destination. + _, statErr := os.Stat(filePath) + Expect(os.IsNotExist(statErr)).To(BeTrue(), + "download with wrong SHA left a file at %s — a subsequent caller could load tampered content", filePath) + }) + + // A download without an expected digest is a supply-chain footgun. + // The downloader allows it (backend installs pass through here + // today and don't yet ship a digest) but it is the caller's + // responsibility to know when integrity is required. The downloader + // emits a WARN log on every empty-digest download to make this + // visible at the default log level. + It("succeeds with empty SHA but emits an integrity warning", func() { + mockServer := getMockServer(true) + defer mockServer.Close() + uri := URI(mockServer.URL) + + // No assertion on logs (we don't capture xlog output here), + // but the call must succeed so existing backend installs do + // not regress. + err := uri.DownloadFile(filePath, "", 1, 1, func(s1, s2, s3 string, f float64) {}) + Expect(err).ToNot(HaveOccurred()) + _, statErr := os.Stat(filePath) + Expect(statErr).ToNot(HaveOccurred()) + }) }) AfterEach(func() { diff --git a/pkg/utils/urlfetch.go b/pkg/utils/urlfetch.go index d32a1ba0a..072c07b21 100644 --- a/pkg/utils/urlfetch.go +++ b/pkg/utils/urlfetch.go @@ -49,7 +49,7 @@ func ValidateExternalURL(rawURL string) error { return fmt.Errorf("unable to parse resolved IP: %s", ipStr) } - if !isPublicIP(ip) { + if !IsPublicIP(ip) { return fmt.Errorf("requests to internal network addresses are not allowed") } } @@ -57,7 +57,11 @@ func ValidateExternalURL(rawURL string) error { return nil } -func isPublicIP(ip net.IP) bool { +// IsPublicIP reports whether ip refers to a host on the public internet, i.e. +// not loopback, link-local, private (RFC 1918 / RFC 4193), or unspecified. +// Covers 0.0.0.0/8, ::/128, and IPv4-mapped IPv6 wrapping a private address — +// holes hand-rolled CIDR lists tend to miss. +func IsPublicIP(ip net.IP) bool { if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() || @@ -66,7 +70,6 @@ func isPublicIP(ip net.IP) bool { return false } - // Block IPv4-mapped IPv6 addresses that wrap private IPv4 if ip4 := ip.To4(); ip4 != nil { return !ip4.IsLoopback() && !ip4.IsLinkLocalUnicast() &&