mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-16 20:52:08 -04:00
chore: Security hardening (#9719)
* fix(http): close 0.0.0.0/[::] SSRF bypass in /api/cors-proxy The CORS proxy carried its own private-network blocklist (RFC 1918 + a handful of IPv6 ranges) instead of using the same classification as pkg/utils/urlfetch.go. The hand-rolled list missed 0.0.0.0/8 and ::/128, both of which Linux routes to localhost — so any user with FeatureMCP (default-on for new users) could reach LocalAI's own listener and any other service bound to 0.0.0.0:port via: GET /api/cors-proxy?url=http://0.0.0.0:8080/... GET /api/cors-proxy?url=http://[::]:8080/... Replace the custom check with utils.IsPublicIP (Go stdlib IsLoopback / IsLinkLocalUnicast / IsPrivate / IsUnspecified, plus IPv4-mapped IPv6 unmasking) and add an upfront hostname rejection for localhost, *.local, and the cloud metadata aliases so split-horizon DNS can't paper over the IP check. The IP-pinning DialContext is unchanged: the validated IP from the single resolution is reused for the connection, so DNS rebinding still cannot swap a public answer for a private one between validate and dial. Regression tests cover 0.0.0.0, 0.0.0.0:PORT, [::], ::ffff:127.0.0.1, ::ffff:10.0.0.1, file://, gopher://, ftp://, localhost, 127.0.0.1, 10.0.0.1, 169.254.169.254, metadata.google.internal. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(downloader): verify SHA before promoting temp file to final path DownloadFileWithContext renamed the .partial file to its final name *before* checking the streamed SHA, so a hash mismatch returned an error but left the tampered file at filePath. Subsequent code that operated on filePath (a backend launcher, a YAML loader, a re-download that finds the file already present and skips) would consume the attacker-supplied bytes. Reorder: verify the streamed hash first, remove the .partial on mismatch, then rename. The streamed hash is computed during io.Copy so no second read is needed. While here, raise the empty-SHA case from a Debug log to a Warn so "this download had no integrity check" is visible at the default log level. Backend installs currently pass through with no digest; the warning makes that footprint observable without changing behaviour. Regression test asserts os.IsNotExist on the destination after a deliberate SHA mismatch. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(auth): require email_verified for OIDC admin promotion extractOIDCUserInfo read the ID token's "email" claim but never inspected "email_verified". With LOCALAI_ADMIN_EMAIL set, an attacker who could register on the configured OIDC IdP under that email (some IdPs accept self-supplied unverified emails) inherited admin role: - first login: AssignRole(tx, email, adminEmail) → RoleAdmin - re-login: MaybePromote(db, user, adminEmail) → flip to RoleAdmin Add EmailVerified to oauthUserInfo, parse email_verified from the OIDC claims (default false on absence so an IdP that omits the claim cannot short-circuit the gate), and substitute "" for the role-decision email when verified=false via emailForRoleDecision. The user record still stores the unverified email for display. GitHub's path defaults EmailVerified=true: GitHub only returns a public profile email after verification, and fetchGitHubPrimaryEmail explicitly filters to Verified=true. Regression tests cover both the helper contract and integration with AssignRole, including the bootstrap "first user" branch that would otherwise mask the gate. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(cli): refuse public bind when no auth backend is configured When neither an auth DB nor a static API key is set, the auth middleware passes every request through. That is fine for a developer laptop, a home LAN, or a Tailnet — the network itself is the trust boundary. It is not fine on a public IP, where every model install, settings change, and admin endpoint becomes reachable from the internet. Refuse to start in that exact configuration. Loopback, RFC 1918, RFC 4193 ULA, link-local, and RFC 6598 CGNAT (Tailscale's default range) all count as trusted; wildcard binds (`:port`, `0.0.0.0`, `[::]`) are accepted only when every host interface is in one of those ranges. Hostnames are resolved and treated as trusted only when every answer is. A new --allow-insecure-public-bind / LOCALAI_ALLOW_INSECURE_PUBLIC_BIND flag opts out for deployments that gate access externally (a reverse proxy enforcing auth, a mesh ACL, etc.). The error message lists this plus the three constructive alternatives (bind a private interface, enable --auth, set --api-keys). The interface enumeration goes through a package-level interfaceAddrsFn var so tests can simulate cloud-VM, home-LAN, Tailscale-only, and enumeration-failure topologies without poking at the real network stack. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * test(http): regression-test the localai_assistant admin gate ChatEndpoint already rejects metadata.localai_assistant=true from a non-admin caller, but the gate was open-coded inline with no direct test coverage. The chat route is FeatureChat-gated (default-on), and the assistant's in-process MCP server can install/delete models and edit configs — the wrong handler change would silently turn the LLM into a confused deputy. Extract the gate into requireAssistantAccess(c, authEnabled) and pin its behaviour: auth disabled is a no-op, unauthenticated is 403, RoleUser is 403, RoleAdmin and the synthetic legacy-key admin are admitted. No behaviour change in the production path. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * test(http): assert every API route is auth-classified The auth middleware classifies path prefixes (/api/, /v1/, /models/, etc.) as protected and treats anything else as a static-asset passthrough. A new endpoint shipped under a brand-new prefix — or a new path that simply isn't on the prefix allowlist — would be reachable anonymously. Walk every route registered by API() with auth enabled and a fresh in-memory database (no users, no keys), and assert each API-prefixed route returns 401 / 404 / 405 to an anonymous request. Public surfaces (/api/auth/*, /api/branding, /api/node/* token-authenticated routes, /healthz, branding asset server, generated-content server, static assets) are explicit allowlist entries with comments justifying them. Build-tagged 'auth' so it runs against the SQLite-backed auth DB (matches the existing auth suite). Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * test(http): pin agent endpoint per-user isolation contract agents.go's getUserID / effectiveUserID / canImpersonateUser / wantsAllUsers helpers are the single trust boundary for cross-user access on agent, agent-jobs, collections, and skills routes. A regression there is the difference between "regular user reads their own data" and "regular user reads anyone's data via ?user_id=victim". Lock in the contract: - effectiveUserID ignores ?user_id= for unauthenticated and RoleUser - effectiveUserID honours it for RoleAdmin and ProviderAgentWorker - wantsAllUsers requires admin AND the literal "true" string - canImpersonateUser is admin OR agent-worker, never plain RoleUser No production change — this commit only adds tests. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(downloader): drop redundant stat in removePartialFile The stat-then-remove pattern is a TOCTOU window and a wasted syscall — os.Remove already returns ErrNotExist for the missing-file case, so trust that and treat it as a no-op. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(http): redact secrets from trace buffer and distribution-token logs The /api/traces buffer captured Authorization, Cookie, Set-Cookie, and API-key headers verbatim from every request when tracing was enabled. The endpoint is admin-only but the buffer is reachable via any heap-style introspection and the captured tokens otherwise outlive the request. Strip those header values at capture time. Body redaction is left to a follow-up — the prompts are usually the operator's own and JSON-walking is invasive. Distribution tokens were also logged in plaintext from core/explorer/discovery.go; logs forward to syslog/journald and outlive the token. Redact those to a short prefix/suffix instead. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(auth): rate-limit OAuth callbacks separately from password endpoints The shared 5/min/IP limit on auth endpoints is right for password-style flows but too tight for OAuth callbacks: corporate SSO funnels many real users through one outbound IP and would trip the limit. Add a separate 60/min/IP limiter for /api/auth/{github,oidc}/callback so callbacks are bounded against floods without breaking shared-IP deployments. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(gallery): verify backend tarball sha256 when set in gallery entry GalleryBackend gained an optional sha256 field; the install path now threads it through to the existing downloader hash-verify (which already streams, verifies, and rolls back on mismatch). Galleries without sha256 keep working; the empty-SHA path still emits the existing "downloading without integrity check" warning. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * test(http): pin CSRF coverage on multipart endpoints The CSRF middleware in app.go is global (e.Use) so it covers every multipart upload route — branding assets, fine-tune datasets, audio transforms, agent collections. Pin that contract: cross-site multipart POSTs are rejected; same-origin / same-site / API-key clients are not. Also pins the SameSite=Lax fallback path the skipper relies on when Sec-Fetch-Site is absent. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(http): XSS hardening — CSP headers, safe href, base-href escape, SVG sandbox Several closely related XSS-prevention changes spanning the SPA shell, the React UI, and the branding asset server: - New SecurityHeaders middleware sets CSP, X-Content-Type-Options, X-Frame-Options, and Referrer-Policy on every response. The CSP keeps script-src permissive because the Vite bundle relies on inline + eval'd scripts; tightening that requires moving to a nonce-based policy. - The <base href> injection in the SPA shell escaped attacker-controllable Host / X-Forwarded-Host headers — a single quote in the host header broke out of the attribute. Pass through SecureBaseHref (html.EscapeString). - Three React sinks rendering untrusted content via dangerouslySetInnerHTML switch to text-node rendering with whiteSpace: pre-wrap: user message bodies in Chat.jsx and AgentChat.jsx, and the agent activity log in AgentChat.jsx. The hand-rolled escape on the agent user-message variant is replaced by the same plain-text path. - New safeHref util collapses non-allowlisted URI schemes (most importantly javascript:) to '#'. Applied to gallery `<a href={url}>` links in Models / Backends / Manage and to canvas artifact links — these come from gallery JSON or assistant tool calls and must be treated as untrusted. - The branding asset server attaches a sandbox CSP plus same-origin CORP to .svg responses. The React UI loads logos via <img>, but the same URL is also reachable via direct navigation; this prevents script execution if a hostile SVG slipped past upload validation. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(http): bound HTTP server with read-header and idle timeouts A net/http server with no timeouts is trivially Slowloris-able and leaks idle keep-alive connections. Set ReadHeaderTimeout (30s) to plug the slow-headers attack and IdleTimeout (120s) to cap keep-alive sockets. ReadTimeout and WriteTimeout stay at 0 because request bodies can be multi-GB model uploads and SSE / chat completions stream for many minutes; operators who need tighter per-request bounds should terminate slow clients at a reverse proxy. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * test(auth): pin PUT /api/auth/profile field-tampering contract The handler uses an explicit local body struct (only name and avatar_url) plus a gorm Updates(map) with a column allowlist, so an attacker posting {"role":"admin","email":"...","password_hash":"..."} can't mass-assign those fields. Lock that down with a regression test so a future "let's just c.Bind(&user)" refactor breaks loudly. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(services): strip directory components from multipart upload filenames UploadDataset and UploadToCollectionForUser took the raw multipart file.Filename and joined it into a destination path. The fine-tune upload was incidentally safe because of a UUID prefix that fused any leading '..' to a literal segment, but the protection is fragile. UploadToCollectionForUser handed the filename to a vendored backend without sanitising at all. Strip to filepath.Base at both boundaries and reject the trivial unsafe values ("", ".", "..", "/"). Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(react-ui): validate persisted MCP server entries on load localStorage is shared across same-origin pages; an XSS that lands once can poison persisted MCP server config to attempt header injection or to feed a non-http URL into the fetch path on subsequent loads. Validate every entry: types must match, URL must parse with http(s) scheme, header keys/values must be control-char-free. Drop anything that doesn't fit. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(http): close X-Forwarded-Prefix open redirect The reverse-proxy support concatenated X-Forwarded-Prefix into the redirect target without validation, so a forged header value of "//evil.com" turned the SPA-shell redirect helper at /, /browse, and /browse/* into a 301 to //evil.com/app. The path-strip middleware had the same shape on its prefix-trailing-slash redirect. Add SafeForwardedPrefix at the middleware boundary: must start with a single '/', no protocol-relative '//' opener, no scheme, no backslash, no control characters. Apply at both consumers; misconfig trips the validator and the header is dropped. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(http): refuse wildcard CORS when LOCALAI_CORS=true with empty allowlist When LOCALAI_CORS=true but LOCALAI_CORS_ALLOW_ORIGINS was empty, Echo's CORSWithConfig saw an empty allow-list and fell back to its default AllowOrigins=["*"]. An operator who flipped the strict-CORS feature flag without populating the list got the opposite of what they asked for. Echo never sets Allow-Credentials: true so this isn't directly exploitable (cookies aren't sent under wildcard CORS), but the misconfiguration trap is worth closing. Skip the registration and warn. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(auth): zxcvbn password strength check with user-acknowledged override The previous policy was len < 8, which let through "Password1" and the rest of the credential-stuffing corpus. LocalAI has no second factor yet, so the bar needs to sit higher. Add ValidatePasswordStrength using github.com/timbutler/zxcvbn (an actively-maintained fork of the trustelem port; v1.0.4, April 2024): - min 12 chars, max 72 (bcrypt's truncation point) - reject NUL bytes (some bcrypt callers truncate at the first NUL) - require zxcvbn score >= 3 ("safely unguessable, ~10^8 guesses to break"); the hint list ["localai", "local-ai", "admin"] penalises passwords built from the app's own branding zxcvbn produces false positives sometimes (a strong-looking password that happens to match a dictionary word) and operators occasionally need to set a known-weak password (kiosk demos, CI rigs). Add an acknowledgement path: PasswordPolicy{AllowWeak: true} skips the entropy check while still enforcing the hard rules. The structured PasswordErrorResponse marks weak-password rejections as Overridable so the UI can surface a "use this anyway" checkbox. Wired through register, self-service password change, and admin password reset on both the server and the React UI. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(react-ui): drop HTML5 minLength on new-password inputs minLength={12} on the new-password input let the browser block the form submit silently before any JS or network call ran. The browser focused the field, showed a brief native tooltip, and that was that — no toast, no fetch, no clue. Reproducible by typing fewer than 12 chars on the second password change of a session. The JS-level length check in handleSubmit already shows a toast and the server rejects with a structured error, so the HTML5 attribute was redundant defence anyway. Drop it. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(react-ui): bundle Geist fonts locally instead of fetching from Google The new CSP correctly refused to apply styles from fonts.googleapis.com because style-src is locked to 'self' and 'unsafe-inline'. Loosening the CSP would defeat its purpose; the right fix is to stop reaching out to a third-party CDN for fonts on every page load. Add @fontsource-variable/geist and @fontsource-variable/geist-mono as npm deps and import them once at boot. Drop the <link rel="preconnect"> and external stylesheet from index.html. Side benefit: no third-party tracking via Referer / IP on every UI load, no failure mode when offline / behind a captive portal. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * fix(react-ui): refresh i18n strings to reflect 12-char password minimum The translations still said "at least 8 characters" everywhere — the client-side toast on a too-short password change told the user the wrong floor. Update tooShort and newPasswordPlaceholder / newPasswordDescription across all five locales (en, es, it, de, zh-CN) to match the real ValidatePasswordStrength rule. Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(auth): make password length-floor overridable like the entropy check The 12-char minimum was a policy choice, not a technical invariant — only "non-empty", "<= 72 bytes", and "no NUL bytes" are real bcrypt constraints. Treating length-12 as a hard rule was inconsistent with the entropy check (already overridable) and friction for use cases where the account is just a name on a session, not a security boundary (single-user kiosk, CI rig, lab demo). Restructure ValidatePasswordStrength: - Hard rules (always enforced): non-empty, <= MaxPasswordLength, no NUL byte - Policy rules (skipped when AllowWeak=true): length >= 12, zxcvbn score >= 3 PasswordError now marks password_too_short as Overridable too. The React forms generalised from `error_code === 'password_too_weak'` to `overridable === true`, and the JS-side preflight length checks were removed (server is source of truth, returns the same checkbox flow). Assisted-by: Claude:claude-opus-4-7 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com> --------- Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
committed by
GitHub
parent
e5d7b84216
commit
670259ce43
@@ -70,6 +70,7 @@ type RunCMD struct {
|
||||
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
|
||||
UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"`
|
||||
DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"`
|
||||
AllowInsecurePublicBind bool `env:"LOCALAI_ALLOW_INSECURE_PUBLIC_BIND" default:"false" help:"Allow binding the API to a public-internet address without any authentication configured. Without this flag the server refuses to start when the bind address is public (or a wildcard on a host with a public interface) and no auth backend or static API key is set. Loopback, RFC 1918 LAN, ULA, link-local, and CGNAT (Tailscale) ranges are accepted regardless." group:"hardening"`
|
||||
DisableMetricsEndpoint bool `env:"LOCALAI_DISABLE_METRICS_ENDPOINT,DISABLE_METRICS_ENDPOINT" default:"false" help:"Disable the /metrics endpoint" group:"api"`
|
||||
HttpGetExemptedEndpoints []string `env:"LOCALAI_HTTP_GET_EXEMPTED_ENDPOINTS" default:"^/$,^/app(/.*)?$,^/browse(/.*)?$,^/login/?$,^/explorer/?$,^/assets/.*$,^/static/.*$,^/swagger.*$" help:"If LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET is overriden to true, this is the list of endpoints to exempt. Only adjust this in case of a security incident or as a result of a personal security posture review" group:"hardening"`
|
||||
Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"`
|
||||
@@ -516,6 +517,17 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
return fmt.Errorf("LocalAI failed to start: %w.\nTroubleshooting steps:\n 1. Check that your models directory exists and is accessible: %s\n 2. Verify model config files are valid YAML: 'local-ai util usecase-heuristic <config>'\n 3. Check available disk space and file permissions\n 4. Run with --log-level=debug for more details\nSee https://localai.io/basics/troubleshooting/ for more help", err, r.ModelsPath)
|
||||
}
|
||||
|
||||
// Refuse to bind a public-internet address without authentication unless
|
||||
// the operator has explicitly opted in. The auth middleware degrades to
|
||||
// pass-through when there is no auth DB and no legacy keys; on a loopback,
|
||||
// LAN, or VPN that's the historical "trusted network" deployment, but on
|
||||
// a public IP it makes every model, gallery install, settings change, and
|
||||
// admin endpoint reachable by anyone who can connect to the port.
|
||||
authConfigured := app.AuthDB() != nil || len(r.APIKeys) > 0
|
||||
if err := requireAuthOrTrustedBind(r.Address, authConfigured, r.AllowInsecurePublicBind); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
appHTTP, err := http.API(app)
|
||||
if err != nil {
|
||||
xlog.Error("error during HTTP App construction", "error", err)
|
||||
|
||||
126
core/cli/run_safety.go
Normal file
126
core/cli/run_safety.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
// interfaceAddrsFn is the host-interface enumeration call. Tests swap it to
|
||||
// simulate cloud-VM, home-LAN, and Tailscale-only topologies without poking
|
||||
// the real network stack.
|
||||
var interfaceAddrsFn = net.InterfaceAddrs
|
||||
|
||||
// requireAuthOrTrustedBind fails closed when the server would otherwise bind a
|
||||
// public-internet-reachable address with no authentication configured. Loopback,
|
||||
// RFC 1918, ULA, link-local, and CGNAT (Tailscale's default range) are all
|
||||
// trusted. Wildcard binds are trusted only when every host interface is.
|
||||
//
|
||||
// Operators with an external gating layer (e.g. a reverse proxy that enforces
|
||||
// auth) can opt out via --allow-insecure-public-bind.
|
||||
func requireAuthOrTrustedBind(address string, authConfigured, allowInsecurePublicBind bool) error {
|
||||
if authConfigured || allowInsecurePublicBind {
|
||||
return nil
|
||||
}
|
||||
if isTrustedBind(address) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf(`refusing to start: API bound to public address %q with no authentication configured.
|
||||
|
||||
When auth is disabled, the server has no idea who is calling it — every model,
|
||||
gallery install, settings change, and admin endpoint is reachable by anyone
|
||||
who can connect to the port. That is acceptable on a loopback, LAN, or VPN
|
||||
address but not on a public IP.
|
||||
|
||||
Pick one:
|
||||
1. Bind to a private/LAN/VPN interface only (e.g. --address 10.0.0.5:8080)
|
||||
2. Enable user authentication: --auth (or LOCALAI_AUTH=true), then sign in
|
||||
3. Set a static API key: --api-keys <key> (LOCALAI_API_KEY=<key>)
|
||||
4. Allow the public bind anyway: --allow-insecure-public-bind (only when an
|
||||
external system is gating access to this
|
||||
listener)`, address)
|
||||
}
|
||||
|
||||
// isTrustedBind reports whether `address` binds only to addresses that are
|
||||
// local, on a private LAN, or on a VPN. Hostnames it can't classify cleanly
|
||||
// are rejected.
|
||||
func isTrustedBind(address string) bool {
|
||||
if address == "" {
|
||||
return false
|
||||
}
|
||||
host, _, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return allInterfacesTrusted()
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip.IsUnspecified() {
|
||||
return allInterfacesTrusted()
|
||||
}
|
||||
return isPrivateOrLocalIP(ip)
|
||||
}
|
||||
// Hostname — every resolved address must be trusted. A name resolving to
|
||||
// a mix of public and private addresses fails closed.
|
||||
addrs, err := net.LookupHost(host)
|
||||
if err != nil || len(addrs) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, a := range addrs {
|
||||
ip := net.ParseIP(a)
|
||||
if ip == nil || !isPrivateOrLocalIP(ip) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// isPrivateOrLocalIP returns true for loopback, RFC 1918 / RFC 4193 private,
|
||||
// link-local, and RFC 6598 CGNAT addresses. CGNAT (100.64/10) gets the
|
||||
// special case because the Go stdlib doesn't classify it as private but
|
||||
// Tailscale and similar overlay VPNs hand them out.
|
||||
func isPrivateOrLocalIP(ip net.IP) bool {
|
||||
if ip.IsUnspecified() {
|
||||
return false
|
||||
}
|
||||
if !utils.IsPublicIP(ip) {
|
||||
return true
|
||||
}
|
||||
ip4 := ip.To4()
|
||||
return ip4 != nil && ip4[0] == 100 && (ip4[1]&0xc0) == 64
|
||||
}
|
||||
|
||||
// allInterfacesTrusted reports whether every IP assigned to a local interface
|
||||
// is private/local. A wildcard bind on a host with even one public interface
|
||||
// is genuinely exposing that public interface.
|
||||
//
|
||||
// Returns false on enumeration failure or when the host has no addresses
|
||||
// at all — we can't prove the bind is safe.
|
||||
func allInterfacesTrusted() bool {
|
||||
addrs, err := interfaceAddrsFn()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
sawAny := false
|
||||
for _, a := range addrs {
|
||||
var ip net.IP
|
||||
switch v := a.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
}
|
||||
if ip == nil || ip.IsUnspecified() {
|
||||
continue
|
||||
}
|
||||
sawAny = true
|
||||
if !isPrivateOrLocalIP(ip) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return sawAny
|
||||
}
|
||||
154
core/cli/run_safety_test.go
Normal file
154
core/cli/run_safety_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// withInterfaceAddrs swaps interfaceAddrsFn for the duration of one spec.
|
||||
// Ginkgo's DeferCleanup restores the original after the spec finishes, so
|
||||
// concurrent specs can each pretend to be running on a different host.
|
||||
func withInterfaceAddrs(cidrs ...string) {
|
||||
original := interfaceAddrsFn
|
||||
interfaceAddrsFn = func() ([]net.Addr, error) {
|
||||
out := make([]net.Addr, 0, len(cidrs))
|
||||
for _, c := range cidrs {
|
||||
ip, ipnet, err := net.ParseCIDR(c)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
out = append(out, &net.IPNet{IP: ip, Mask: ipnet.Mask})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
DeferCleanup(func() { interfaceAddrsFn = original })
|
||||
}
|
||||
|
||||
func withInterfaceAddrsErr(err error) {
|
||||
original := interfaceAddrsFn
|
||||
interfaceAddrsFn = func() ([]net.Addr, error) { return nil, err }
|
||||
DeferCleanup(func() { interfaceAddrsFn = original })
|
||||
}
|
||||
|
||||
var _ = Describe("requireAuthOrTrustedBind", func() {
|
||||
BeforeEach(func() {
|
||||
// Default to "host has only loopback" — the literal-IP cases below
|
||||
// don't touch interfaceAddrsFn but the wildcard cases do, and a
|
||||
// loopback-only host is the safest default for those.
|
||||
withInterfaceAddrs("127.0.0.1/8", "::1/128")
|
||||
})
|
||||
|
||||
It("permits any bind when auth is configured", func() {
|
||||
Expect(requireAuthOrTrustedBind("0.0.0.0:8080", true, false)).To(Succeed())
|
||||
Expect(requireAuthOrTrustedBind("203.0.113.5:8080", true, false)).To(Succeed())
|
||||
})
|
||||
|
||||
It("permits any bind when --allow-insecure-public-bind is set", func() {
|
||||
Expect(requireAuthOrTrustedBind("0.0.0.0:8080", false, true)).To(Succeed())
|
||||
Expect(requireAuthOrTrustedBind("203.0.113.5:8080", false, true)).To(Succeed())
|
||||
})
|
||||
|
||||
Context("literal IP binds", func() {
|
||||
It("permits loopback", func() {
|
||||
Expect(requireAuthOrTrustedBind("127.0.0.1:8080", false, false)).To(Succeed())
|
||||
Expect(requireAuthOrTrustedBind("[::1]:8080", false, false)).To(Succeed())
|
||||
Expect(requireAuthOrTrustedBind("127.5.4.3:8080", false, false)).To(Succeed())
|
||||
})
|
||||
|
||||
DescribeTable("permits private LAN ranges",
|
||||
func(addr string) {
|
||||
Expect(requireAuthOrTrustedBind(addr, false, false)).To(Succeed())
|
||||
},
|
||||
Entry("RFC 1918 — 10/8", "10.0.0.5:8080"),
|
||||
Entry("RFC 1918 — 172.16/12", "172.16.5.5:8080"),
|
||||
Entry("RFC 1918 — 192.168/16", "192.168.1.5:8080"),
|
||||
Entry("IPv6 ULA — fc00::/7", "[fc00::1]:8080"),
|
||||
Entry("IPv6 ULA — fd00::/8", "[fd12:3456:789a::1]:8080"),
|
||||
)
|
||||
|
||||
It("permits link-local addresses", func() {
|
||||
Expect(requireAuthOrTrustedBind("169.254.10.10:8080", false, false)).To(Succeed())
|
||||
Expect(requireAuthOrTrustedBind("[fe80::1]:8080", false, false)).To(Succeed())
|
||||
})
|
||||
|
||||
It("permits CGNAT (Tailscale default)", func() {
|
||||
Expect(requireAuthOrTrustedBind("100.64.0.5:8080", false, false)).To(Succeed())
|
||||
Expect(requireAuthOrTrustedBind("100.127.255.1:8080", false, false)).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects boundary addresses just outside CGNAT", func() {
|
||||
Expect(requireAuthOrTrustedBind("100.63.255.255:8080", false, false)).To(HaveOccurred())
|
||||
Expect(requireAuthOrTrustedBind("100.128.0.0:8080", false, false)).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects public IPv4", func() {
|
||||
Expect(requireAuthOrTrustedBind("8.8.8.8:8080", false, false)).To(HaveOccurred())
|
||||
Expect(requireAuthOrTrustedBind("203.0.113.5:8080", false, false)).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects public IPv6", func() {
|
||||
Expect(requireAuthOrTrustedBind("[2001:db8::1]:8080", false, false)).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("wildcard bind (`:port`, 0.0.0.0, ::)", func() {
|
||||
It("permits when every interface is private/loopback", func() {
|
||||
withInterfaceAddrs("127.0.0.1/8", "::1/128", "10.0.0.5/24", "fc00::1/64")
|
||||
Expect(requireAuthOrTrustedBind(":8080", false, false)).To(Succeed())
|
||||
Expect(requireAuthOrTrustedBind("0.0.0.0:8080", false, false)).To(Succeed())
|
||||
Expect(requireAuthOrTrustedBind("[::]:8080", false, false)).To(Succeed())
|
||||
})
|
||||
|
||||
It("permits when interfaces are loopback + Tailscale CGNAT", func() {
|
||||
withInterfaceAddrs("127.0.0.1/8", "::1/128", "100.65.10.20/32")
|
||||
Expect(requireAuthOrTrustedBind("0.0.0.0:8080", false, false)).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects when ANY interface has a public IP", func() {
|
||||
withInterfaceAddrs("127.0.0.1/8", "::1/128", "10.0.0.5/24", "203.0.113.42/24")
|
||||
Expect(requireAuthOrTrustedBind(":8080", false, false)).To(HaveOccurred())
|
||||
Expect(requireAuthOrTrustedBind("0.0.0.0:8080", false, false)).To(HaveOccurred())
|
||||
Expect(requireAuthOrTrustedBind("[::]:8080", false, false)).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("fails closed when interface enumeration errors", func() {
|
||||
withInterfaceAddrsErr(errors.New("enumeration disabled"))
|
||||
Expect(requireAuthOrTrustedBind("0.0.0.0:8080", false, false)).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("fails closed when the host has no addresses at all", func() {
|
||||
withInterfaceAddrs()
|
||||
Expect(requireAuthOrTrustedBind("0.0.0.0:8080", false, false)).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("hostname binds", func() {
|
||||
It("permits 'localhost' (resolves to loopback)", func() {
|
||||
Expect(requireAuthOrTrustedBind("localhost:8080", false, false)).To(Succeed())
|
||||
})
|
||||
})
|
||||
|
||||
Context("malformed input", func() {
|
||||
It("rejects an address with no port", func() {
|
||||
Expect(requireAuthOrTrustedBind("8080", false, false)).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects an empty address", func() {
|
||||
Expect(requireAuthOrTrustedBind("", false, false)).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("error message", func() {
|
||||
It("guides the operator with all four escape hatches", func() {
|
||||
err := requireAuthOrTrustedBind("203.0.113.5:8080", false, false)
|
||||
Expect(err).To(HaveOccurred())
|
||||
msg := err.Error()
|
||||
Expect(msg).To(ContainSubstring("--auth"))
|
||||
Expect(msg).To(ContainSubstring("--api-keys"))
|
||||
Expect(msg).To(ContainSubstring("--allow-insecure-public-bind"))
|
||||
Expect(msg).To(ContainSubstring("LAN"))
|
||||
Expect(msg).To(ContainSubstring("VPN"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -21,6 +21,15 @@ type DiscoveryServer struct {
|
||||
errorThreshold int
|
||||
}
|
||||
|
||||
// redactToken obfuscates a distribution token for log output; tokens may be
|
||||
// retained beyond their lifetime via syslog/journald.
|
||||
func redactToken(t string) string {
|
||||
if len(t) <= 8 {
|
||||
return "[redacted]"
|
||||
}
|
||||
return t[:4] + "…" + t[len(t)-4:]
|
||||
}
|
||||
|
||||
// NewDiscoveryServer creates a new DiscoveryServer with the given Database.
|
||||
// it keeps the db state in sync with the network state
|
||||
func NewDiscoveryServer(db *Database, dur time.Duration, failureThreshold int) *DiscoveryServer {
|
||||
@@ -92,10 +101,10 @@ func (s *DiscoveryServer) runBackground() {
|
||||
}
|
||||
}
|
||||
|
||||
xlog.Debug("Network clusters", "network", token, "count", len(ledgerK))
|
||||
xlog.Debug("Network clusters", "network", redactToken(token), "count", len(ledgerK))
|
||||
if len(ledgerK) != 0 {
|
||||
for _, k := range ledgerK {
|
||||
xlog.Debug("Clusterdata", "network", token, "cluster", k)
|
||||
xlog.Debug("Clusterdata", "network", redactToken(token), "cluster", k)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,7 +137,7 @@ func (s *DiscoveryServer) deleteFailedConnections() {
|
||||
for _, t := range s.database.TokenList() {
|
||||
data, _ := s.database.Get(t)
|
||||
if data.Failures > s.errorThreshold {
|
||||
xlog.Info("Token has been removed from the database", "token", t)
|
||||
xlog.Info("Token has been removed from the database", "token", redactToken(t))
|
||||
s.database.Delete(t)
|
||||
}
|
||||
}
|
||||
|
||||
55
core/gallery/backend_sha256_test.go
Normal file
55
core/gallery/backend_sha256_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Backend gallery integrity check. Operators populate `sha256:` on each
|
||||
// backend gallery entry; the install path now passes that value into the
|
||||
// downloader (which already knows how to hash-verify and roll back on
|
||||
// mismatch). This test pins the YAML wire format so a future refactor of
|
||||
// GalleryBackend can't drop the field silently.
|
||||
var _ = Describe("GalleryBackend.SHA256 wire format", func() {
|
||||
It("parses sha256 from YAML", func() {
|
||||
data := []byte(`name: test-backend
|
||||
uri: https://example.com/backend.tar.gz
|
||||
sha256: deadbeefcafef00d
|
||||
`)
|
||||
var b GalleryBackend
|
||||
Expect(yaml.Unmarshal(data, &b)).To(Succeed())
|
||||
Expect(b.SHA256).To(Equal("deadbeefcafef00d"))
|
||||
})
|
||||
|
||||
It("parses sha256 from JSON", func() {
|
||||
// The struct is JSON-tagged for HTTP API responses too.
|
||||
var b GalleryBackend
|
||||
// Round-trip via YAML to JSON to keep the test framework simple.
|
||||
b.Metadata.Name = "x"
|
||||
b.URI = "https://example.com/x.tar.gz"
|
||||
b.SHA256 = "deadbeefcafef00d"
|
||||
out, err := yaml.Marshal(&b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(out)).To(ContainSubstring("sha256: deadbeefcafef00d"))
|
||||
})
|
||||
|
||||
It("omits sha256 when empty", func() {
|
||||
b := GalleryBackend{Metadata: Metadata{Name: "x"}, URI: "https://example.com/x.tar.gz"}
|
||||
out, err := yaml.Marshal(&b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(out)).ToNot(ContainSubstring("sha256:"),
|
||||
"empty SHA256 must use omitempty so old galleries don't gain a noisy field")
|
||||
})
|
||||
|
||||
It("defaults SHA256 to empty for galleries that don't specify it", func() {
|
||||
// Old galleries without sha256: keep working. The downloader emits a
|
||||
// runtime warning ("downloading without integrity check") which is
|
||||
// the deliberate carrot-stick toward populating the field.
|
||||
var b GalleryBackend
|
||||
Expect(yaml.Unmarshal([]byte(`name: legacy-backend
|
||||
uri: https://example.com/legacy.tar.gz
|
||||
`), &b)).To(Succeed())
|
||||
Expect(b.SHA256).To(Equal(""))
|
||||
})
|
||||
})
|
||||
@@ -36,6 +36,9 @@ type GalleryBackend struct {
|
||||
Version string `json:"version,omitempty" yaml:"version,omitempty"`
|
||||
Mirrors []string `json:"mirrors,omitempty" yaml:"mirrors,omitempty"`
|
||||
CapabilitiesMap map[string]string `json:"capabilities,omitempty" yaml:"capabilities,omitempty"`
|
||||
// SHA256 is the expected sha256 of the backend tarball at URI / Mirrors.
|
||||
// Empty disables the integrity check; OCI URIs carry their own digest.
|
||||
SHA256 string `json:"sha256,omitempty" yaml:"sha256,omitempty"`
|
||||
}
|
||||
|
||||
func (backend *GalleryBackend) FindBestBackendFromMeta(systemState *system.SystemState, backends GalleryElements[*GalleryBackend]) *GalleryBackend {
|
||||
|
||||
@@ -222,7 +222,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||
}
|
||||
} else {
|
||||
xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath)
|
||||
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||
if err := uri.DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus); err != nil {
|
||||
xlog.Debug("Backend download failed, trying fallback", "backendPath", backendPath, "error", err)
|
||||
|
||||
// resetBackendPath cleans up partial state from a failed OCI extraction
|
||||
@@ -243,7 +243,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||
default:
|
||||
}
|
||||
resetBackendPath()
|
||||
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus); err == nil {
|
||||
success = true
|
||||
xlog.Debug("Downloaded backend from mirror", "uri", config.URI, "backendPath", backendPath)
|
||||
break
|
||||
@@ -256,7 +256,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||
if fallbackURI != string(config.URI) {
|
||||
resetBackendPath()
|
||||
xlog.Info("Trying fallback URI", "original", config.URI, "fallback", fallbackURI)
|
||||
if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus); err == nil {
|
||||
xlog.Info("Downloaded backend using fallback URI", "uri", fallbackURI, "backendPath", backendPath)
|
||||
success = true
|
||||
} else {
|
||||
@@ -265,7 +265,7 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
|
||||
resetBackendPath()
|
||||
devFallbackURI := fallbackURI + "-" + devSuffix
|
||||
xlog.Info("Trying development fallback URI", "fallback", devFallbackURI)
|
||||
if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, config.SHA256, 1, 1, downloadStatus); err == nil {
|
||||
xlog.Info("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath)
|
||||
success = true
|
||||
} else {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
@@ -160,6 +161,11 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
})
|
||||
}
|
||||
|
||||
// Security headers (CSP, X-Content-Type-Options, X-Frame-Options,
|
||||
// Referrer-Policy). Set early so every response — including 404s and
|
||||
// errors — picks them up.
|
||||
e.Use(httpMiddleware.SecurityHeaders())
|
||||
|
||||
// Custom logger middleware using xlog
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
@@ -277,13 +283,19 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
e.Use(auth.RequireQuota(application.AuthDB()))
|
||||
}
|
||||
|
||||
// CORS middleware
|
||||
// CORS middleware. When CORS=true the operator must also specify the
|
||||
// allowed origins; an empty allowlist would otherwise let Echo fall back
|
||||
// to AllowOrigins=["*"], which is almost never what someone enabling
|
||||
// "strict CORS" intended.
|
||||
if application.ApplicationConfig().CORS {
|
||||
corsConfig := middleware.CORSConfig{}
|
||||
if application.ApplicationConfig().CORSAllowOrigins != "" {
|
||||
corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",")
|
||||
if application.ApplicationConfig().CORSAllowOrigins == "" {
|
||||
xlog.Warn("LOCALAI_CORS=true but LOCALAI_CORS_ALLOW_ORIGINS is empty; refusing to register a wildcard CORS policy. Set the allowlist or unset LOCALAI_CORS.")
|
||||
} else {
|
||||
corsConfig := middleware.CORSConfig{
|
||||
AllowOrigins: strings.Split(application.ApplicationConfig().CORSAllowOrigins, ","),
|
||||
}
|
||||
e.Use(middleware.CORSWithConfig(corsConfig))
|
||||
}
|
||||
e.Use(middleware.CORSWithConfig(corsConfig))
|
||||
} else {
|
||||
e.Use(middleware.CORS())
|
||||
}
|
||||
@@ -424,10 +436,11 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
if err != nil {
|
||||
return c.String(http.StatusNotFound, "React UI not built")
|
||||
}
|
||||
// Inject <base href> for reverse-proxy support
|
||||
// Inject <base href> for reverse-proxy support; baseURL comes
|
||||
// from attacker-controllable Host / X-Forwarded-Host headers.
|
||||
baseURL := httpMiddleware.BaseURL(c)
|
||||
if baseURL != "" {
|
||||
baseTag := `<base href="` + baseURL + `" />`
|
||||
baseTag := `<base href="` + httpMiddleware.SecureBaseHref(baseURL) + `" />`
|
||||
indexHTML = []byte(strings.Replace(string(indexHTML), "<head>", "<head>\n "+baseTag, 1))
|
||||
}
|
||||
return c.HTMLBlob(http.StatusOK, indexHTML)
|
||||
@@ -440,9 +453,11 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
e.GET("/app", serveIndex)
|
||||
e.GET("/app/*", serveIndex)
|
||||
|
||||
// prefixRedirect performs a redirect that preserves X-Forwarded-Prefix for reverse-proxy support.
|
||||
// prefixRedirect performs a redirect that preserves X-Forwarded-Prefix
|
||||
// for reverse-proxy support. The prefix is forgeable on misconfigured
|
||||
// proxy chains, so reject anything that isn't a same-origin path.
|
||||
prefixRedirect := func(c echo.Context, target string) error {
|
||||
if prefix := c.Request().Header.Get("X-Forwarded-Prefix"); prefix != "" {
|
||||
if prefix, ok := httpMiddleware.SafeForwardedPrefix(c.Request().Header.Get("X-Forwarded-Prefix")); ok {
|
||||
target = strings.TrimSuffix(prefix, "/") + target
|
||||
}
|
||||
return c.Redirect(http.StatusMovedPermanently, target)
|
||||
@@ -490,6 +505,21 @@ func API(application *application.Application) (*echo.Echo, error) {
|
||||
|
||||
// Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route
|
||||
|
||||
// HTTP server timeouts.
|
||||
//
|
||||
// - ReadHeaderTimeout: bounds the slow-headers Slowloris case. 30s is
|
||||
// enough for a real client on a poor connection but cuts off a
|
||||
// drip-feeding attacker.
|
||||
// - IdleTimeout: bounds idle keep-alive connections.
|
||||
//
|
||||
// We deliberately leave ReadTimeout and WriteTimeout at 0:
|
||||
// - Request bodies can be multi-GB model/dataset uploads.
|
||||
// - Chat-completion and SSE responses can stream for many minutes.
|
||||
// Operators who need stricter limits should front the server with a
|
||||
// reverse proxy that terminates slow clients per-request.
|
||||
e.Server.ReadHeaderTimeout = 30 * time.Second
|
||||
e.Server.IdleTimeout = 120 * time.Second
|
||||
|
||||
// Log startup message
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
xlog.Info("LocalAI API server shutting down")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build auth
|
||||
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
|
||||
@@ -30,11 +30,16 @@ type providerEntry struct {
|
||||
}
|
||||
|
||||
// oauthUserInfo is a provider-agnostic representation of an authenticated user.
|
||||
// EmailVerified MUST reflect upstream verification: AssignRole compares Email
|
||||
// against the configured admin email, so an unverified claim of a privileged
|
||||
// address must not be honoured. Callers that cannot prove verification set
|
||||
// EmailVerified=false.
|
||||
type oauthUserInfo struct {
|
||||
Subject string
|
||||
Email string
|
||||
Name string
|
||||
AvatarURL string
|
||||
Subject string
|
||||
Email string
|
||||
EmailVerified bool
|
||||
Name string
|
||||
AvatarURL string
|
||||
}
|
||||
|
||||
// OAuthManager manages multiple OAuth/OIDC providers.
|
||||
@@ -236,10 +241,15 @@ func (m *OAuthManager) CallbackHandler(providerName string, db *gorm.DB, adminEm
|
||||
email = strings.ToLower(strings.TrimSpace(userInfo.Email))
|
||||
}
|
||||
|
||||
role := AssignRole(tx, email, adminEmail)
|
||||
// roleEmail is what AssignRole and NeedsInviteOrApproval
|
||||
// use to short-circuit on admin-email matches. Pass the
|
||||
// unverified-email-substituted form so an IdP-supplied
|
||||
// copy of LOCALAI_ADMIN_EMAIL doesn't bypass either gate.
|
||||
roleEmail := emailForRoleDecision(email, userInfo.EmailVerified)
|
||||
role := AssignRole(tx, roleEmail, adminEmail)
|
||||
status := StatusActive
|
||||
|
||||
if NeedsInviteOrApproval(tx, email, adminEmail, registrationMode) {
|
||||
if NeedsInviteOrApproval(tx, roleEmail, adminEmail, registrationMode) {
|
||||
if registrationMode == "invite" {
|
||||
if inviteCode == "" {
|
||||
return fmt.Errorf("invite_required")
|
||||
@@ -294,8 +304,11 @@ func (m *OAuthManager) CallbackHandler(providerName string, db *gorm.DB, adminEm
|
||||
return c.JSON(http.StatusForbidden, map[string]string{"error": "account pending approval"})
|
||||
}
|
||||
|
||||
// Maybe promote on login
|
||||
MaybePromote(db, user, adminEmail)
|
||||
// Same gate as roleEmail above: only verified emails can flip an
|
||||
// existing user to admin via the LOCALAI_ADMIN_EMAIL match.
|
||||
if userInfo.EmailVerified {
|
||||
MaybePromote(db, user, adminEmail)
|
||||
}
|
||||
|
||||
// Create session
|
||||
sessionID, err := CreateSession(db, user.ID, hmacSecret)
|
||||
@@ -321,20 +334,26 @@ func extractOIDCUserInfo(ctx context.Context, verifier *oidc.IDTokenVerifier, to
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Picture string `json:"picture"`
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified *bool `json:"email_verified"`
|
||||
Name string `json:"name"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse ID token claims: %w", err)
|
||||
}
|
||||
|
||||
// Default to false on absence: an IdP that doesn't issue the claim is
|
||||
// not asserting verification, and we must not promote on its email.
|
||||
verified := claims.EmailVerified != nil && *claims.EmailVerified
|
||||
|
||||
return &oauthUserInfo{
|
||||
Subject: claims.Sub,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
AvatarURL: claims.Picture,
|
||||
Subject: claims.Sub,
|
||||
Email: claims.Email,
|
||||
EmailVerified: verified,
|
||||
Name: claims.Name,
|
||||
AvatarURL: claims.Picture,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -353,16 +372,19 @@ type githubEmail struct {
|
||||
}
|
||||
|
||||
// fetchGitHubUserInfoAsOAuth fetches GitHub user info and returns it as oauthUserInfo.
|
||||
// GitHub only surfaces verified emails (public profile email and the
|
||||
// /user/emails Verified=true filter), so a non-empty email is always verified.
|
||||
func fetchGitHubUserInfoAsOAuth(ctx context.Context, accessToken string) (*oauthUserInfo, error) {
|
||||
info, err := fetchGitHubUserInfo(ctx, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &oauthUserInfo{
|
||||
Subject: fmt.Sprintf("%d", info.ID),
|
||||
Email: info.Email,
|
||||
Name: info.Name,
|
||||
AvatarURL: info.AvatarURL,
|
||||
Subject: fmt.Sprintf("%d", info.ID),
|
||||
Email: info.Email,
|
||||
EmailVerified: info.Email != "",
|
||||
Name: info.Name,
|
||||
AvatarURL: info.AvatarURL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
17
core/http/auth/oauth_email_decision.go
Normal file
17
core/http/auth/oauth_email_decision.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package auth
|
||||
|
||||
// emailForRoleDecision returns the email value to use for role assignment
|
||||
// (admin promotion, admin-email invite bypass) when handling an OAuth/OIDC
|
||||
// callback. An unverified email must NOT be honoured for these checks —
|
||||
// otherwise an attacker who can register on the configured IdP with an
|
||||
// unverified copy of LOCALAI_ADMIN_EMAIL would inherit admin role on first
|
||||
// login (via AssignRole) or on every subsequent login (via MaybePromote).
|
||||
//
|
||||
// Profile/display uses of the email are unaffected: those happen elsewhere
|
||||
// in the callback and treat the email as user-supplied advisory data.
|
||||
func emailForRoleDecision(email string, verified bool) string {
|
||||
if !verified {
|
||||
return ""
|
||||
}
|
||||
return email
|
||||
}
|
||||
61
core/http/auth/oauth_email_decision_test.go
Normal file
61
core/http/auth/oauth_email_decision_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
//go:build auth
|
||||
|
||||
package auth
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var _ = Describe("emailForRoleDecision", func() {
|
||||
It("returns the email when verified", func() {
|
||||
Expect(emailForRoleDecision("admin@example.com", true)).
|
||||
To(Equal("admin@example.com"))
|
||||
})
|
||||
|
||||
It("returns empty when not verified", func() {
|
||||
Expect(emailForRoleDecision("admin@example.com", false)).
|
||||
To(Equal(""))
|
||||
})
|
||||
|
||||
It("returns empty when email is empty regardless of flag", func() {
|
||||
Expect(emailForRoleDecision("", true)).To(Equal(""))
|
||||
Expect(emailForRoleDecision("", false)).To(Equal(""))
|
||||
})
|
||||
|
||||
Context("integration with AssignRole", func() {
|
||||
var db *gorm.DB
|
||||
|
||||
BeforeEach(func() {
|
||||
db, _ = InitDB(":memory:")
|
||||
// Seed at least one user so the "first user becomes admin"
|
||||
// branch doesn't hide the gate we're testing.
|
||||
seed := &User{
|
||||
ID: "seed-user",
|
||||
Email: "seed@example.com",
|
||||
Provider: ProviderGitHub,
|
||||
Subject: "seed",
|
||||
Role: RoleAdmin,
|
||||
Status: StatusActive,
|
||||
}
|
||||
Expect(db.Create(seed).Error).To(Succeed())
|
||||
})
|
||||
|
||||
It("does NOT promote on unverified email matching admin email", func() {
|
||||
role := AssignRole(db, emailForRoleDecision("admin@example.com", false), "admin@example.com")
|
||||
Expect(role).To(Equal(RoleUser),
|
||||
"unverified IdP claim of admin email must not yield admin role")
|
||||
})
|
||||
|
||||
It("DOES promote on verified email matching admin email", func() {
|
||||
role := AssignRole(db, emailForRoleDecision("admin@example.com", true), "admin@example.com")
|
||||
Expect(role).To(Equal(RoleAdmin))
|
||||
})
|
||||
|
||||
It("ignores email when admin email is unset, regardless of verification", func() {
|
||||
role := AssignRole(db, emailForRoleDecision("any@example.com", true), "")
|
||||
Expect(role).To(Equal(RoleUser))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,6 +1,100 @@
|
||||
package auth
|
||||
|
||||
import "golang.org/x/crypto/bcrypt"
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/timbutler/zxcvbn"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// MinPasswordLength is the floor for any new password. LocalAI does not
|
||||
// (yet) support a second factor, so the bar sits above NIST's 8-char
|
||||
// recommendation for MFA-protected accounts.
|
||||
const MinPasswordLength = 12
|
||||
|
||||
// MaxPasswordLength matches bcrypt's 72-byte truncation. Accepting longer
|
||||
// inputs creates a confusing UX where two "different" passwords hash to
|
||||
// the same value because bcrypt silently dropped the suffix.
|
||||
const MaxPasswordLength = 72
|
||||
|
||||
// MinPasswordScore is the minimum zxcvbn score (0-4) we accept. 3 means
|
||||
// "safely unguessable: moderate protection from offline slow-hash scenario"
|
||||
// per Dropbox's scoring; 4 is the highest.
|
||||
const MinPasswordScore = 3
|
||||
|
||||
// passwordContextHints are tokens fed to zxcvbn so it penalises passwords
|
||||
// built from the application's own name or branding.
|
||||
var passwordContextHints = []string{"localai", "local-ai", "admin"}
|
||||
|
||||
// ErrPasswordEmpty is returned for a zero-length password. Always rejected;
|
||||
// not overridable — bcrypt comparison on an empty string is its own hazard
|
||||
// and there's no realistic legitimate use.
|
||||
var ErrPasswordEmpty = errors.New("password must not be empty")
|
||||
|
||||
// ErrPasswordTooShort is returned when the password is below
|
||||
// MinPasswordLength. Overridable — short is a policy choice, not a
|
||||
// technical constraint.
|
||||
var ErrPasswordTooShort = fmt.Errorf("password is shorter than %d characters; pick a longer one or acknowledge the weak password to use it anyway", MinPasswordLength)
|
||||
|
||||
// ErrPasswordTooLong is returned when the password exceeds MaxPasswordLength.
|
||||
// Not overridable — bcrypt silently truncates at 72 bytes.
|
||||
var ErrPasswordTooLong = fmt.Errorf("password must be at most %d characters", MaxPasswordLength)
|
||||
|
||||
// ErrPasswordNullByte is returned when the password contains a NUL byte —
|
||||
// some bcrypt callers truncate at the first NUL, which would let an
|
||||
// attacker register "abc\x00garbage" and authenticate as "abc". Not
|
||||
// overridable.
|
||||
var ErrPasswordNullByte = errors.New("password must not contain null bytes")
|
||||
|
||||
// ErrPasswordTooWeak is returned when zxcvbn scores the password below
|
||||
// MinPasswordScore. Overridable — an operator may legitimately want a
|
||||
// known-weak password (kiosk demo, CI rig, false positive on zxcvbn).
|
||||
var ErrPasswordTooWeak = errors.New("password is too easy to guess; pick a longer or less common one, or acknowledge the weak password to use it anyway")
|
||||
|
||||
// PasswordPolicy controls which checks ValidatePasswordStrength enforces.
|
||||
// AllowWeak skips the policy-level checks (length floor, zxcvbn score) but
|
||||
// the technical invariants (non-empty, max length, no NUL bytes) always
|
||||
// apply.
|
||||
type PasswordPolicy struct {
|
||||
AllowWeak bool
|
||||
}
|
||||
|
||||
// ValidatePasswordStrength enforces the password policy. Callers should use
|
||||
// this for every register / change-password / admin-reset flow. Pass an
|
||||
// optional PasswordPolicy{AllowWeak: true} to skip the policy-level checks;
|
||||
// the technical invariants still apply.
|
||||
func ValidatePasswordStrength(password string, policy ...PasswordPolicy) error {
|
||||
// Hard rules — always enforced. These aren't policy, they're invariants
|
||||
// the bcrypt layer below us depends on.
|
||||
if len(password) == 0 {
|
||||
return ErrPasswordEmpty
|
||||
}
|
||||
if len(password) > MaxPasswordLength {
|
||||
return ErrPasswordTooLong
|
||||
}
|
||||
if strings.ContainsRune(password, 0) {
|
||||
return ErrPasswordNullByte
|
||||
}
|
||||
|
||||
allowWeak := false
|
||||
if len(policy) > 0 {
|
||||
allowWeak = policy[0].AllowWeak
|
||||
}
|
||||
if allowWeak {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Policy-level checks — bypassable via AllowWeak.
|
||||
if len(password) < MinPasswordLength {
|
||||
return ErrPasswordTooShort
|
||||
}
|
||||
if zxcvbn.PasswordStrength(password, passwordContextHints).Score < MinPasswordScore {
|
||||
return ErrPasswordTooWeak
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HashPassword returns a bcrypt hash of the given password.
|
||||
func HashPassword(password string) (string, error) {
|
||||
@@ -12,3 +106,33 @@ func HashPassword(password string) (string, error) {
|
||||
func CheckPassword(hash, password string) bool {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil
|
||||
}
|
||||
|
||||
// PasswordErrorResponse describes a password-policy rejection in a
|
||||
// machine-readable form so the UI can choose whether to offer an "use
|
||||
// this anyway" override (only when Overridable is true).
|
||||
type PasswordErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
ErrorCode string `json:"error_code"`
|
||||
Overridable bool `json:"overridable"`
|
||||
}
|
||||
|
||||
// PasswordError returns a structured response for a ValidatePasswordStrength
|
||||
// error. err must be one of the package-level password errors.
|
||||
func PasswordError(err error) PasswordErrorResponse {
|
||||
r := PasswordErrorResponse{Error: err.Error()}
|
||||
switch {
|
||||
case errors.Is(err, ErrPasswordEmpty):
|
||||
r.ErrorCode = "password_empty"
|
||||
case errors.Is(err, ErrPasswordTooShort):
|
||||
r.ErrorCode = "password_too_short"
|
||||
r.Overridable = true
|
||||
case errors.Is(err, ErrPasswordTooLong):
|
||||
r.ErrorCode = "password_too_long"
|
||||
case errors.Is(err, ErrPasswordNullByte):
|
||||
r.ErrorCode = "password_null_byte"
|
||||
case errors.Is(err, ErrPasswordTooWeak):
|
||||
r.ErrorCode = "password_too_weak"
|
||||
r.Overridable = true
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
100
core/http/auth/password_test.go
Normal file
100
core/http/auth/password_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Password policy", func() {
|
||||
Describe("ValidatePasswordStrength", func() {
|
||||
// Anything below MinPasswordScore is rejected. zxcvbn scores are subject
|
||||
// to the embedded dictionary; if the underlying dictionary changes we
|
||||
// want the test to break loudly so we can re-baseline rather than
|
||||
// silently accept weaker passwords.
|
||||
DescribeTable("rejects weak inputs",
|
||||
func(pw string) {
|
||||
Expect(auth.ValidatePasswordStrength(pw)).ToNot(Succeed())
|
||||
},
|
||||
Entry("too short", "Tr0ub4dor"),
|
||||
Entry("empty", ""),
|
||||
Entry("common: password", "password1234"),
|
||||
Entry("common: 12345", "12345678901234"),
|
||||
Entry("common: qwerty", "qwertyuiopas"),
|
||||
Entry("keyboard run", "qwertyuiop12"),
|
||||
Entry("app branding only", "localailocalai"),
|
||||
Entry("repeated word", "passwordpassword"),
|
||||
)
|
||||
|
||||
DescribeTable("accepts strong inputs",
|
||||
func(pw string) {
|
||||
Expect(auth.ValidatePasswordStrength(pw)).To(Succeed())
|
||||
},
|
||||
Entry("diceware-style passphrase", "correct horse battery staple unicycle"),
|
||||
Entry("random-ish 14-char with punctuation", "Th3-Quick~Br0wn-F0x.Jumps"),
|
||||
Entry("random mixed", "q9V$mZ1pL7nB3w"),
|
||||
)
|
||||
|
||||
Context("length boundaries", func() {
|
||||
It("returns ErrPasswordEmpty for empty input", func() {
|
||||
Expect(auth.ValidatePasswordStrength("")).To(MatchError(auth.ErrPasswordEmpty))
|
||||
})
|
||||
It("returns ErrPasswordTooShort just under the floor", func() {
|
||||
short := strings.Repeat("a", auth.MinPasswordLength-1)
|
||||
Expect(auth.ValidatePasswordStrength(short)).To(MatchError(auth.ErrPasswordTooShort))
|
||||
})
|
||||
It("returns ErrPasswordTooLong just over the ceiling", func() {
|
||||
long := strings.Repeat("a", auth.MaxPasswordLength+1)
|
||||
Expect(auth.ValidatePasswordStrength(long)).To(MatchError(auth.ErrPasswordTooLong))
|
||||
})
|
||||
})
|
||||
|
||||
It("rejects passwords containing NUL bytes", func() {
|
||||
Expect(auth.ValidatePasswordStrength("abcdef\x00ghijklmnop")).To(MatchError(auth.ErrPasswordNullByte))
|
||||
})
|
||||
|
||||
// AllowWeak skips the policy-level checks (length floor + entropy) but
|
||||
// the technical invariants (empty / max-length / NUL byte) always apply.
|
||||
Context("with AllowWeak override", func() {
|
||||
weak := "password1234"
|
||||
|
||||
It("rejects the weak password by default", func() {
|
||||
Expect(auth.ValidatePasswordStrength(weak)).To(MatchError(auth.ErrPasswordTooWeak))
|
||||
})
|
||||
It("accepts the weak password when AllowWeak is set", func() {
|
||||
Expect(auth.ValidatePasswordStrength(weak, auth.PasswordPolicy{AllowWeak: true})).To(Succeed())
|
||||
})
|
||||
It("bypasses the length floor", func() {
|
||||
Expect(auth.ValidatePasswordStrength("short", auth.PasswordPolicy{AllowWeak: true})).To(Succeed())
|
||||
})
|
||||
It("does not bypass the empty-input check", func() {
|
||||
Expect(auth.ValidatePasswordStrength("", auth.PasswordPolicy{AllowWeak: true})).To(MatchError(auth.ErrPasswordEmpty))
|
||||
})
|
||||
It("does not bypass the NUL-byte check", func() {
|
||||
Expect(auth.ValidatePasswordStrength("ok\x00password1234", auth.PasswordPolicy{AllowWeak: true})).To(MatchError(auth.ErrPasswordNullByte))
|
||||
})
|
||||
It("does not bypass the max-length check", func() {
|
||||
long := strings.Repeat("a", auth.MaxPasswordLength+1)
|
||||
Expect(auth.ValidatePasswordStrength(long, auth.PasswordPolicy{AllowWeak: true})).To(MatchError(auth.ErrPasswordTooLong))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("PasswordError", func() {
|
||||
DescribeTable("produces a structured response",
|
||||
func(err error, code string, overridable bool) {
|
||||
r := auth.PasswordError(err)
|
||||
Expect(r.ErrorCode).To(Equal(code))
|
||||
Expect(r.Overridable).To(Equal(overridable))
|
||||
Expect(r.Error).ToNot(BeEmpty())
|
||||
},
|
||||
Entry("empty", auth.ErrPasswordEmpty, "password_empty", false),
|
||||
Entry("too short", auth.ErrPasswordTooShort, "password_too_short", true),
|
||||
Entry("too long", auth.ErrPasswordTooLong, "password_too_long", false),
|
||||
Entry("null byte", auth.ErrPasswordNullByte, "password_null_byte", false),
|
||||
Entry("too weak", auth.ErrPasswordTooWeak, "password_too_weak", true),
|
||||
)
|
||||
})
|
||||
})
|
||||
145
core/http/csrf_multipart_test.go
Normal file
145
core/http/csrf_multipart_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
echoMiddleware "github.com/labstack/echo/v4/middleware"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// CSRF on multipart endpoints. The protection in core/http/app.go is global
|
||||
// (e.Use), so it covers /api/branding/asset/:kind, /api/finetune, audio
|
||||
// transforms, etc. This test rebuilds the same middleware config in
|
||||
// isolation and pins the contract: cross-site POSTs are rejected; same-site
|
||||
// POSTs and Authorization-header requests are not.
|
||||
//
|
||||
// Booting the whole application via API() per-spec costs tens of seconds and
|
||||
// has external dependencies, so we deliberately reconstruct just the
|
||||
// middleware here. If the app.go config drifts from this test, fix the
|
||||
// constants in the test rather than the app.
|
||||
var _ = Describe("CSRF coverage on multipart endpoints", func() {
|
||||
var app *echo.Echo
|
||||
|
||||
BeforeEach(func() {
|
||||
app = echo.New()
|
||||
app.Use(echoMiddleware.CSRFWithConfig(echoMiddleware.CSRFConfig{
|
||||
Skipper: func(c echo.Context) bool {
|
||||
if c.Request().Header.Get("Authorization") != "" {
|
||||
return true
|
||||
}
|
||||
if c.Request().Header.Get("x-api-key") != "" || c.Request().Header.Get("xi-api-key") != "" {
|
||||
return true
|
||||
}
|
||||
if c.Request().Header.Get("Sec-Fetch-Site") == "" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
},
|
||||
AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) {
|
||||
if c.Request().Header.Get("Sec-Fetch-Site") == "same-site" {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
},
|
||||
}))
|
||||
app.POST("/api/branding/asset/:kind", func(c echo.Context) error {
|
||||
return c.NoContent(http.StatusOK)
|
||||
})
|
||||
})
|
||||
|
||||
multipartBody := func() (*bytes.Buffer, string) {
|
||||
buf := &bytes.Buffer{}
|
||||
w := multipart.NewWriter(buf)
|
||||
fw, err := w.CreateFormFile("file", "logo.svg")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, _ = fw.Write([]byte(`<svg xmlns="http://www.w3.org/2000/svg" />`))
|
||||
Expect(w.Close()).To(Succeed())
|
||||
return buf, w.FormDataContentType()
|
||||
}
|
||||
|
||||
It("rejects a cross-site multipart POST", func() {
|
||||
body, ct := multipartBody()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/branding/asset/logo", body)
|
||||
req.Header.Set("Content-Type", ct)
|
||||
req.Header.Set("Sec-Fetch-Site", "cross-site")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
// Echo's CSRF returns 400 (missing csrf token) when AllowSecFetchSite
|
||||
// returns false — what we care about is that the request did not
|
||||
// reach the handler with status 200.
|
||||
Expect(rec.Code).To(BeNumerically(">=", 400),
|
||||
"cross-site POST must be rejected; got %d", rec.Code)
|
||||
Expect(rec.Code).To(BeNumerically("<", 500),
|
||||
"cross-site POST must be rejected with 4xx; got %d", rec.Code)
|
||||
})
|
||||
|
||||
It("allows a same-origin multipart POST", func() {
|
||||
body, ct := multipartBody()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/branding/asset/logo", body)
|
||||
req.Header.Set("Content-Type", ct)
|
||||
req.Header.Set("Sec-Fetch-Site", "same-origin")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK),
|
||||
"same-origin POST must reach the handler; got %d body=%s", rec.Code, rec.Body.String())
|
||||
})
|
||||
|
||||
It("allows a same-site multipart POST", func() {
|
||||
body, ct := multipartBody()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/branding/asset/logo", body)
|
||||
req.Header.Set("Content-Type", ct)
|
||||
req.Header.Set("Sec-Fetch-Site", "same-site")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK),
|
||||
"same-site POST must reach the handler; got %d body=%s", rec.Code, rec.Body.String())
|
||||
})
|
||||
|
||||
It("skips CSRF for Authorization header clients (cross-site is fine)", func() {
|
||||
body, ct := multipartBody()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/branding/asset/logo", body)
|
||||
req.Header.Set("Content-Type", ct)
|
||||
req.Header.Set("Sec-Fetch-Site", "cross-site")
|
||||
req.Header.Set("Authorization", "Bearer something")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
// Skipper short-circuits CSRF; the handler is reached.
|
||||
Expect(rec.Code).To(Equal(http.StatusOK),
|
||||
"Authorization header must skip CSRF; got %d body=%s", rec.Code, rec.Body.String())
|
||||
})
|
||||
|
||||
It("skips CSRF for x-api-key clients", func() {
|
||||
body, ct := multipartBody()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/branding/asset/logo", body)
|
||||
req.Header.Set("Content-Type", ct)
|
||||
req.Header.Set("Sec-Fetch-Site", "cross-site")
|
||||
req.Header.Set("x-api-key", "something")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK),
|
||||
"x-api-key must skip CSRF; got %d", rec.Code)
|
||||
})
|
||||
|
||||
It("falls through when Sec-Fetch-Site is absent (relies on SameSite=Lax cookie elsewhere)", func() {
|
||||
// Older browsers and some reverse proxies strip Sec-Fetch-Site. The
|
||||
// skipper returns true in that case; the auth-cookie SameSite=Lax
|
||||
// attribute is the actual defense (cookies aren't sent on cross-site
|
||||
// POSTs, so auth would 401 the request). This test just pins the
|
||||
// skipper behavior — the SameSite contract lives in oauth.go /
|
||||
// session.go.
|
||||
body, ct := multipartBody()
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/branding/asset/logo", body)
|
||||
req.Header.Set("Content-Type", ct)
|
||||
// no Sec-Fetch-Site
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK),
|
||||
"missing Sec-Fetch-Site must skip CSRF (SameSite cookie is the fallback); got %d", rec.Code)
|
||||
})
|
||||
})
|
||||
178
core/http/endpoints/localai/agents_isolation_test.go
Normal file
178
core/http/endpoints/localai/agents_isolation_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Per-user isolation contract for agent endpoints: a regular user is scoped
|
||||
// to their own data, admins and agent-worker service accounts can override
|
||||
// scope via ?user_id=, and ?all_users=true is admin-only.
|
||||
var _ = Describe("Agent endpoint per-user isolation", func() {
|
||||
var e *echo.Echo
|
||||
|
||||
makeContext := func(query string) echo.Context {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/agents"+query, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
return e.NewContext(req, rec)
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
e = echo.New()
|
||||
})
|
||||
|
||||
Describe("getUserID", func() {
|
||||
It("returns empty when no user is in context", func() {
|
||||
c := makeContext("")
|
||||
Expect(getUserID(c)).To(Equal(""))
|
||||
})
|
||||
|
||||
It("returns the authenticated user's ID", func() {
|
||||
c := makeContext("")
|
||||
c.Set("auth_user", &auth.User{ID: "alice", Role: auth.RoleUser})
|
||||
Expect(getUserID(c)).To(Equal("alice"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("isAdminUser", func() {
|
||||
It("is false when no user is in context", func() {
|
||||
c := makeContext("")
|
||||
Expect(isAdminUser(c)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is false for a regular user", func() {
|
||||
c := makeContext("")
|
||||
c.Set("auth_user", &auth.User{ID: "alice", Role: auth.RoleUser})
|
||||
Expect(isAdminUser(c)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is true for admin", func() {
|
||||
c := makeContext("")
|
||||
c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin})
|
||||
Expect(isAdminUser(c)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("canImpersonateUser", func() {
|
||||
It("is false for unauthenticated", func() {
|
||||
c := makeContext("")
|
||||
Expect(canImpersonateUser(c)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is false for a regular user", func() {
|
||||
c := makeContext("")
|
||||
c.Set("auth_user", &auth.User{ID: "alice", Role: auth.RoleUser})
|
||||
Expect(canImpersonateUser(c)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is true for admin", func() {
|
||||
c := makeContext("")
|
||||
c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin})
|
||||
Expect(canImpersonateUser(c)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("is true for agent-worker service accounts", func() {
|
||||
c := makeContext("")
|
||||
c.Set("auth_user", &auth.User{
|
||||
ID: "worker-1", Role: auth.RoleUser,
|
||||
Provider: auth.ProviderAgentWorker,
|
||||
})
|
||||
Expect(canImpersonateUser(c)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("ignores user-supplied 'provider' since the field is set server-side at registration", func() {
|
||||
// Defense-in-depth: even if a user could somehow inject a
|
||||
// Provider field into their session row (they can't via any
|
||||
// supported flow), the role check is independent and a
|
||||
// non-admin RoleUser still fails the impersonation gate.
|
||||
c := makeContext("")
|
||||
c.Set("auth_user", &auth.User{
|
||||
ID: "alice", Role: auth.RoleUser,
|
||||
// User claims to be an agent worker.
|
||||
Provider: auth.ProviderAgentWorker,
|
||||
})
|
||||
// Per the current contract canImpersonate accepts this — but
|
||||
// the path that mints ProviderAgentWorker is server-only
|
||||
// (see core/services/nodes/registration.go). Pin the
|
||||
// expectation so a future refactor that changes the rule
|
||||
// here also has to change the lock below.
|
||||
Expect(canImpersonateUser(c)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Describe("effectiveUserID", func() {
|
||||
It("returns the caller's own ID when no query param is set", func() {
|
||||
c := makeContext("")
|
||||
c.Set("auth_user", &auth.User{ID: "alice", Role: auth.RoleUser})
|
||||
Expect(effectiveUserID(c)).To(Equal("alice"))
|
||||
})
|
||||
|
||||
It("ignores ?user_id= for a regular user (no impersonation)", func() {
|
||||
c := makeContext("?user_id=bob")
|
||||
c.Set("auth_user", &auth.User{ID: "alice", Role: auth.RoleUser})
|
||||
Expect(effectiveUserID(c)).To(Equal("alice"),
|
||||
"regular user must not be able to scope queries to another user")
|
||||
})
|
||||
|
||||
It("ignores ?user_id= for an unauthenticated caller", func() {
|
||||
c := makeContext("?user_id=alice")
|
||||
// no auth_user set
|
||||
Expect(effectiveUserID(c)).To(Equal(""))
|
||||
})
|
||||
|
||||
It("honors ?user_id= for admin", func() {
|
||||
c := makeContext("?user_id=alice")
|
||||
c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin})
|
||||
Expect(effectiveUserID(c)).To(Equal("alice"))
|
||||
})
|
||||
|
||||
It("honors ?user_id= for agent-worker service accounts", func() {
|
||||
c := makeContext("?user_id=alice")
|
||||
c.Set("auth_user", &auth.User{
|
||||
ID: "worker-1", Role: auth.RoleUser,
|
||||
Provider: auth.ProviderAgentWorker,
|
||||
})
|
||||
Expect(effectiveUserID(c)).To(Equal("alice"))
|
||||
})
|
||||
|
||||
It("falls back to caller's own ID when impersonation is allowed but query is empty", func() {
|
||||
c := makeContext("")
|
||||
c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin})
|
||||
Expect(effectiveUserID(c)).To(Equal("root"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("wantsAllUsers", func() {
|
||||
It("is false when ?all_users is not set", func() {
|
||||
c := makeContext("")
|
||||
c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin})
|
||||
Expect(wantsAllUsers(c)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("is false for a regular user even with ?all_users=true", func() {
|
||||
c := makeContext("?all_users=true")
|
||||
c.Set("auth_user", &auth.User{ID: "alice", Role: auth.RoleUser})
|
||||
Expect(wantsAllUsers(c)).To(BeFalse(),
|
||||
"regular user must not be able to fan out to all users")
|
||||
})
|
||||
|
||||
It("is true for admin with ?all_users=true", func() {
|
||||
c := makeContext("?all_users=true")
|
||||
c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin})
|
||||
Expect(wantsAllUsers(c)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("only accepts the literal 'true' string", func() {
|
||||
// Sanity — the query string parser should be strict about the
|
||||
// value, otherwise typos and case variants might bypass.
|
||||
c := makeContext("?all_users=1")
|
||||
c.Set("auth_user", &auth.User{ID: "root", Role: auth.RoleAdmin})
|
||||
Expect(wantsAllUsers(c)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -340,6 +340,17 @@ func ServeBrandingAssetEndpoint(appConfig *config.ApplicationConfig) echo.Handle
|
||||
|
||||
path := filepath.Join(appConfig.DynamicConfigsDir, brandingDirName, file)
|
||||
c.Response().Header().Set("Cache-Control", "public, max-age=300")
|
||||
// SVG can carry <script> and event handlers that fire when the
|
||||
// asset is loaded as a top-level document (direct navigation or
|
||||
// <object>/<iframe> embed). Sandbox so a malicious SVG that slipped
|
||||
// past upload validation cannot execute.
|
||||
if strings.EqualFold(filepath.Ext(file), ".svg") {
|
||||
c.Response().Header().Set(
|
||||
"Content-Security-Policy",
|
||||
"default-src 'none'; style-src 'unsafe-inline'; sandbox",
|
||||
)
|
||||
c.Response().Header().Set("Cross-Origin-Resource-Policy", "same-origin")
|
||||
}
|
||||
return c.File(path)
|
||||
}
|
||||
}
|
||||
|
||||
88
core/http/endpoints/localai/branding_test.go
Normal file
88
core/http/endpoints/localai/branding_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// SVG branding assets are loaded by the React UI via <img>, which never
|
||||
// executes script. But the same URL is reachable via direct navigation, in
|
||||
// which case the browser does run script tags inside the SVG. The serve
|
||||
// handler must lock the response down so an attacker who got an admin to
|
||||
// upload a hostile logo can't pivot to same-origin XSS.
|
||||
var _ = Describe("Branding SVG hardening", func() {
|
||||
var (
|
||||
dir string
|
||||
app *echo.Echo
|
||||
appCfg *config.ApplicationConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
dir, err = os.MkdirTemp("", "branding-test-*")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
brandingDir := filepath.Join(dir, "branding")
|
||||
Expect(os.MkdirAll(brandingDir, 0o750)).To(Succeed())
|
||||
|
||||
svg := []byte(`<svg xmlns="http://www.w3.org/2000/svg"><script>alert(1)</script></svg>`)
|
||||
Expect(os.WriteFile(filepath.Join(brandingDir, "logo.svg"), svg, 0o644)).To(Succeed())
|
||||
|
||||
appCfg = config.NewApplicationConfig()
|
||||
appCfg.DynamicConfigsDir = dir
|
||||
appCfg.Branding.LogoFile = "logo.svg"
|
||||
|
||||
app = echo.New()
|
||||
app.GET("/branding/asset/:kind", ServeBrandingAssetEndpoint(appCfg))
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
Expect(os.RemoveAll(dir)).To(Succeed())
|
||||
})
|
||||
|
||||
It("returns the SVG body", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/branding/asset/logo", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(rec.Body.String()).To(ContainSubstring("<svg"))
|
||||
})
|
||||
|
||||
It("attaches a strict CSP that blocks script execution", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/branding/asset/logo", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
csp := rec.Header().Get("Content-Security-Policy")
|
||||
Expect(csp).ToNot(BeEmpty(), "SVG branding assets must ship a CSP")
|
||||
Expect(csp).To(ContainSubstring("default-src 'none'"))
|
||||
Expect(csp).To(ContainSubstring("sandbox"))
|
||||
})
|
||||
|
||||
It("attaches Cross-Origin-Resource-Policy: same-origin to SVG", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/branding/asset/logo", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
Expect(rec.Header().Get("Cross-Origin-Resource-Policy")).To(Equal("same-origin"))
|
||||
})
|
||||
|
||||
It("does not attach the SVG-specific CSP to non-SVG assets", func() {
|
||||
// Replace the SVG with a PNG-named file.
|
||||
Expect(os.Remove(filepath.Join(dir, "branding", "logo.svg"))).To(Succeed())
|
||||
Expect(os.WriteFile(filepath.Join(dir, "branding", "logo.png"), []byte("\x89PNG\r\n\x1a\n"), 0o644)).To(Succeed())
|
||||
appCfg.Branding.LogoFile = "logo.png"
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/branding/asset/logo", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
// The SVG-specific lockdown should not apply to PNG.
|
||||
Expect(rec.Header().Get("Content-Security-Policy")).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
@@ -12,39 +12,16 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
var privateNetworks []*net.IPNet
|
||||
|
||||
func init() {
|
||||
for _, cidr := range []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"127.0.0.0/8",
|
||||
"169.254.0.0/16",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
"fe80::/10",
|
||||
} {
|
||||
_, network, _ := net.ParseCIDR(cidr)
|
||||
privateNetworks = append(privateNetworks, network)
|
||||
}
|
||||
}
|
||||
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
for _, network := range privateNetworks {
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CORSProxyEndpoint proxies HTTP requests to external MCP servers,
|
||||
// solving CORS issues for browser-based MCP connections.
|
||||
// The target URL is passed as a query parameter: /api/cors-proxy?url=https://...
|
||||
//
|
||||
// SSRF guard: the resolved IP is classified via utils.IsPublicIP and the
|
||||
// same IP is reused for the connection (DNS-rebinding mitigation).
|
||||
func CORSProxyEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
targetURL := c.QueryParam("url")
|
||||
@@ -61,12 +38,24 @@ func CORSProxyEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "only http and https schemes are supported"})
|
||||
}
|
||||
|
||||
ips, err := net.LookupIP(parsed.Hostname())
|
||||
if err != nil {
|
||||
hostname := parsed.Hostname()
|
||||
if hostname == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "URL has no hostname"})
|
||||
}
|
||||
// Reject internal hostnames before DNS — split-horizon DNS or hosts
|
||||
// files could otherwise map them to addresses the CIDR check accepts.
|
||||
lowerHost := strings.ToLower(hostname)
|
||||
if lowerHost == "localhost" || strings.HasSuffix(lowerHost, ".local") ||
|
||||
lowerHost == "metadata.google.internal" || lowerHost == "instance-data" {
|
||||
return c.JSON(http.StatusForbidden, map[string]string{"error": "requests to internal hosts are not allowed"})
|
||||
}
|
||||
|
||||
ips, err := net.LookupIP(hostname)
|
||||
if err != nil || len(ips) == 0 {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "cannot resolve hostname"})
|
||||
}
|
||||
for _, ip := range ips {
|
||||
if isPrivateIP(ip) {
|
||||
if !utils.IsPublicIP(ip) {
|
||||
return c.JSON(http.StatusForbidden, map[string]string{"error": "requests to private networks are not allowed"})
|
||||
}
|
||||
}
|
||||
|
||||
110
core/http/endpoints/localai/cors_proxy_test.go
Normal file
110
core/http/endpoints/localai/cors_proxy_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// SSRF guards on the CORS proxy. Notably 0.0.0.0/8 and :: route to localhost
|
||||
// on Linux, and ::ffff:127.0.0.1 (IPv4-mapped IPv6) reaches 127.0.0.1 —
|
||||
// hand-rolled CIDR blocklists frequently miss these.
|
||||
var _ = Describe("CORSProxy SSRF guards", func() {
|
||||
var app *echo.Echo
|
||||
|
||||
BeforeEach(func() {
|
||||
app = echo.New()
|
||||
appConfig := config.NewApplicationConfig()
|
||||
app.GET("/api/cors-proxy", CORSProxyEndpoint(appConfig))
|
||||
app.POST("/api/cors-proxy", CORSProxyEndpoint(appConfig))
|
||||
})
|
||||
|
||||
rejectsTarget := func(target string) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/cors-proxy?url="+target, nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
// Any 4xx is acceptable — we only care that the request was rejected
|
||||
// before a connection was attempted to the local network.
|
||||
Expect(rec.Code).To(BeNumerically(">=", 400),
|
||||
"expected proxy to reject %s, got %d body=%s", target, rec.Code, rec.Body.String())
|
||||
Expect(rec.Code).To(BeNumerically("<", 500),
|
||||
"expected proxy to reject %s with 4xx, got %d body=%s", target, rec.Code, rec.Body.String())
|
||||
}
|
||||
|
||||
It("rejects http://0.0.0.0/ (routes to localhost on Linux)", func() {
|
||||
rejectsTarget("http://0.0.0.0/anything")
|
||||
})
|
||||
|
||||
It("rejects http://0.0.0.0:PORT/ (catches loopback bind on any port)", func() {
|
||||
rejectsTarget("http://0.0.0.0:8080/")
|
||||
})
|
||||
|
||||
It("rejects http://[::]/ (IPv6 unspecified)", func() {
|
||||
rejectsTarget("http://[::]/")
|
||||
})
|
||||
|
||||
It("rejects http://[::ffff:127.0.0.1]/ (IPv4-mapped IPv6 loopback)", func() {
|
||||
rejectsTarget("http://[::ffff:127.0.0.1]/")
|
||||
})
|
||||
|
||||
It("rejects http://[::ffff:10.0.0.1]/ (IPv4-mapped IPv6 RFC1918)", func() {
|
||||
rejectsTarget("http://[::ffff:10.0.0.1]/")
|
||||
})
|
||||
|
||||
It("rejects file:// scheme", func() {
|
||||
rejectsTarget("file:///etc/passwd")
|
||||
})
|
||||
|
||||
It("rejects gopher:// scheme", func() {
|
||||
rejectsTarget("gopher://attacker.example.com:1234/")
|
||||
})
|
||||
|
||||
It("rejects ftp:// scheme", func() {
|
||||
rejectsTarget("ftp://example.com/")
|
||||
})
|
||||
|
||||
It("rejects http://localhost/", func() {
|
||||
rejectsTarget("http://localhost/")
|
||||
})
|
||||
|
||||
It("rejects http://127.0.0.1/", func() {
|
||||
rejectsTarget("http://127.0.0.1/")
|
||||
})
|
||||
|
||||
It("rejects http://10.0.0.1/", func() {
|
||||
rejectsTarget("http://10.0.0.1/")
|
||||
})
|
||||
|
||||
It("rejects http://169.254.169.254/ (cloud metadata)", func() {
|
||||
rejectsTarget("http://169.254.169.254/latest/meta-data/")
|
||||
})
|
||||
|
||||
It("rejects http://metadata.google.internal/", func() {
|
||||
rejectsTarget("http://metadata.google.internal/computeMetadata/v1/")
|
||||
})
|
||||
|
||||
It("rejects requests with no url parameter", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/cors-proxy", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
})
|
||||
|
||||
// Sanity: confirm the test runner machine resolves 0.0.0.0 to itself —
|
||||
// otherwise the test could pass for the wrong reason.
|
||||
It("baseline: 0.0.0.0 is classified as unspecified by Go stdlib", func() {
|
||||
ip := net.ParseIP("0.0.0.0")
|
||||
Expect(ip.IsUnspecified()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("baseline: :: is classified as unspecified by Go stdlib", func() {
|
||||
ip := net.ParseIP("::")
|
||||
Expect(ip.IsUnspecified()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
@@ -463,15 +462,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// per-model MCP servers in the same chat session by design).
|
||||
assistantMode := mcpTools.LocalAIAssistantFromMetadata(input.Metadata)
|
||||
if assistantMode {
|
||||
// Defense-in-depth admin gate: the chat route is feature-gated
|
||||
// (FeatureChat), but the assistant tools mutate server state, so
|
||||
// re-check role here even when the deployment chose to skip
|
||||
// FeatureLocalAIAssistant on the route.
|
||||
if startupOptions.Auth.Enabled {
|
||||
user := auth.GetUser(c)
|
||||
if user == nil || user.Role != auth.RoleAdmin {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "localai_assistant requires admin")
|
||||
}
|
||||
if err := requireAssistantAccess(c, startupOptions.Auth.Enabled); err != nil {
|
||||
return err
|
||||
}
|
||||
// Read the disable flag live: an admin can flip it via /api/settings
|
||||
// and the next request must see the change without a restart.
|
||||
|
||||
25
core/http/endpoints/openai/chat_assistant_gate.go
Normal file
25
core/http/endpoints/openai/chat_assistant_gate.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
)
|
||||
|
||||
// requireAssistantAccess gates a chat request that asked for the LocalAI
|
||||
// Assistant tool surface (metadata.localai_assistant=true). The assistant's
|
||||
// in-process MCP server can install models, edit configs and trigger backend
|
||||
// upgrades, so it must be admin-only — the chat route itself only enforces
|
||||
// FeatureChat (default-on for every user). When auth is disabled the gate is
|
||||
// a no-op; the operator already chose to trust every caller.
|
||||
func requireAssistantAccess(c echo.Context, authEnabled bool) error {
|
||||
if !authEnabled {
|
||||
return nil
|
||||
}
|
||||
user := auth.GetUser(c)
|
||||
if user == nil || user.Role != auth.RoleAdmin {
|
||||
return echo.NewHTTPError(http.StatusForbidden, "localai_assistant requires admin")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
82
core/http/endpoints/openai/chat_assistant_gate_test.go
Normal file
82
core/http/endpoints/openai/chat_assistant_gate_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// A non-admin caller who asks for metadata.localai_assistant=true must be
|
||||
// refused. The assistant's in-process MCP server can install/delete models,
|
||||
// edit configs, and rebrand the server, so a user-level caller driving it
|
||||
// via prompt-injected tool calls would be a confused deputy.
|
||||
var _ = Describe("requireAssistantAccess", func() {
|
||||
var (
|
||||
e *echo.Echo
|
||||
c echo.Context
|
||||
)
|
||||
|
||||
makeContext := func() (echo.Context, *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
return e.NewContext(req, rec), rec
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
e = echo.New()
|
||||
c, _ = makeContext()
|
||||
})
|
||||
|
||||
Context("when auth is disabled", func() {
|
||||
It("admits any caller", func() {
|
||||
Expect(requireAssistantAccess(c, false)).To(BeNil())
|
||||
})
|
||||
|
||||
It("admits even when no user is in context", func() {
|
||||
Expect(requireAssistantAccess(c, false)).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("when auth is enabled", func() {
|
||||
It("rejects an unauthenticated caller with 403", func() {
|
||||
err := requireAssistantAccess(c, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
httpErr, ok := err.(*echo.HTTPError)
|
||||
Expect(ok).To(BeTrue(), "expected echo.HTTPError, got %T", err)
|
||||
Expect(httpErr.Code).To(Equal(http.StatusForbidden))
|
||||
Expect(httpErr.Message).To(ContainSubstring("admin"))
|
||||
})
|
||||
|
||||
It("rejects a regular user with 403", func() {
|
||||
c.Set("auth_user", &auth.User{ID: "u-1", Role: auth.RoleUser})
|
||||
err := requireAssistantAccess(c, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
httpErr, ok := err.(*echo.HTTPError)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(httpErr.Code).To(Equal(http.StatusForbidden))
|
||||
})
|
||||
|
||||
It("rejects a user with empty role with 403", func() {
|
||||
c.Set("auth_user", &auth.User{ID: "u-2", Role: ""})
|
||||
err := requireAssistantAccess(c, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("admits an admin", func() {
|
||||
c.Set("auth_user", &auth.User{ID: "admin-1", Role: auth.RoleAdmin})
|
||||
Expect(requireAssistantAccess(c, true)).To(BeNil())
|
||||
})
|
||||
|
||||
It("admits a synthetic admin from a legacy API key", func() {
|
||||
// Legacy API key callers get a synthetic admin user from the
|
||||
// auth middleware. They must continue to work — that's the
|
||||
// shape every existing single-key deployment has today.
|
||||
c.Set("auth_user", &auth.User{ID: "legacy-api-key", Role: auth.RoleAdmin})
|
||||
Expect(requireAssistantAccess(c, true)).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
34
core/http/middleware/forwarded_prefix.go
Normal file
34
core/http/middleware/forwarded_prefix.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package middleware
|
||||
|
||||
import "strings"
|
||||
|
||||
// SafeForwardedPrefix validates an X-Forwarded-Prefix header value before we
|
||||
// concatenate it into a redirect target or use it for path stripping. An
|
||||
// untrusted value like "//evil.com" or "http://evil.com" turns the
|
||||
// reverse-proxy support into an open redirect.
|
||||
//
|
||||
// Returns the trimmed, validated value and true on success; "" and false
|
||||
// when the value is unsafe and should be ignored.
|
||||
func SafeForwardedPrefix(raw string) (string, bool) {
|
||||
s := strings.TrimSpace(raw)
|
||||
if s == "" {
|
||||
return "", false
|
||||
}
|
||||
// Must be a path: starts with a single '/' and doesn't begin a
|
||||
// protocol-relative URL.
|
||||
if !strings.HasPrefix(s, "/") || strings.HasPrefix(s, "//") {
|
||||
return "", false
|
||||
}
|
||||
// Backslashes are interpreted as forward slashes by some clients but
|
||||
// not by Echo's router; reject to avoid bypasses.
|
||||
if strings.ContainsAny(s, "\\") {
|
||||
return "", false
|
||||
}
|
||||
// No control characters or whitespace inside the path.
|
||||
for _, c := range s {
|
||||
if c < 0x20 || c == 0x7f {
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
return s, true
|
||||
}
|
||||
50
core/http/middleware/forwarded_prefix_test.go
Normal file
50
core/http/middleware/forwarded_prefix_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// SafeForwardedPrefix gates every X-Forwarded-Prefix consumer (the path
|
||||
// stripper and the SPA-shell redirect helpers). The threat is operator-
|
||||
// trusted reverse-proxy headers being forgeable via a misconfigured chain;
|
||||
// any value the attacker can inject must not escape the local origin.
|
||||
var _ = Describe("SafeForwardedPrefix", func() {
|
||||
DescribeTable("accepts well-formed path prefixes",
|
||||
func(in string) {
|
||||
out, ok := SafeForwardedPrefix(in)
|
||||
Expect(ok).To(BeTrue(), "expected %q to validate", in)
|
||||
Expect(out).To(Equal(in))
|
||||
},
|
||||
Entry("simple", "/api"),
|
||||
Entry("nested", "/api/v1"),
|
||||
Entry("trailing slash", "/api/"),
|
||||
Entry("with hyphens and dots", "/api-v1.beta"),
|
||||
)
|
||||
|
||||
DescribeTable("rejects values that would escape the origin",
|
||||
func(in string) {
|
||||
_, ok := SafeForwardedPrefix(in)
|
||||
Expect(ok).To(BeFalse(), "expected %q to be rejected", in)
|
||||
},
|
||||
Entry("empty", ""),
|
||||
Entry("whitespace only", " "),
|
||||
Entry("protocol-relative", "//evil.com"),
|
||||
Entry("protocol-relative with path", "//evil.com/x"),
|
||||
Entry("absolute http URL", "http://evil.com"),
|
||||
Entry("absolute https URL", "https://evil.com/x"),
|
||||
Entry("javascript scheme", "javascript:alert(1)"),
|
||||
Entry("data scheme", "data:text/html,foo"),
|
||||
Entry("missing leading slash", "api"),
|
||||
Entry("backslash injection", "/foo\\evil.com"),
|
||||
Entry("CR injection", "/foo\rLocation: //evil.com"),
|
||||
Entry("LF injection", "/foo\nSet-Cookie: x=y"),
|
||||
Entry("NUL byte", "/foo\x00bar"),
|
||||
)
|
||||
|
||||
It("trims surrounding whitespace before validating", func() {
|
||||
out, ok := SafeForwardedPrefix(" /api ")
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(out).To(Equal("/api"))
|
||||
})
|
||||
})
|
||||
53
core/http/middleware/security_headers.go
Normal file
53
core/http/middleware/security_headers.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"html"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// SecurityHeaders sets headers that limit the blast radius of any XSS bug
|
||||
// that slips through. The CSP keeps script-src permissive because the Vite
|
||||
// bundle relies on inline + eval'd scripts; tightening it requires moving
|
||||
// to a nonce-based policy.
|
||||
func SecurityHeaders() echo.MiddlewareFunc {
|
||||
const csp = "default-src 'self'; " +
|
||||
"script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; " +
|
||||
"style-src 'self' 'unsafe-inline'; " +
|
||||
"img-src 'self' data: blob: https:; " +
|
||||
"media-src 'self' data: blob:; " +
|
||||
"font-src 'self' data:; " +
|
||||
"connect-src 'self' ws: wss: https:; " +
|
||||
"frame-src 'self' blob:; " +
|
||||
"worker-src 'self' blob:; " +
|
||||
"object-src 'none'; " +
|
||||
"base-uri 'self'; " +
|
||||
"form-action 'self'; " +
|
||||
"frame-ancestors 'self'"
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
h := c.Response().Header()
|
||||
if h.Get("Content-Security-Policy") == "" {
|
||||
h.Set("Content-Security-Policy", csp)
|
||||
}
|
||||
if h.Get("X-Content-Type-Options") == "" {
|
||||
h.Set("X-Content-Type-Options", "nosniff")
|
||||
}
|
||||
if h.Get("X-Frame-Options") == "" {
|
||||
h.Set("X-Frame-Options", "SAMEORIGIN")
|
||||
}
|
||||
if h.Get("Referrer-Policy") == "" {
|
||||
h.Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SecureBaseHref escapes a base URL value for safe interpolation into a
|
||||
// `<base href="...">` attribute. baseURL is built from Host /
|
||||
// X-Forwarded-Host, both attacker-controllable on most reverse-proxy setups.
|
||||
func SecureBaseHref(s string) string {
|
||||
return html.EscapeString(s)
|
||||
}
|
||||
113
core/http/middleware/security_headers_test.go
Normal file
113
core/http/middleware/security_headers_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("SecurityHeaders", func() {
|
||||
var e *echo.Echo
|
||||
|
||||
BeforeEach(func() {
|
||||
e = echo.New()
|
||||
e.Use(SecurityHeaders())
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
})
|
||||
|
||||
It("sets Content-Security-Policy", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
csp := rec.Header().Get("Content-Security-Policy")
|
||||
Expect(csp).ToNot(BeEmpty())
|
||||
Expect(csp).To(ContainSubstring("default-src 'self'"))
|
||||
Expect(csp).To(ContainSubstring("frame-ancestors 'self'"))
|
||||
Expect(csp).To(ContainSubstring("object-src 'none'"))
|
||||
Expect(csp).To(ContainSubstring("base-uri 'self'"))
|
||||
})
|
||||
|
||||
It("sets X-Content-Type-Options: nosniff", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
Expect(rec.Header().Get("X-Content-Type-Options")).To(Equal("nosniff"))
|
||||
})
|
||||
|
||||
It("sets X-Frame-Options: SAMEORIGIN", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
Expect(rec.Header().Get("X-Frame-Options")).To(Equal("SAMEORIGIN"))
|
||||
})
|
||||
|
||||
It("sets Referrer-Policy", func() {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
Expect(rec.Header().Get("Referrer-Policy")).To(Equal("strict-origin-when-cross-origin"))
|
||||
})
|
||||
|
||||
It("does not overwrite a header a later handler set explicitly", func() {
|
||||
// Reset router so we can install a handler that sets CSP itself.
|
||||
e = echo.New()
|
||||
e.Use(SecurityHeaders())
|
||||
e.GET("/", func(c echo.Context) error {
|
||||
c.Response().Header().Set("Content-Security-Policy", "default-src 'none'")
|
||||
return c.String(http.StatusOK, "ok")
|
||||
})
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
// The middleware runs first (sets default), but a later handler may
|
||||
// want a tighter CSP for a specific response. The middleware should
|
||||
// only set headers that aren't already present — but since Echo
|
||||
// middleware runs around the handler, the middleware's Set calls
|
||||
// happen before the handler runs. So this is more of a smoke test
|
||||
// that the middleware doesn't actively clobber on the way out.
|
||||
Expect(rec.Header().Get("Content-Security-Policy")).To(Equal("default-src 'none'"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("SecureBaseHref", func() {
|
||||
It("escapes attribute-breaking characters", func() {
|
||||
out := SecureBaseHref(`"><script>alert(1)</script>`)
|
||||
Expect(out).ToNot(ContainSubstring(`"`))
|
||||
Expect(out).ToNot(ContainSubstring("<"))
|
||||
Expect(out).ToNot(ContainSubstring(">"))
|
||||
Expect(out).To(ContainSubstring("script"))
|
||||
})
|
||||
|
||||
It("escapes ampersands", func() {
|
||||
Expect(SecureBaseHref("https://example.com/?a=1&b=2")).
|
||||
To(Equal("https://example.com/?a=1&b=2"))
|
||||
})
|
||||
|
||||
It("escapes single quotes", func() {
|
||||
Expect(SecureBaseHref(`x' onload='alert(1)`)).
|
||||
To(ContainSubstring("'"))
|
||||
})
|
||||
|
||||
It("leaves benign URLs alone", func() {
|
||||
Expect(SecureBaseHref("https://example.com/app/")).
|
||||
To(Equal("https://example.com/app/"))
|
||||
})
|
||||
|
||||
It("encloses safely inside double-quoted attribute", func() {
|
||||
// The realistic attack: attacker sets X-Forwarded-Host: foo.com" onload="x.
|
||||
// Confirm the escaped form can't break out of the surrounding quotes.
|
||||
hostile := `foo.com" onload="alert(1)`
|
||||
out := SecureBaseHref(hostile)
|
||||
Expect(out).ToNot(ContainSubstring(`"`))
|
||||
// Wrapped in attribute context — no raw quote means no breakout.
|
||||
full := `<base href="` + out + `" />`
|
||||
Expect(strings.Count(full, `"`)).To(Equal(2))
|
||||
})
|
||||
})
|
||||
@@ -16,6 +16,11 @@ func StripPathPrefix() echo.MiddlewareFunc {
|
||||
originalPath := c.Request().URL.Path
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
validated, ok := SafeForwardedPrefix(prefix)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
prefix = validated
|
||||
if prefix != "" {
|
||||
normalizedPrefix := prefix
|
||||
if !strings.HasSuffix(prefix, "/") {
|
||||
|
||||
@@ -85,6 +85,29 @@ func initializeTracing(maxItems int) {
|
||||
doInitializeTracing()
|
||||
}
|
||||
|
||||
// sensitiveTraceHeaders is the set of header names whose values must not
|
||||
// land in the in-memory trace buffer. Keys are canonical — http.Header
|
||||
// stores them that way, so range yields canonical keys directly.
|
||||
var sensitiveTraceHeaders = map[string]struct{}{
|
||||
"Authorization": {},
|
||||
"Proxy-Authorization": {},
|
||||
"Cookie": {},
|
||||
"Set-Cookie": {},
|
||||
"X-Api-Key": {},
|
||||
"Xi-Api-Key": {},
|
||||
"X-Auth-Token": {},
|
||||
}
|
||||
|
||||
func redactSensitiveHeaders(h http.Header) http.Header {
|
||||
out := h.Clone()
|
||||
for k := range out {
|
||||
if _, ok := sensitiveTraceHeaders[k]; ok {
|
||||
out[k] = []string{"[redacted]"}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// TraceMiddleware intercepts and logs JSON API requests and responses
|
||||
func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
@@ -130,11 +153,15 @@ func TraceMiddleware(app *application.Application) echo.MiddlewareFunc {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
// Create exchange log (always, even on error)
|
||||
requestHeaders := c.Request().Header.Clone()
|
||||
// Create exchange log (always, even on error). Sensitive headers
|
||||
// (Authorization, API keys, cookies) are redacted before storage —
|
||||
// the trace endpoint is admin-only but the buffer is also reachable
|
||||
// via any heap-dump-style introspection, and tokens shouldn't
|
||||
// outlive the request that carried them.
|
||||
requestHeaders := redactSensitiveHeaders(c.Request().Header)
|
||||
requestBody := make([]byte, len(body))
|
||||
copy(requestBody, body)
|
||||
responseHeaders := c.Response().Header().Clone()
|
||||
responseHeaders := redactSensitiveHeaders(c.Response().Header())
|
||||
responseBody := make([]byte, resBody.Len())
|
||||
copy(responseBody, resBody.Bytes())
|
||||
exchange := APIExchange{
|
||||
|
||||
66
core/http/middleware/trace_redact_test.go
Normal file
66
core/http/middleware/trace_redact_test.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// The trace buffer feeds an admin-only /api/traces endpoint. Even so, it
|
||||
// must not retain the request's auth headers — once captured they outlive
|
||||
// the request, get serialised to JSON for the dashboard, and could leak via
|
||||
// any heap inspection. This pins the redaction contract so a future refactor
|
||||
// of TraceMiddleware can't silently regress it.
|
||||
var _ = Describe("redactSensitiveHeaders", func() {
|
||||
It("redacts Authorization", func() {
|
||||
h := http.Header{}
|
||||
h.Set("Authorization", "Bearer sk-secret-1234567890")
|
||||
out := redactSensitiveHeaders(h)
|
||||
Expect(out.Get("Authorization")).To(Equal("[redacted]"))
|
||||
Expect(out.Get("Authorization")).ToNot(ContainSubstring("sk-secret"))
|
||||
})
|
||||
|
||||
It("redacts Proxy-Authorization", func() {
|
||||
h := http.Header{}
|
||||
h.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz")
|
||||
Expect(redactSensitiveHeaders(h).Get("Proxy-Authorization")).To(Equal("[redacted]"))
|
||||
})
|
||||
|
||||
It("redacts Cookie and Set-Cookie", func() {
|
||||
h := http.Header{}
|
||||
h.Set("Cookie", "session=abc123; csrf=xyz")
|
||||
h.Set("Set-Cookie", "session=newvalue; HttpOnly")
|
||||
out := redactSensitiveHeaders(h)
|
||||
Expect(out.Get("Cookie")).To(Equal("[redacted]"))
|
||||
Expect(out.Get("Set-Cookie")).To(Equal("[redacted]"))
|
||||
})
|
||||
|
||||
It("redacts X-Api-Key (and case variants)", func() {
|
||||
h := http.Header{}
|
||||
h.Set("X-Api-Key", "key-1")
|
||||
h.Set("xi-api-key", "key-2")
|
||||
out := redactSensitiveHeaders(h)
|
||||
Expect(out.Get("X-Api-Key")).To(Equal("[redacted]"))
|
||||
Expect(out.Get("Xi-Api-Key")).To(Equal("[redacted]"))
|
||||
})
|
||||
|
||||
It("preserves benign headers", func() {
|
||||
h := http.Header{}
|
||||
h.Set("Content-Type", "application/json")
|
||||
h.Set("User-Agent", "curl/8.0")
|
||||
h.Set("Accept", "*/*")
|
||||
out := redactSensitiveHeaders(h)
|
||||
Expect(out.Get("Content-Type")).To(Equal("application/json"))
|
||||
Expect(out.Get("User-Agent")).To(Equal("curl/8.0"))
|
||||
Expect(out.Get("Accept")).To(Equal("*/*"))
|
||||
})
|
||||
|
||||
It("does not mutate the input header", func() {
|
||||
h := http.Header{}
|
||||
h.Set("Authorization", "Bearer abc")
|
||||
_ = redactSensitiveHeaders(h)
|
||||
Expect(h.Get("Authorization")).To(Equal("Bearer abc"),
|
||||
"redactSensitiveHeaders must operate on a clone — caller's header must be untouched")
|
||||
})
|
||||
})
|
||||
@@ -13,6 +13,8 @@
|
||||
"@codemirror/search": "^6.5.10",
|
||||
"@codemirror/state": "^6.5.2",
|
||||
"@codemirror/view": "^6.36.8",
|
||||
"@fontsource-variable/geist": "^5.2.8",
|
||||
"@fontsource-variable/geist-mono": "^5.2.7",
|
||||
"@fortawesome/fontawesome-free": "^6.7.2",
|
||||
"@lezer/highlight": "^1.2.1",
|
||||
"@modelcontextprotocol/ext-apps": "^1.2.2",
|
||||
@@ -169,6 +171,10 @@
|
||||
|
||||
"@eslint/plugin-kit": ["@eslint/plugin-kit@0.4.1", "", { "dependencies": { "@eslint/core": "^0.17.0", "levn": "^0.4.1" } }, "sha512-43/qtrDUokr7LJqoF2c3+RInu/t4zfrpYdoSDfYyhg52rwLV6TnOvdG4fXm7IkSB3wErkcmJS9iEhjVtOSEjjA=="],
|
||||
|
||||
"@fontsource-variable/geist": ["@fontsource-variable/geist@5.2.8", "", {}, "sha512-cJ6m9e+8MQ5dCYJsLylfZrgBh6KkG4bOLckB35Tr9J/EqdkEM6QllH5PxqP1dhTvFup+HtMRPuz9xOjxXJggxw=="],
|
||||
|
||||
"@fontsource-variable/geist-mono": ["@fontsource-variable/geist-mono@5.2.7", "", {}, "sha512-ZKlZ5sjtalb2TwXKs400mAGDlt/+2ENLNySPx0wTz3bP3mWARCsUW+rpxzZc7e05d2qGch70pItt3K4qttbIYA=="],
|
||||
|
||||
"@fortawesome/fontawesome-free": ["@fortawesome/fontawesome-free@6.7.2", "", {}, "sha512-JUOtgFW6k9u4Y+xeIaEiLr3+cjoUPiAuLXoyKOJSia6Duzb7pq+A76P9ZdPDoAoxHdHzq6gE9/jKBGXlZT8FbA=="],
|
||||
|
||||
"@gulpjs/to-absolute-glob": ["@gulpjs/to-absolute-glob@4.0.0", "", { "dependencies": { "is-negated-glob": "^1.0.0" } }, "sha512-kjotm7XJrJ6v+7knhPaRgaT6q8F8K2jiafwYdNHLzmV0uGLuZY43FK6smNSHUPrhq5kX2slCUy+RGG/xGqmIKA=="],
|
||||
|
||||
@@ -5,9 +5,6 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>LocalAI</title>
|
||||
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com" />
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin />
|
||||
<link href="https://fonts.googleapis.com/css2?family=Geist:wght@300..700&family=Geist+Mono:wght@300..700&display=swap" rel="stylesheet" />
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
|
||||
@@ -21,6 +21,8 @@
|
||||
"@codemirror/search": "^6.5.10",
|
||||
"@codemirror/state": "^6.5.2",
|
||||
"@codemirror/view": "^6.36.8",
|
||||
"@fontsource-variable/geist": "^5.2.8",
|
||||
"@fontsource-variable/geist-mono": "^5.2.7",
|
||||
"@fortawesome/fontawesome-free": "^6.7.2",
|
||||
"@lezer/highlight": "^1.2.1",
|
||||
"@modelcontextprotocol/ext-apps": "^1.2.2",
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"namePlaceholder": "Ihr Name (optional)",
|
||||
"password": "Passwort",
|
||||
"passwordPlaceholder": "Passwort eingeben...",
|
||||
"newPasswordPlaceholder": "Mindestens 8 Zeichen",
|
||||
"newPasswordPlaceholder": "Mindestens 12 Zeichen",
|
||||
"confirmPassword": "Passwort bestätigen",
|
||||
"confirmPasswordPlaceholder": "Passwort wiederholen",
|
||||
"inviteCodeLabel": "Einladungscode",
|
||||
@@ -70,7 +70,7 @@
|
||||
"currentPasswordDescription": "Geben Sie Ihr aktuelles Passwort zur Bestätigung Ihrer Identität ein",
|
||||
"currentPasswordPlaceholder": "Aktuelles Passwort",
|
||||
"newPassword": "Neues Passwort",
|
||||
"newPasswordDescription": "Muss mindestens 8 Zeichen haben",
|
||||
"newPasswordDescription": "Muss mindestens 12 Zeichen haben",
|
||||
"newPasswordPlaceholder": "Neues Passwort",
|
||||
"confirmPassword": "Passwort bestätigen",
|
||||
"confirmPasswordDescription": "Geben Sie Ihr neues Passwort erneut ein",
|
||||
@@ -79,7 +79,7 @@
|
||||
"changing": "Wird geändert...",
|
||||
"changed": "Passwort geändert",
|
||||
"passwordsDoNotMatch": "Passwörter stimmen nicht überein",
|
||||
"tooShort": "Das neue Passwort muss mindestens 8 Zeichen haben",
|
||||
"tooShort": "Das neue Passwort muss mindestens 12 Zeichen haben",
|
||||
"oauthOnly": "Passwortverwaltung ist für {{provider}}-Konten nicht verfügbar."
|
||||
},
|
||||
"apiKeys": {
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"namePlaceholder": "Your name (optional)",
|
||||
"password": "Password",
|
||||
"passwordPlaceholder": "Enter password...",
|
||||
"newPasswordPlaceholder": "At least 8 characters",
|
||||
"newPasswordPlaceholder": "At least 12 characters",
|
||||
"confirmPassword": "Confirm Password",
|
||||
"confirmPasswordPlaceholder": "Repeat password",
|
||||
"inviteCodeLabel": "Invite Code",
|
||||
@@ -70,7 +70,7 @@
|
||||
"currentPasswordDescription": "Enter your existing password to verify your identity",
|
||||
"currentPasswordPlaceholder": "Current password",
|
||||
"newPassword": "New password",
|
||||
"newPasswordDescription": "Must be at least 8 characters",
|
||||
"newPasswordDescription": "Must be at least 12 characters",
|
||||
"newPasswordPlaceholder": "New password",
|
||||
"confirmPassword": "Confirm password",
|
||||
"confirmPasswordDescription": "Re-enter your new password",
|
||||
@@ -79,7 +79,7 @@
|
||||
"changing": "Changing...",
|
||||
"changed": "Password changed",
|
||||
"passwordsDoNotMatch": "Passwords do not match",
|
||||
"tooShort": "New password must be at least 8 characters",
|
||||
"tooShort": "New password must be at least 12 characters",
|
||||
"oauthOnly": "Password management is not available for {{provider}} accounts."
|
||||
},
|
||||
"apiKeys": {
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"namePlaceholder": "Tu nombre (opcional)",
|
||||
"password": "Contraseña",
|
||||
"passwordPlaceholder": "Introduce la contraseña...",
|
||||
"newPasswordPlaceholder": "Al menos 8 caracteres",
|
||||
"newPasswordPlaceholder": "Al menos 12 caracteres",
|
||||
"confirmPassword": "Confirmar contraseña",
|
||||
"confirmPasswordPlaceholder": "Repite la contraseña",
|
||||
"inviteCodeLabel": "Código de invitación",
|
||||
@@ -70,7 +70,7 @@
|
||||
"currentPasswordDescription": "Introduce tu contraseña actual para verificar tu identidad",
|
||||
"currentPasswordPlaceholder": "Contraseña actual",
|
||||
"newPassword": "Nueva contraseña",
|
||||
"newPasswordDescription": "Debe tener al menos 8 caracteres",
|
||||
"newPasswordDescription": "Debe tener al menos 12 caracteres",
|
||||
"newPasswordPlaceholder": "Nueva contraseña",
|
||||
"confirmPassword": "Confirmar contraseña",
|
||||
"confirmPasswordDescription": "Vuelve a introducir tu nueva contraseña",
|
||||
@@ -79,7 +79,7 @@
|
||||
"changing": "Cambiando...",
|
||||
"changed": "Contraseña cambiada",
|
||||
"passwordsDoNotMatch": "Las contraseñas no coinciden",
|
||||
"tooShort": "La nueva contraseña debe tener al menos 8 caracteres",
|
||||
"tooShort": "La nueva contraseña debe tener al menos 12 caracteres",
|
||||
"oauthOnly": "La gestión de contraseña no está disponible para cuentas {{provider}}."
|
||||
},
|
||||
"apiKeys": {
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"namePlaceholder": "Il tuo nome (opzionale)",
|
||||
"password": "Password",
|
||||
"passwordPlaceholder": "Inserisci la password...",
|
||||
"newPasswordPlaceholder": "Almeno 8 caratteri",
|
||||
"newPasswordPlaceholder": "Almeno 12 caratteri",
|
||||
"confirmPassword": "Conferma password",
|
||||
"confirmPasswordPlaceholder": "Ripeti la password",
|
||||
"inviteCodeLabel": "Codice invito",
|
||||
@@ -70,7 +70,7 @@
|
||||
"currentPasswordDescription": "Inserisci la tua password attuale per verificare la tua identità",
|
||||
"currentPasswordPlaceholder": "Password attuale",
|
||||
"newPassword": "Nuova password",
|
||||
"newPasswordDescription": "Deve avere almeno 8 caratteri",
|
||||
"newPasswordDescription": "Deve avere almeno 12 caratteri",
|
||||
"newPasswordPlaceholder": "Nuova password",
|
||||
"confirmPassword": "Conferma password",
|
||||
"confirmPasswordDescription": "Reinserisci la nuova password",
|
||||
@@ -79,7 +79,7 @@
|
||||
"changing": "Modifica in corso...",
|
||||
"changed": "Password modificata",
|
||||
"passwordsDoNotMatch": "Le password non coincidono",
|
||||
"tooShort": "La nuova password deve avere almeno 8 caratteri",
|
||||
"tooShort": "La nuova password deve avere almeno 12 caratteri",
|
||||
"oauthOnly": "La gestione della password non è disponibile per gli account {{provider}}."
|
||||
},
|
||||
"apiKeys": {
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"namePlaceholder": "您的姓名(可选)",
|
||||
"password": "密码",
|
||||
"passwordPlaceholder": "输入密码...",
|
||||
"newPasswordPlaceholder": "至少 8 个字符",
|
||||
"newPasswordPlaceholder": "至少 12 个字符",
|
||||
"confirmPassword": "确认密码",
|
||||
"confirmPasswordPlaceholder": "再次输入密码",
|
||||
"inviteCodeLabel": "邀请码",
|
||||
@@ -70,7 +70,7 @@
|
||||
"currentPasswordDescription": "输入您的现有密码以验证身份",
|
||||
"currentPasswordPlaceholder": "当前密码",
|
||||
"newPassword": "新密码",
|
||||
"newPasswordDescription": "至少需要 8 个字符",
|
||||
"newPasswordDescription": "至少需要 12 个字符",
|
||||
"newPasswordPlaceholder": "新密码",
|
||||
"confirmPassword": "确认密码",
|
||||
"confirmPasswordDescription": "再次输入新密码",
|
||||
@@ -79,7 +79,7 @@
|
||||
"changing": "正在更改...",
|
||||
"changed": "密码已更改",
|
||||
"passwordsDoNotMatch": "密码不匹配",
|
||||
"tooShort": "新密码至少需要 8 个字符",
|
||||
"tooShort": "新密码至少需要 12 个字符",
|
||||
"oauthOnly": "{{provider}} 账户不支持密码管理。"
|
||||
},
|
||||
"apiKeys": {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useState, useEffect, useRef } from 'react'
|
||||
import { renderMarkdown } from '../utils/markdown'
|
||||
import { getArtifactIcon } from '../utils/artifacts'
|
||||
import { safeHref } from '../utils/url'
|
||||
import DOMPurify from 'dompurify'
|
||||
import hljs from 'highlight.js'
|
||||
|
||||
@@ -70,7 +71,7 @@ export default function CanvasPanel({ artifacts, selectedId, onSelect, onClose }
|
||||
return (
|
||||
<div className="canvas-url-card">
|
||||
<i className="fas fa-external-link-alt" />
|
||||
<a href={current.url} target="_blank" rel="noopener noreferrer">{current.url}</a>
|
||||
<a href={safeHref(current.url)} target="_blank" rel="noopener noreferrer">{current.url}</a>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -78,7 +79,7 @@ export default function CanvasPanel({ artifacts, selectedId, onSelect, onClose }
|
||||
return (
|
||||
<div className="canvas-url-card">
|
||||
<i className="fas fa-file" />
|
||||
<a href={current.url} target="_blank" rel="noopener noreferrer" download={current.title}>{current.title}</a>
|
||||
<a href={safeHref(current.url)} target="_blank" rel="noopener noreferrer" download={current.title}>{current.title}</a>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import { AuthProvider } from './context/AuthContext'
|
||||
import { router } from './router'
|
||||
import './i18n'
|
||||
import '@fortawesome/fontawesome-free/css/all.min.css'
|
||||
import '@fontsource-variable/geist'
|
||||
import '@fontsource-variable/geist-mono'
|
||||
import './index.css'
|
||||
import './theme.css'
|
||||
import './App.css'
|
||||
|
||||
@@ -131,6 +131,8 @@ function SecurityTab({ addToast }) {
|
||||
const [newPw, setNewPw] = useState('')
|
||||
const [confirmPw, setConfirmPw] = useState('')
|
||||
const [saving, setSaving] = useState(false)
|
||||
const [weakWarning, setWeakWarning] = useState('')
|
||||
const [acknowledgeWeak, setAcknowledgeWeak] = useState(false)
|
||||
|
||||
const handleSubmit = async (e) => {
|
||||
e.preventDefault()
|
||||
@@ -138,19 +140,21 @@ function SecurityTab({ addToast }) {
|
||||
addToast(t('account.security.passwordsDoNotMatch'), 'error')
|
||||
return
|
||||
}
|
||||
if (newPw.length < 8) {
|
||||
addToast(t('account.security.tooShort'), 'error')
|
||||
return
|
||||
}
|
||||
setSaving(true)
|
||||
try {
|
||||
await profileApi.changePassword(currentPw, newPw)
|
||||
await profileApi.changePassword(currentPw, newPw, acknowledgeWeak)
|
||||
addToast(t('account.security.changed'), 'success')
|
||||
setCurrentPw('')
|
||||
setNewPw('')
|
||||
setConfirmPw('')
|
||||
setWeakWarning('')
|
||||
setAcknowledgeWeak(false)
|
||||
} catch (err) {
|
||||
addToast(err.message, 'error')
|
||||
if (err.body?.overridable) {
|
||||
setWeakWarning(err.body.error || err.message)
|
||||
} else {
|
||||
addToast(err.message, 'error')
|
||||
}
|
||||
} finally {
|
||||
setSaving(false)
|
||||
}
|
||||
@@ -186,9 +190,12 @@ function SecurityTab({ addToast }) {
|
||||
type="password"
|
||||
className="input account-input-sm"
|
||||
value={newPw}
|
||||
onChange={(e) => setNewPw(e.target.value)}
|
||||
onChange={(e) => {
|
||||
setNewPw(e.target.value)
|
||||
setWeakWarning('')
|
||||
setAcknowledgeWeak(false)
|
||||
}}
|
||||
placeholder={t('account.security.newPasswordPlaceholder')}
|
||||
minLength={8}
|
||||
disabled={saving}
|
||||
required
|
||||
/>
|
||||
@@ -204,6 +211,22 @@ function SecurityTab({ addToast }) {
|
||||
required
|
||||
/>
|
||||
</SettingRow>
|
||||
{weakWarning && (
|
||||
<SettingRow label="" description="">
|
||||
<div role="alert" style={{ display: 'flex', flexDirection: 'column', gap: 'var(--spacing-xs)' }}>
|
||||
<div className="login-warning">{weakWarning}</div>
|
||||
<label style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-xs)', fontSize: '0.875rem' }}>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={acknowledgeWeak}
|
||||
onChange={(e) => setAcknowledgeWeak(e.target.checked)}
|
||||
disabled={saving}
|
||||
/>
|
||||
Use this password anyway
|
||||
</label>
|
||||
</div>
|
||||
</SettingRow>
|
||||
)}
|
||||
</div>
|
||||
<div className="form-actions">
|
||||
<button
|
||||
|
||||
@@ -59,8 +59,7 @@ function AgentActivityGroup({ items }) {
|
||||
{items.map((item, idx) => (
|
||||
<div key={idx} className="chat-activity-item">
|
||||
<span className="chat-activity-item-label">{new Date(item.timestamp).toLocaleTimeString()}</span>
|
||||
<div className="chat-activity-item-content"
|
||||
dangerouslySetInnerHTML={{ __html: item.content }} />
|
||||
<div className="chat-activity-item-content" style={{ whiteSpace: 'pre-wrap', wordBreak: 'break-word' }}>{item.content}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
@@ -619,7 +618,7 @@ export default function AgentChat() {
|
||||
<div className="chat-message-bubble">
|
||||
<div className="chat-message-content">
|
||||
{role === 'user' ? (
|
||||
<div dangerouslySetInnerHTML={{ __html: msg.content.replace(/&/g, '&').replace(/</g, '<').replace(/>/g, '>').replace(/\n/g, '<br>') }} />
|
||||
<div style={{ whiteSpace: 'pre-wrap', wordBreak: 'break-word' }}>{msg.content}</div>
|
||||
) : (
|
||||
<div dangerouslySetInnerHTML={{
|
||||
__html: canvasMode
|
||||
|
||||
@@ -8,6 +8,7 @@ import { useOperations } from '../hooks/useOperations'
|
||||
import { useDistributedMode } from '../hooks/useDistributedMode'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
import { renderMarkdown } from '../utils/markdown'
|
||||
import { safeHref } from '../utils/url'
|
||||
import ConfirmDialog from '../components/ConfirmDialog'
|
||||
import Toggle from '../components/Toggle'
|
||||
import NodeDistributionChip from '../components/NodeDistributionChip'
|
||||
@@ -860,7 +861,7 @@ function BackendDetail({ backend }) {
|
||||
{backend.urls?.length > 0 && (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: '2px' }}>
|
||||
{backend.urls.map((url, i) => (
|
||||
<a key={i} href={url} target="_blank" rel="noopener noreferrer" style={{ fontSize: '0.8125rem', color: 'var(--color-primary)', wordBreak: 'break-all' }}>
|
||||
<a key={i} href={safeHref(url)} target="_blank" rel="noopener noreferrer" style={{ fontSize: '0.8125rem', color: 'var(--color-primary)', wordBreak: 'break-all' }}>
|
||||
<i className="fas fa-external-link-alt" style={{ marginRight: 4, fontSize: '0.6875rem' }} />{url}
|
||||
</a>
|
||||
))}
|
||||
|
||||
@@ -259,7 +259,7 @@ function UserMessageContent({ content, files }) {
|
||||
const text = typeof content === 'string' ? content : content?.[0]?.text || ''
|
||||
return (
|
||||
<>
|
||||
<div dangerouslySetInnerHTML={{ __html: text.replace(/\n/g, '<br>') }} />
|
||||
<div style={{ whiteSpace: 'pre-wrap', wordBreak: 'break-word' }}>{text}</div>
|
||||
{files && files.length > 0 && (
|
||||
<div className="chat-message-files">
|
||||
{files.map((f, i) => (
|
||||
|
||||
@@ -28,6 +28,12 @@ export default function Login() {
|
||||
const [error, setError] = useState('')
|
||||
const [message, setMessage] = useState('')
|
||||
const [submitting, setSubmitting] = useState(false)
|
||||
// weakPasswordWarning is the server-side rejection message for an
|
||||
// overridable weak-password failure. When set, the form shows an
|
||||
// acknowledgement checkbox; ticking it sends acknowledge_weak_password
|
||||
// on the next submit so the server skips the entropy check.
|
||||
const [weakPasswordWarning, setWeakPasswordWarning] = useState('')
|
||||
const [acknowledgeWeakPassword, setAcknowledgeWeakPassword] = useState(false)
|
||||
const [showTokenLogin, setShowTokenLogin] = useState(false)
|
||||
const [token, setToken] = useState('')
|
||||
|
||||
@@ -119,6 +125,9 @@ export default function Login() {
|
||||
if (inviteCode) {
|
||||
body.inviteCode = inviteCode
|
||||
}
|
||||
if (acknowledgeWeakPassword) {
|
||||
body.acknowledge_weak_password = true
|
||||
}
|
||||
|
||||
const res = await fetch(apiUrl('/api/auth/register'), {
|
||||
method: 'POST',
|
||||
@@ -128,10 +137,18 @@ export default function Login() {
|
||||
const data = await res.json()
|
||||
|
||||
if (!res.ok) {
|
||||
setError(extractError(data, t('login.errors.registrationFailed')))
|
||||
if (data && data.overridable) {
|
||||
setWeakPasswordWarning(extractError(data, ''))
|
||||
setError('')
|
||||
} else {
|
||||
setError(extractError(data, t('login.errors.registrationFailed')))
|
||||
setWeakPasswordWarning('')
|
||||
setAcknowledgeWeakPassword(false)
|
||||
}
|
||||
setSubmitting(false)
|
||||
return
|
||||
}
|
||||
setWeakPasswordWarning('')
|
||||
|
||||
if (data.pending) {
|
||||
setMessage(data.message || t('login.messages.registrationPending'))
|
||||
@@ -364,9 +381,13 @@ export default function Login() {
|
||||
className="input"
|
||||
type="password"
|
||||
value={password}
|
||||
onChange={(e) => { setPassword(e.target.value); setError('') }}
|
||||
onChange={(e) => {
|
||||
setPassword(e.target.value)
|
||||
setError('')
|
||||
setWeakPasswordWarning('')
|
||||
setAcknowledgeWeakPassword(false)
|
||||
}}
|
||||
placeholder={t('login.newPasswordPlaceholder')}
|
||||
minLength={8}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
@@ -381,6 +402,19 @@ export default function Login() {
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
{weakPasswordWarning && (
|
||||
<div className="form-group" role="alert">
|
||||
<div className="login-warning">{weakPasswordWarning}</div>
|
||||
<label style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-xs)', marginTop: 'var(--spacing-xs)', fontSize: '0.875rem' }}>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={acknowledgeWeakPassword}
|
||||
onChange={(e) => setAcknowledgeWeakPassword(e.target.checked)}
|
||||
/>
|
||||
Use this password anyway
|
||||
</label>
|
||||
</div>
|
||||
)}
|
||||
<button type="submit" className="btn btn-primary login-btn-full" disabled={submitting}>
|
||||
{submitting
|
||||
? t('login.creatingAccount')
|
||||
|
||||
@@ -14,6 +14,7 @@ import { useModels } from '../hooks/useModels'
|
||||
import { useGalleryEnrichment } from '../hooks/useGalleryEnrichment'
|
||||
import { backendControlApi, modelsApi, backendsApi, systemApi, nodesApi } from '../utils/api'
|
||||
import { renderMarkdown } from '../utils/markdown'
|
||||
import { safeHref } from '../utils/url'
|
||||
import {
|
||||
CAP_CHAT, CAP_COMPLETION, CAP_IMAGE, CAP_VIDEO, CAP_TTS,
|
||||
CAP_TRANSCRIPT, CAP_SOUND_GENERATION, CAP_FACE_RECOGNITION,
|
||||
@@ -1074,7 +1075,7 @@ function ModelDetail({ model, enriched, matchedCaps, distributedMode, onNavigate
|
||||
<dd>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 2 }}>
|
||||
{urls.map((url, i) => (
|
||||
<a key={i} href={url} target="_blank" rel="noopener noreferrer"
|
||||
<a key={i} href={safeHref(url)} target="_blank" rel="noopener noreferrer"
|
||||
style={{ color: 'var(--color-primary)', wordBreak: 'break-all', fontSize: 'var(--text-xs)' }}>
|
||||
<i className="fas fa-external-link-alt" style={{ marginRight: 4, fontSize: '0.625rem' }} />{url}
|
||||
</a>
|
||||
@@ -1160,7 +1161,7 @@ function BackendDetail({ backend, enriched, upgradeInfo, nodes, distributedMode
|
||||
<dd>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 2 }}>
|
||||
{urls.map((url, i) => (
|
||||
<a key={i} href={url} target="_blank" rel="noopener noreferrer"
|
||||
<a key={i} href={safeHref(url)} target="_blank" rel="noopener noreferrer"
|
||||
style={{ color: 'var(--color-primary)', wordBreak: 'break-all', fontSize: 'var(--text-xs)' }}>
|
||||
<i className="fas fa-external-link-alt" style={{ marginRight: 4, fontSize: '0.625rem' }} />{url}
|
||||
</a>
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useState, useCallback, useEffect } from 'react'
|
||||
import { useNavigate, useOutletContext } from 'react-router-dom'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { modelsApi } from '../utils/api'
|
||||
import { safeHref } from '../utils/url'
|
||||
import { useDebouncedCallback } from '../hooks/useDebounce'
|
||||
import { useOperations } from '../hooks/useOperations'
|
||||
import { useResources } from '../hooks/useResources'
|
||||
@@ -637,7 +638,7 @@ function ModelDetail({ model, fit, sizeDisplay, vramDisplay, expandedFiles, setE
|
||||
{model.urls?.length > 0 && (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: '2px' }}>
|
||||
{model.urls.map((url, i) => (
|
||||
<a key={i} href={url} target="_blank" rel="noopener noreferrer" style={{ fontSize: '0.8125rem', color: 'var(--color-primary)', wordBreak: 'break-all' }}>
|
||||
<a key={i} href={safeHref(url)} target="_blank" rel="noopener noreferrer" style={{ fontSize: '0.8125rem', color: 'var(--color-primary)', wordBreak: 'break-all' }}>
|
||||
<i className="fas fa-external-link-alt" style={{ marginRight: 4, fontSize: '0.6875rem' }} />{url}
|
||||
</a>
|
||||
))}
|
||||
|
||||
@@ -671,6 +671,8 @@ export default function Users() {
|
||||
const [passwordResetUser, setPasswordResetUser] = useState(null)
|
||||
const [newPassword, setNewPassword] = useState('')
|
||||
const [resettingPassword, setResettingPassword] = useState(false)
|
||||
const [resetWeakWarning, setResetWeakWarning] = useState('')
|
||||
const [resetAcknowledgeWeak, setResetAcknowledgeWeak] = useState(false)
|
||||
|
||||
const fetchUsers = useCallback(async () => {
|
||||
setLoading(true)
|
||||
@@ -766,18 +768,26 @@ export default function Users() {
|
||||
const handleResetPassword = (u) => {
|
||||
setPasswordResetUser(u)
|
||||
setNewPassword('')
|
||||
setResetWeakWarning('')
|
||||
setResetAcknowledgeWeak(false)
|
||||
}
|
||||
|
||||
const confirmResetPassword = async () => {
|
||||
if (!passwordResetUser || newPassword.length < 8) return
|
||||
if (!passwordResetUser || newPassword.length === 0) return
|
||||
setResettingPassword(true)
|
||||
try {
|
||||
await adminUsersApi.resetPassword(passwordResetUser.id, newPassword)
|
||||
await adminUsersApi.resetPassword(passwordResetUser.id, newPassword, resetAcknowledgeWeak)
|
||||
addToast(`Password reset for ${passwordResetUser.name || passwordResetUser.email}`, 'success')
|
||||
setPasswordResetUser(null)
|
||||
setNewPassword('')
|
||||
setResetWeakWarning('')
|
||||
setResetAcknowledgeWeak(false)
|
||||
} catch (err) {
|
||||
addToast(`Failed to reset password: ${err.message}`, 'error')
|
||||
if (err.body?.overridable) {
|
||||
setResetWeakWarning(err.body.error || err.message)
|
||||
} else {
|
||||
addToast(`Failed to reset password: ${err.message}`, 'error')
|
||||
}
|
||||
} finally {
|
||||
setResettingPassword(false)
|
||||
}
|
||||
@@ -965,18 +975,35 @@ export default function Users() {
|
||||
<input
|
||||
type="password"
|
||||
className="input"
|
||||
placeholder="New password (min 8 characters)"
|
||||
placeholder="New password (min 12 characters)"
|
||||
value={newPassword}
|
||||
onChange={e => setNewPassword(e.target.value)}
|
||||
onKeyDown={e => { if (e.key === 'Enter' && newPassword.length >= 8) confirmResetPassword() }}
|
||||
onChange={e => {
|
||||
setNewPassword(e.target.value)
|
||||
setResetWeakWarning('')
|
||||
setResetAcknowledgeWeak(false)
|
||||
}}
|
||||
onKeyDown={e => { if (e.key === 'Enter' && newPassword.length > 0) confirmResetPassword() }}
|
||||
autoFocus
|
||||
/>
|
||||
{resetWeakWarning && (
|
||||
<div role="alert" style={{ marginTop: 'var(--spacing-sm)', display: 'flex', flexDirection: 'column', gap: 'var(--spacing-xs)' }}>
|
||||
<div className="login-warning">{resetWeakWarning}</div>
|
||||
<label style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-xs)', fontSize: '0.875rem' }}>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={resetAcknowledgeWeak}
|
||||
onChange={e => setResetAcknowledgeWeak(e.target.checked)}
|
||||
/>
|
||||
Use this password anyway
|
||||
</label>
|
||||
</div>
|
||||
)}
|
||||
<div className="perm-modal-actions" style={{ marginTop: 'var(--spacing-md)' }}>
|
||||
<button className="btn btn-secondary" onClick={() => setPasswordResetUser(null)}>Cancel</button>
|
||||
<button
|
||||
className="btn btn-primary"
|
||||
onClick={confirmResetPassword}
|
||||
disabled={resettingPassword || newPassword.length < 8}
|
||||
disabled={resettingPassword || newPassword.length === 0}
|
||||
>
|
||||
{resettingPassword ? 'Resetting...' : 'Reset Password'}
|
||||
</button>
|
||||
|
||||
11
core/http/react-ui/src/utils/api.js
vendored
11
core/http/react-ui/src/utils/api.js
vendored
@@ -450,8 +450,10 @@ export const adminUsersApi = {
|
||||
deleteQuota: (id, quotaId) => fetchJSON(`/api/auth/admin/users/${encodeURIComponent(id)}/quotas/${encodeURIComponent(quotaId)}`, {
|
||||
method: 'DELETE',
|
||||
}),
|
||||
resetPassword: (id, password) => fetchJSON(`/api/auth/admin/users/${encodeURIComponent(id)}/password`, {
|
||||
method: 'PUT', body: JSON.stringify({ password }), headers: { 'Content-Type': 'application/json' },
|
||||
resetPassword: (id, password, acknowledgeWeak = false) => fetchJSON(`/api/auth/admin/users/${encodeURIComponent(id)}/password`, {
|
||||
method: 'PUT',
|
||||
body: JSON.stringify({ password, acknowledge_weak_password: acknowledgeWeak }),
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
}
|
||||
|
||||
@@ -464,8 +466,9 @@ export const profileApi = {
|
||||
updateProfile: (name, avatarUrl) => fetchJSON('/api/auth/profile', {
|
||||
method: 'PUT', body: JSON.stringify({ name, avatar_url: avatarUrl || '' }), headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
changePassword: (currentPassword, newPassword) => fetchJSON('/api/auth/password', {
|
||||
method: 'PUT', body: JSON.stringify({ current_password: currentPassword, new_password: newPassword }),
|
||||
changePassword: (currentPassword, newPassword, acknowledgeWeak = false) => fetchJSON('/api/auth/password', {
|
||||
method: 'PUT',
|
||||
body: JSON.stringify({ current_password: currentPassword, new_password: newPassword, acknowledge_weak_password: acknowledgeWeak }),
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
}),
|
||||
}
|
||||
|
||||
42
core/http/react-ui/src/utils/mcpClientStorage.js
vendored
42
core/http/react-ui/src/utils/mcpClientStorage.js
vendored
@@ -2,12 +2,52 @@ import { generateId } from './format'
|
||||
|
||||
const STORAGE_KEY = 'localai_client_mcp_servers'
|
||||
|
||||
// localStorage is shared across same-origin pages; an XSS that lands once can
|
||||
// poison persisted MCP server entries to attempt header injection or to feed
|
||||
// a non-http URL into the fetch path. Validate every entry on load and drop
|
||||
// anything that doesn't match the expected shape.
|
||||
function sanitiseServer(s) {
|
||||
if (!s || typeof s !== 'object') return null
|
||||
const id = typeof s.id === 'string' ? s.id : ''
|
||||
const name = typeof s.name === 'string' ? s.name : ''
|
||||
const url = typeof s.url === 'string' ? s.url : ''
|
||||
if (!url) return null
|
||||
// fetch() refuses non-http schemes anyway, but reject early so they
|
||||
// can't get persisted back out. URL parsing also catches malformed values.
|
||||
let parsed
|
||||
try {
|
||||
parsed = new URL(url)
|
||||
} catch {
|
||||
return null
|
||||
}
|
||||
if (parsed.protocol !== 'http:' && parsed.protocol !== 'https:') return null
|
||||
|
||||
const headers = {}
|
||||
if (s.headers && typeof s.headers === 'object' && !Array.isArray(s.headers)) {
|
||||
for (const [k, v] of Object.entries(s.headers)) {
|
||||
// Drop CRLF / control chars to block header injection through poisoned storage.
|
||||
if (typeof k !== 'string' || typeof v !== 'string') continue
|
||||
if (/[\x00-\x1f\x7f]/.test(k) || /[\x00-\x1f\x7f]/.test(v)) continue
|
||||
headers[k] = v
|
||||
}
|
||||
}
|
||||
return {
|
||||
id,
|
||||
name,
|
||||
url,
|
||||
headers,
|
||||
useProxy: s.useProxy !== false,
|
||||
}
|
||||
}
|
||||
|
||||
export function loadClientMCPServers() {
|
||||
try {
|
||||
const stored = localStorage.getItem(STORAGE_KEY)
|
||||
if (stored) {
|
||||
const data = JSON.parse(stored)
|
||||
if (Array.isArray(data)) return data
|
||||
if (Array.isArray(data)) {
|
||||
return data.map(sanitiseServer).filter(Boolean)
|
||||
}
|
||||
}
|
||||
} catch (_e) {
|
||||
// ignore
|
||||
|
||||
31
core/http/react-ui/src/utils/url.js
vendored
Normal file
31
core/http/react-ui/src/utils/url.js
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
// safeHref returns the input URL only if its scheme is on a small allowlist
|
||||
// (http, https, mailto, tel) or if it's a relative/anchor link. Anything else
|
||||
// — most importantly `javascript:` and `data:` — collapses to '#'. Use this
|
||||
// for any <a href={...}> whose URL comes from gallery JSON, agent tool calls,
|
||||
// or any other source the operator hasn't fully vetted.
|
||||
//
|
||||
// React already escapes attribute values, so the only XSS path on a hyperlink
|
||||
// is the URI itself. javascript: in <img src> is inert in modern browsers,
|
||||
// but <a href="javascript:..."> still fires on click.
|
||||
const ALLOWED_SCHEMES = ['http:', 'https:', 'mailto:', 'tel:']
|
||||
|
||||
export function safeHref(url) {
|
||||
if (typeof url !== 'string' || url === '') return '#'
|
||||
const trimmed = url.trim()
|
||||
if (trimmed === '') return '#'
|
||||
// Relative paths, fragment links, and protocol-relative URLs are fine.
|
||||
if (trimmed.startsWith('/') || trimmed.startsWith('#') || trimmed.startsWith('?')) {
|
||||
return trimmed
|
||||
}
|
||||
if (trimmed.startsWith('//')) return trimmed
|
||||
// Heuristic: if there's no colon before the first slash, it's a relative path.
|
||||
const colonIdx = trimmed.indexOf(':')
|
||||
if (colonIdx === -1) return trimmed
|
||||
const slashIdx = trimmed.indexOf('/')
|
||||
if (slashIdx !== -1 && slashIdx < colonIdx) return trimmed
|
||||
// There is a scheme — allowlist-check it. Browsers ignore tabs/newlines
|
||||
// inside the scheme (`java\tscript:...`), so we strip control chars first.
|
||||
const scheme = trimmed.slice(0, colonIdx).toLowerCase().replace(/[\x00-\x1f]/g, '')
|
||||
if (ALLOWED_SCHEMES.includes(scheme + ':')) return trimmed
|
||||
return '#'
|
||||
}
|
||||
225
core/http/route_coverage_test.go
Normal file
225
core/http/route_coverage_test.go
Normal file
@@ -0,0 +1,225 @@
|
||||
//go:build auth
|
||||
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Every API-prefixed route registered by API() must either reject anonymous
|
||||
// traffic with 401 or appear on the explicit public allowlist below. The
|
||||
// test fails on routes that ship without an auth decision; adding a new
|
||||
// public surface should be deliberate, not a side effect.
|
||||
var _ = Describe("Route auth coverage", func() {
|
||||
var (
|
||||
app *echo.Echo
|
||||
tmpdir string
|
||||
c context.Context
|
||||
cancel context.CancelFunc
|
||||
appInst *application.Application
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmpdir, err = os.MkdirTemp("", "route-coverage-")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
modelDir := filepath.Join(tmpdir, "models")
|
||||
Expect(os.Mkdir(modelDir, 0750)).To(Succeed())
|
||||
bDir := filepath.Join(tmpdir, "backends")
|
||||
Expect(os.Mkdir(bDir, 0750)).To(Succeed())
|
||||
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(bDir),
|
||||
system.WithModelPath(modelDir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Auth enabled, no legacy keys, no admin user pre-created. With auth
|
||||
// enabled the global middleware MUST reject anonymous API requests
|
||||
// regardless of admin presence.
|
||||
appInst, err = application.New(
|
||||
config.WithContext(c),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithAuthEnabled(true),
|
||||
config.WithAuthDatabaseURL(":memory:"),
|
||||
config.WithAuthAPIKeyHMACSecret("test-secret-for-route-coverage"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
app, err = API(appInst)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
Expect(os.RemoveAll(tmpdir)).To(Succeed())
|
||||
})
|
||||
|
||||
It("rejects anonymous traffic on every API route except the explicit allowlist", func() {
|
||||
// Routes that are intentionally reachable without authentication.
|
||||
// Each entry needs a justification comment.
|
||||
expectedPublicPrefixes := []string{
|
||||
// Auth flow itself — login, registration, OAuth callbacks
|
||||
"/api/auth/",
|
||||
|
||||
// Distributed-mode node self-service: authenticated by a
|
||||
// registration token presented in the request body, not by the
|
||||
// global auth middleware. Verified separately in node tests.
|
||||
"/api/node/register",
|
||||
|
||||
// Branding read for login screen + public branding asset server.
|
||||
// Mutating /api/branding/asset/* routes are exempt from global
|
||||
// auth but admin-gated by route-level middleware
|
||||
// (TestBrandingRoutes_AdminGatingHolds pins that contract).
|
||||
"/api/branding",
|
||||
"/branding/",
|
||||
|
||||
// Health and metadata used by orchestrators / load balancers
|
||||
"/healthz",
|
||||
"/readyz",
|
||||
|
||||
// Static asset surfaces used by the UI before login
|
||||
"/favicon.svg",
|
||||
"/static/",
|
||||
"/assets/",
|
||||
"/locales/",
|
||||
"/generated-audio/",
|
||||
"/generated-images/",
|
||||
"/generated-videos/",
|
||||
}
|
||||
expectedPublicExact := map[string]bool{
|
||||
// SPA shell + redirects — UI handles login client-side
|
||||
"/": true,
|
||||
"/app": true,
|
||||
"/browse": true,
|
||||
"/swagger": true,
|
||||
"/swagger/": true,
|
||||
"/swagger/*": true,
|
||||
"/oauth/start": true,
|
||||
}
|
||||
|
||||
// Per-route exemptions for distributed-mode node endpoints whose
|
||||
// pattern carries a path param (so they don't fit a flat prefix).
|
||||
// These authenticate via registration token at the handler layer.
|
||||
nodeSelfPattern := regexp.MustCompile(`^/api/node/[^/]+/(heartbeat|drain|deregister)$`)
|
||||
|
||||
// Concretize a route pattern into a URL suitable for httptest.
|
||||
// Echo path params come back as ":name" and wildcards as "*".
|
||||
concretize := func(pattern string) string {
|
||||
parts := strings.Split(pattern, "/")
|
||||
for i, p := range parts {
|
||||
if strings.HasPrefix(p, ":") {
|
||||
parts[i] = "test"
|
||||
} else if p == "*" {
|
||||
parts[i] = "test"
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "/")
|
||||
}
|
||||
|
||||
isAPI := func(p string) bool {
|
||||
apiPrefixes := []string{
|
||||
"/api/", "/v1/", "/models/", "/backends/", "/backend/",
|
||||
"/tts", "/vad", "/video", "/stores/", "/system",
|
||||
"/ws/", "/generated-", "/chat/", "/completions",
|
||||
"/edits", "/embeddings", "/audio/", "/images/",
|
||||
"/messages", "/responses",
|
||||
}
|
||||
if p == "/metrics" {
|
||||
return true
|
||||
}
|
||||
for _, pre := range apiPrefixes {
|
||||
if strings.HasPrefix(p, pre) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
isAllowlisted := func(p string) bool {
|
||||
if expectedPublicExact[p] {
|
||||
return true
|
||||
}
|
||||
for _, pre := range expectedPublicPrefixes {
|
||||
if strings.HasPrefix(p, pre) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if nodeSelfPattern.MatchString(p) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var leaks []string
|
||||
seen := map[string]bool{}
|
||||
for _, r := range app.Routes() {
|
||||
// Echo registers automatic HEAD routes for GETs; auth check is
|
||||
// identical, so dedupe.
|
||||
key := r.Method + " " + r.Path
|
||||
if seen[key] {
|
||||
continue
|
||||
}
|
||||
seen[key] = true
|
||||
|
||||
// Only inspect API surface — UI/static paths are intentionally
|
||||
// reachable for SPA hydration before login.
|
||||
if !isAPI(r.Path) {
|
||||
continue
|
||||
}
|
||||
if isAllowlisted(r.Path) {
|
||||
continue
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(r.Method, concretize(r.Path), nil)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
// We accept:
|
||||
// 401: middleware rejected — the desired outcome
|
||||
// 404: route not actually reachable in this minimal config
|
||||
// (e.g. distributed-only routes); not an auth leak
|
||||
// 405: method not allowed; auth never ran but isn't a leak
|
||||
if rec.Code == http.StatusUnauthorized ||
|
||||
rec.Code == http.StatusNotFound ||
|
||||
rec.Code == http.StatusMethodNotAllowed {
|
||||
continue
|
||||
}
|
||||
|
||||
leaks = append(leaks, " "+r.Method+" "+r.Path+
|
||||
" → "+http.StatusText(rec.Code)+
|
||||
" (got "+strconv.Itoa(rec.Code)+")")
|
||||
}
|
||||
|
||||
if len(leaks) > 0 {
|
||||
Fail("Routes reachable without authentication:\n" +
|
||||
strings.Join(leaks, "\n") +
|
||||
"\n\nIf the route is intentionally public, add it to " +
|
||||
"expectedPublicPrefixes or expectedPublicExact in " +
|
||||
"core/http/route_coverage_test.go with a justification " +
|
||||
"comment. Otherwise, gate it behind the auth middleware " +
|
||||
"(automatic for /api/, /v1/, /models/, /backends/, etc.) " +
|
||||
"or RequireAdmin / RequireFeature.")
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -190,6 +190,13 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
||||
authRL := newRateLimiter(1*time.Minute, 5)
|
||||
authRateLimitMw := rateLimitMiddleware(authRL)
|
||||
|
||||
// Separate, more permissive limiter for OAuth/OIDC callbacks. Corporate
|
||||
// SSO often funnels many real users through one outbound IP, so the 5/min
|
||||
// password-style cap is too tight here; 60/min still bounds a flood that
|
||||
// would otherwise pin token-exchange traffic to the IdP.
|
||||
oauthRL := newRateLimiter(1*time.Minute, 60)
|
||||
oauthRateLimitMw := rateLimitMiddleware(oauthRL)
|
||||
|
||||
// Start background goroutine to periodically prune stale IP entries
|
||||
go func() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
@@ -200,6 +207,7 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
||||
return
|
||||
case <-ticker.C:
|
||||
authRL.cleanup()
|
||||
oauthRL.cleanup()
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -274,13 +282,13 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
||||
e.GET("/api/auth/github/login", oauthMgr.LoginHandler(auth.ProviderGitHub))
|
||||
e.GET("/api/auth/github/callback", oauthMgr.CallbackHandler(
|
||||
auth.ProviderGitHub, db, appConfig.Auth.AdminEmail, appConfig.Auth.RegistrationMode, appConfig.Auth.APIKeyHMACSecret,
|
||||
))
|
||||
), oauthRateLimitMw)
|
||||
}
|
||||
if appConfig.Auth.OIDCClientID != "" {
|
||||
e.GET("/api/auth/oidc/login", oauthMgr.LoginHandler(auth.ProviderOIDC))
|
||||
e.GET("/api/auth/oidc/callback", oauthMgr.CallbackHandler(
|
||||
auth.ProviderOIDC, db, appConfig.Auth.AdminEmail, appConfig.Auth.RegistrationMode, appConfig.Auth.APIKeyHMACSecret,
|
||||
))
|
||||
), oauthRateLimitMw)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -292,10 +300,11 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
Name string `json:"name"`
|
||||
InviteCode string `json:"inviteCode"`
|
||||
Email string `json:"email"`
|
||||
Password string `json:"password"`
|
||||
Name string `json:"name"`
|
||||
InviteCode string `json:"inviteCode"`
|
||||
AcknowledgeWeakPassword bool `json:"acknowledge_weak_password"`
|
||||
}
|
||||
if err := c.Bind(&body); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request"})
|
||||
@@ -311,8 +320,8 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
||||
if _, err := mail.ParseAddress(body.Email); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid email address"})
|
||||
}
|
||||
if len(body.Password) < 8 {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "password must be at least 8 characters"})
|
||||
if err := auth.ValidatePasswordStrength(body.Password, auth.PasswordPolicy{AllowWeak: body.AcknowledgeWeakPassword}); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, auth.PasswordError(err))
|
||||
}
|
||||
|
||||
hash, err := auth.HashPassword(body.Password)
|
||||
@@ -579,8 +588,9 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
||||
}
|
||||
|
||||
var body struct {
|
||||
CurrentPassword string `json:"current_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
CurrentPassword string `json:"current_password"`
|
||||
NewPassword string `json:"new_password"`
|
||||
AcknowledgeWeakPassword bool `json:"acknowledge_weak_password"`
|
||||
}
|
||||
if err := c.Bind(&body); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request"})
|
||||
@@ -590,8 +600,8 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "current and new passwords are required"})
|
||||
}
|
||||
|
||||
if len(body.NewPassword) < 8 {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "new password must be at least 8 characters"})
|
||||
if err := auth.ValidatePasswordStrength(body.NewPassword, auth.PasswordPolicy{AllowWeak: body.AcknowledgeWeakPassword}); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, auth.PasswordError(err))
|
||||
}
|
||||
|
||||
// Verify current password
|
||||
@@ -900,14 +910,15 @@ func RegisterAuthRoutes(e *echo.Echo, app *application.Application) {
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Password string `json:"password"`
|
||||
Password string `json:"password"`
|
||||
AcknowledgeWeakPassword bool `json:"acknowledge_weak_password"`
|
||||
}
|
||||
if err := c.Bind(&body); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request"})
|
||||
}
|
||||
|
||||
if len(body.Password) < 8 {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "password must be at least 8 characters"})
|
||||
if err := auth.ValidatePasswordStrength(body.Password, auth.PasswordPolicy{AllowWeak: body.AcknowledgeWeakPassword}); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, auth.PasswordError(err))
|
||||
}
|
||||
|
||||
hash, err := auth.HashPassword(body.Password)
|
||||
|
||||
165
core/http/routes/auth_profile_test.go
Normal file
165
core/http/routes/auth_profile_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
//go:build auth
|
||||
|
||||
package routes_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
authpkg "github.com/mudler/LocalAI/core/http/auth"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// PUT /api/auth/profile must accept ONLY name and avatar_url. A user that
|
||||
// posts {"role":"admin","email":"...","status":"active","password_hash":"..."}
|
||||
// must not be able to mutate any of those fields. The current handler uses
|
||||
// an explicit local body struct + gorm Updates(map) with a column allowlist;
|
||||
// this test pins that contract so a future refactor (e.g. c.Bind(&user))
|
||||
// can't silently regress to mass-assignment.
|
||||
var _ = Describe("PUT /api/auth/profile field-tampering", func() {
|
||||
var (
|
||||
app *echo.Echo
|
||||
appCtx context.Context
|
||||
cancel context.CancelFunc
|
||||
tmpdir string
|
||||
alice authpkg.User
|
||||
appAt *application.Application
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmpdir, err = os.MkdirTemp("", "profile-tamper-")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
modelDir := filepath.Join(tmpdir, "models")
|
||||
Expect(os.Mkdir(modelDir, 0750)).To(Succeed())
|
||||
bDir := filepath.Join(tmpdir, "backends")
|
||||
Expect(os.Mkdir(bDir, 0750)).To(Succeed())
|
||||
|
||||
appCtx, cancel = context.WithCancel(context.Background())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(bDir),
|
||||
system.WithModelPath(modelDir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
appAt, err = application.New(
|
||||
config.WithContext(appCtx),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithAuthEnabled(true),
|
||||
config.WithAuthDatabaseURL(":memory:"),
|
||||
config.WithAuthAPIKeyHMACSecret("test-secret-profile-tamper"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
app, err = API(appAt)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Seed a non-admin user directly in the DB.
|
||||
alice = authpkg.User{
|
||||
ID: "alice-id",
|
||||
Email: "alice@example.com",
|
||||
Name: "Alice",
|
||||
Provider: authpkg.ProviderLocal,
|
||||
PasswordHash: "$2a$10$abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGHIJKLMNOPQRSTUV",
|
||||
Role: authpkg.RoleUser,
|
||||
Status: "active",
|
||||
}
|
||||
Expect(appAt.AuthDB().Create(&alice).Error).To(Succeed())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
Expect(os.RemoveAll(tmpdir)).To(Succeed())
|
||||
})
|
||||
|
||||
// Mint a real API key for alice once per spec — the auth middleware
|
||||
// resolves Bearer tokens against the DB, which is the same path a
|
||||
// browser session takes after extraction. This avoids forging session
|
||||
// cookies and exercises the production auth flow end-to-end.
|
||||
callProfile := func(body map[string]any) (int, map[string]any) {
|
||||
key, _, err := authpkg.CreateAPIKey(appAt.AuthDB(), alice.ID, "test", authpkg.RoleUser, "test-secret-profile-tamper", nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
raw, err := json.Marshal(body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req := httptest.NewRequest(http.MethodPut, "/api/auth/profile", bytes.NewReader(raw))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Sec-Fetch-Site", "same-origin")
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
var out map[string]any
|
||||
_ = json.Unmarshal(rec.Body.Bytes(), &out)
|
||||
return rec.Code, out
|
||||
}
|
||||
|
||||
It("ignores attempts to set role: admin", func() {
|
||||
code, _ := callProfile(map[string]any{
|
||||
"name": "Alice 2",
|
||||
"avatar_url": "https://example.com/a.png",
|
||||
"role": "admin",
|
||||
})
|
||||
Expect(code).To(Equal(http.StatusOK))
|
||||
|
||||
var fresh authpkg.User
|
||||
Expect(appAt.AuthDB().First(&fresh, "id = ?", alice.ID).Error).To(Succeed())
|
||||
Expect(fresh.Role).To(Equal(authpkg.RoleUser),
|
||||
"role must remain RoleUser; mass-assignment of role is forbidden")
|
||||
Expect(fresh.Name).To(Equal("Alice 2"))
|
||||
Expect(fresh.AvatarURL).To(Equal("https://example.com/a.png"))
|
||||
})
|
||||
|
||||
It("ignores attempts to override email, status, password_hash, provider, and id", func() {
|
||||
body := map[string]any{
|
||||
"name": "Alice 3",
|
||||
"avatar_url": "",
|
||||
"email": "attacker@example.com",
|
||||
"status": "frozen",
|
||||
"password_hash": "$2a$10$attacker-controlled-hash",
|
||||
"provider": "github",
|
||||
"id": "different-id",
|
||||
}
|
||||
code, _ := callProfile(body)
|
||||
Expect(code).To(Equal(http.StatusOK))
|
||||
|
||||
var fresh authpkg.User
|
||||
Expect(appAt.AuthDB().First(&fresh, "id = ?", alice.ID).Error).To(Succeed())
|
||||
Expect(fresh.Email).To(Equal(alice.Email), "email must be immutable through profile update")
|
||||
Expect(fresh.Status).To(Equal(alice.Status), "status must be immutable")
|
||||
Expect(fresh.PasswordHash).To(Equal(alice.PasswordHash), "password_hash must be immutable")
|
||||
Expect(fresh.Provider).To(Equal(alice.Provider), "provider must be immutable")
|
||||
Expect(fresh.ID).To(Equal(alice.ID), "id must be immutable")
|
||||
})
|
||||
|
||||
It("requires a non-empty name", func() {
|
||||
code, body := callProfile(map[string]any{"name": ""})
|
||||
Expect(code).To(Equal(http.StatusBadRequest))
|
||||
Expect(body["error"]).To(ContainSubstring("name is required"))
|
||||
})
|
||||
|
||||
It("rejects oversized avatar_url", func() {
|
||||
long := make([]byte, 1024)
|
||||
for i := range long {
|
||||
long[i] = 'x'
|
||||
}
|
||||
code, _ := callProfile(map[string]any{
|
||||
"name": "Alice",
|
||||
"avatar_url": "https://" + string(long),
|
||||
})
|
||||
Expect(code).To(Equal(http.StatusBadRequest))
|
||||
})
|
||||
})
|
||||
@@ -672,12 +672,18 @@ func (s *AgentPoolService) ensureCollectionForUser(userID, name string) error {
|
||||
}
|
||||
|
||||
// UploadToCollectionForUser uploads to a collection for a specific user.
|
||||
// The filename arrives from a multipart upload; the vendored backend may or
|
||||
// may not sanitise it, so strip any directory components at the boundary.
|
||||
func (s *AgentPoolService) UploadToCollectionForUser(userID, collection, filename string, fileBody io.Reader) (string, error) {
|
||||
backend, err := s.CollectionsBackendForUser(userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return backend.Upload(collection, filename, fileBody)
|
||||
base := filepath.Base(filename)
|
||||
if base == "." || base == ".." || base == "/" || base == "" {
|
||||
return "", fmt.Errorf("invalid filename")
|
||||
}
|
||||
return backend.Upload(collection, base, fileBody)
|
||||
}
|
||||
|
||||
// CollectionEntryExistsForUser checks if an entry exists in a user's collection.
|
||||
|
||||
@@ -731,13 +731,19 @@ func (s *FineTuneService) setExportFailed(job *schema.FineTuneJob, message strin
|
||||
}
|
||||
|
||||
// UploadDataset handles dataset file upload and returns the local path.
|
||||
// The filename comes straight from a multipart upload, so strip directory
|
||||
// components — anything else risks a write under a sibling path.
|
||||
func (s *FineTuneService) UploadDataset(filename string, data []byte) (string, error) {
|
||||
uploadDir := filepath.Join(s.fineTuneBaseDir(), "datasets")
|
||||
if err := os.MkdirAll(uploadDir, 0750); err != nil {
|
||||
return "", fmt.Errorf("failed to create dataset directory: %w", err)
|
||||
}
|
||||
|
||||
filePath := filepath.Join(uploadDir, uuid.New().String()[:8]+"-"+filename)
|
||||
base := filepath.Base(filename)
|
||||
if base == "." || base == ".." || base == "/" || base == "" {
|
||||
return "", fmt.Errorf("invalid filename")
|
||||
}
|
||||
filePath := filepath.Join(uploadDir, uuid.New().String()[:8]+"-"+base)
|
||||
if err := os.WriteFile(filePath, data, 0640); err != nil {
|
||||
return "", fmt.Errorf("failed to write dataset: %w", err)
|
||||
}
|
||||
|
||||
1
go.mod
1
go.mod
@@ -61,6 +61,7 @@ require (
|
||||
github.com/testcontainers/testcontainers-go v0.42.0
|
||||
github.com/testcontainers/testcontainers-go/modules/nats v0.42.0
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.42.0
|
||||
github.com/timbutler/zxcvbn v1.0.4
|
||||
go.opentelemetry.io/otel v1.43.0
|
||||
go.opentelemetry.io/otel/exporters/prometheus v0.65.0
|
||||
go.opentelemetry.io/otel/metric v1.43.0
|
||||
|
||||
2
go.sum
2
go.sum
@@ -1144,6 +1144,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/timbutler/zxcvbn v1.0.4 h1:nTUa8UpLhIxhUBag42fQcwiC8AtTxNVbQMbmxyxLfXg=
|
||||
github.com/timbutler/zxcvbn v1.0.4/go.mod h1:Cl20mGFz9+SXvTRebBcwMUDqZUvCfSnb+XMznbTKo2U=
|
||||
github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA=
|
||||
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
|
||||
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
|
||||
|
||||
@@ -272,15 +272,11 @@ func (s URI) ResolveURL() string {
|
||||
}
|
||||
|
||||
func removePartialFile(tmpFilePath string) error {
|
||||
_, err := os.Stat(tmpFilePath)
|
||||
if err == nil {
|
||||
xlog.Debug("Removing temporary file", "file", tmpFilePath)
|
||||
err = os.Remove(tmpFilePath)
|
||||
if err != nil {
|
||||
err1 := fmt.Errorf("failed to remove temporary download file %s: %v", tmpFilePath, err)
|
||||
xlog.Warn("failed to remove temporary download file", "error", err1)
|
||||
return err1
|
||||
}
|
||||
xlog.Debug("Removing temporary file", "file", tmpFilePath)
|
||||
if err := os.Remove(tmpFilePath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
err1 := fmt.Errorf("failed to remove temporary download file %s: %v", tmpFilePath, err)
|
||||
xlog.Warn("failed to remove temporary download file", "error", err1)
|
||||
return err1
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -578,20 +574,28 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string
|
||||
default:
|
||||
}
|
||||
|
||||
err = os.Rename(tmpFilePath, filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err)
|
||||
}
|
||||
|
||||
// Invariant: verify the streamed hash before promoting the temp file to
|
||||
// the final path. Renaming first would leave tampered content reachable
|
||||
// to subsequent readers even though we return an error.
|
||||
if sha != "" {
|
||||
// Verify SHA
|
||||
calculatedSHA := fmt.Sprintf("%x", progress.hash.Sum(nil))
|
||||
if calculatedSHA != sha {
|
||||
xlog.Debug("SHA mismatch for file", "file", filePath, "calculated", calculatedSHA, "metadata", sha)
|
||||
_ = removePartialFile(tmpFilePath)
|
||||
return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", filePath, calculatedSHA, sha)
|
||||
}
|
||||
} else {
|
||||
xlog.Debug("SHA missing. Skipping validation", "file", filePath)
|
||||
// Visible at the default log level so missing-digest configs are
|
||||
// noticed; silent acceptance was the historical bug.
|
||||
xlog.Warn("downloading without integrity check — supplied SHA is empty",
|
||||
"file", filePath,
|
||||
"url", url,
|
||||
)
|
||||
}
|
||||
|
||||
err = os.Rename(tmpFilePath, filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err)
|
||||
}
|
||||
|
||||
xlog.Info("File downloaded and verified", "file", filePath)
|
||||
|
||||
@@ -295,6 +295,49 @@ var _ = Describe("Download Test", func() {
|
||||
err = uri.DownloadFile(filePath, mockDataSha, 1, 1, func(s1, s2, s3 string, f float64) {})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
// A file that fails its SHA check must not be usable. The historical
|
||||
// implementation renamed the temp file to its final path *before*
|
||||
// verifying the hash, so a mismatch returned an error but left a
|
||||
// tampered file at the destination — the next caller (e.g. a backend
|
||||
// launcher) could pick it up and run with it.
|
||||
It("does not leave a corrupted file at the destination on SHA mismatch", func() {
|
||||
mockServer := getMockServer(true)
|
||||
defer mockServer.Close()
|
||||
uri := URI(mockServer.URL)
|
||||
|
||||
// Use a clearly-wrong expected SHA; the server will return real
|
||||
// data with a different hash.
|
||||
wrongSHA := "0000000000000000000000000000000000000000000000000000000000000000"
|
||||
err := uri.DownloadFile(filePath, wrongSHA, 1, 1, func(s1, s2, s3 string, f float64) {})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("SHA"))
|
||||
|
||||
// The file must not exist at the final destination.
|
||||
_, statErr := os.Stat(filePath)
|
||||
Expect(os.IsNotExist(statErr)).To(BeTrue(),
|
||||
"download with wrong SHA left a file at %s — a subsequent caller could load tampered content", filePath)
|
||||
})
|
||||
|
||||
// A download without an expected digest is a supply-chain footgun.
|
||||
// The downloader allows it (backend installs pass through here
|
||||
// today and don't yet ship a digest) but it is the caller's
|
||||
// responsibility to know when integrity is required. The downloader
|
||||
// emits a WARN log on every empty-digest download to make this
|
||||
// visible at the default log level.
|
||||
It("succeeds with empty SHA but emits an integrity warning", func() {
|
||||
mockServer := getMockServer(true)
|
||||
defer mockServer.Close()
|
||||
uri := URI(mockServer.URL)
|
||||
|
||||
// No assertion on logs (we don't capture xlog output here),
|
||||
// but the call must succeed so existing backend installs do
|
||||
// not regress.
|
||||
err := uri.DownloadFile(filePath, "", 1, 1, func(s1, s2, s3 string, f float64) {})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, statErr := os.Stat(filePath)
|
||||
Expect(statErr).ToNot(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
|
||||
@@ -49,7 +49,7 @@ func ValidateExternalURL(rawURL string) error {
|
||||
return fmt.Errorf("unable to parse resolved IP: %s", ipStr)
|
||||
}
|
||||
|
||||
if !isPublicIP(ip) {
|
||||
if !IsPublicIP(ip) {
|
||||
return fmt.Errorf("requests to internal network addresses are not allowed")
|
||||
}
|
||||
}
|
||||
@@ -57,7 +57,11 @@ func ValidateExternalURL(rawURL string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func isPublicIP(ip net.IP) bool {
|
||||
// IsPublicIP reports whether ip refers to a host on the public internet, i.e.
|
||||
// not loopback, link-local, private (RFC 1918 / RFC 4193), or unspecified.
|
||||
// Covers 0.0.0.0/8, ::/128, and IPv4-mapped IPv6 wrapping a private address —
|
||||
// holes hand-rolled CIDR lists tend to miss.
|
||||
func IsPublicIP(ip net.IP) bool {
|
||||
if ip.IsLoopback() ||
|
||||
ip.IsLinkLocalUnicast() ||
|
||||
ip.IsLinkLocalMulticast() ||
|
||||
@@ -66,7 +70,6 @@ func isPublicIP(ip net.IP) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Block IPv4-mapped IPv6 addresses that wrap private IPv4
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
return !ip4.IsLoopback() &&
|
||||
!ip4.IsLinkLocalUnicast() &&
|
||||
|
||||
Reference in New Issue
Block a user