mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-30 11:36:31 -04:00
security(http): refuse redirects on outbound clients via hardened pkg/httpclient (#10087)
LocalAI's outbound HTTP clients used Go's default redirect policy, which
follows up to 10 redirects. On a cross-host redirect Go forwards custom
request headers — including credential headers such as Anthropic's
x-api-key — to the redirect target (Go strips Authorization, Cookie and
WWW-Authenticate cross-host, but NOT arbitrary custom headers). An
attacker able to elicit a redirect from an upstream (a hijacked or
spoofed upstream, DNS trickery, or a malicious upstream_url) then
harvests the operator's provider API key.
This was first reported against the cloud-proxy / MITM PII path
(GHSA-3mj3-57v2-4636); the same class affects every other outbound
client. Rather than patch each call site, add pkg/httpclient as the one
sanctioned constructor for outbound HTTP and route everything through it.
pkg/httpclient:
- New(...) refuses redirects, TLS 1.2 floor, no body
deadline (streaming/SSE safe)
- NewWithTimeout(d) simple request/response calls
- WithFollowRedirects opt-in following that still strips credential
headers on any cross-host hop; different
scheme/host/port == different origin, guarding
the curl CVE-2022-27774 port-confusion class
- WithTransport(rt) keep a custom transport (IP-pin, HTTP/2, a
credential-injecting RoundTripper)
- HardenedTransport() base transport with the TLS floor + bounded setup
- Harden(c) apply the policy to a library-supplied *http.Client
- NoRedirect the CheckRedirect policy; wraps ErrRedirectBlocked
Lint: a forbidigo rule flags http.DefaultClient and http.Get/Post/
PostForm/Head, pointing at pkg/httpclient (.golangci.yml,
.agents/coding-style.md). forbidigo cannot match the &http.Client{}
composite literal without also flagging legitimate *http.Client type
references, so that form is enforced by review.
Migrates every non-test outbound call site across core/, pkg/, cmd/, and
the Go backend (backend/go/cloud-proxy). Credential-bearing and
internal-RPC clients refuse redirects; download / CDN / registry clients
use WithFollowRedirects so they keep working while stripping secrets
cross-host. The only credential-bearing client that follows redirects is
the gated-download path (pkg/downloader/uri.go), which strips the token
on the cross-host hop to the CDN. Hardening this closes, in passing:
- MCP remote-server bearer token leaking via a redirect (the
RoundTripper re-injected Authorization on every hop)
- agent multimedia/webhook clients leaking user-supplied auth headers
- cors_proxy following redirects, bypassing its SSRF IP-pin
- downloader's authorized read path leaking the token cross-host
Fixes: GHSA-3mj3-57v2-4636 (cloud-proxy leaks operator provider API key
(x-api-key) to attacker host on cross-host redirect)
Reported-by: tonghuaroot
Assisted-by: Claude:claude-opus-4-8 [Claude Code]
Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
committed by
GitHub
parent
a7cad704b9
commit
12d1f3a697
@@ -50,6 +50,17 @@ Do not mix styles within a package. If you are extending tests in a package that
|
||||
|
||||
This is enforced by `golangci-lint` via the `forbidigo` linter (see `.golangci.yml`); calls like `t.Errorf` / `t.Fatalf` / `t.Run` / `t.Skip` / `t.Logf` are flagged. Run `make lint` locally before submitting; the same check runs in CI (`.github/workflows/lint.yml`).
|
||||
|
||||
## Outbound HTTP
|
||||
|
||||
All outbound HTTP must go through `github.com/mudler/LocalAI/pkg/httpclient` rather than the standard library's default client. Use `httpclient.New(...)` (no body deadline — safe for streaming/SSE) or `httpclient.NewWithTimeout(d, ...)` (simple request/response). Both **refuse redirects by default** and set a TLS 1.2 floor.
|
||||
|
||||
The reason is GHSA-3mj3-57v2-4636: the std default client follows redirects, and on a *cross-host* redirect Go forwards custom credential headers (e.g. Anthropic's `x-api-key`) to the redirect target, leaking the secret. `httpclient` fails closed instead.
|
||||
|
||||
- Need to follow redirects (download CDNs, registry blobs, GitHub asset URLs)? Pass `httpclient.WithFollowRedirects()` — it still strips credential headers on any cross-host hop.
|
||||
- Have a custom transport (IP-pinned dialer, HTTP/2 tuning, a credential-injecting `RoundTripper`)? Pass `httpclient.WithTransport(rt)`, basing the transport on `httpclient.HardenedTransport()` to keep the TLS floor. Handed a `*http.Client` by a library? `httpclient.Harden(c)` applies the policy in place.
|
||||
|
||||
This is enforced by `forbidigo` (see `.golangci.yml`): `http.DefaultClient` and `http.Get`/`Post`/`PostForm`/`Head` are flagged. The `&http.Client{}` composite literal can't be matched precisely by forbidigo without also flagging legitimate `*http.Client` type references, so that form is caught by review — don't construct raw clients.
|
||||
|
||||
## Documentation
|
||||
|
||||
The project documentation is located in `docs/content`. When adding new features or changing existing functionality, it is crucial to update the documentation to reflect these changes. This helps users understand how to use the new capabilities and ensures the documentation stays relevant.
|
||||
|
||||
@@ -56,6 +56,20 @@ linters:
|
||||
# are exempt — see linters.exclusions.rules below.
|
||||
- pattern: '^os\.(Getenv|LookupEnv|Environ)$'
|
||||
msg: 'Plumb config through ApplicationConfig (or the relevant CLI struct) instead of reading env directly. CLI entry points (core/cli/) bind env vars via kong''s `env:` tag — that is the only sanctioned env→struct boundary. See .agents/coding-style.md.'
|
||||
# Outbound HTTP must go through pkg/httpclient, which refuses redirects
|
||||
# by default and sets a TLS floor. The std-library default client and
|
||||
# the http.Get/Post/... convenience helpers follow redirects (up to 10)
|
||||
# and, on a cross-host redirect, forward custom credential headers such
|
||||
# as Anthropic's x-api-key to the redirect target — leaking the secret
|
||||
# (GHSA-3mj3-57v2-4636). forbidigo can't precisely match the
|
||||
# `&http.Client{}` composite literal without also flagging legitimate
|
||||
# `*http.Client` type references, so that form is enforced by
|
||||
# convention + review; these two patterns catch the implicit-default
|
||||
# client, which is the common footgun.
|
||||
- pattern: '^http\.DefaultClient$'
|
||||
msg: 'Use pkg/httpclient (httpclient.New / NewWithTimeout) instead of http.DefaultClient — the std client follows redirects and leaks credential headers cross-host (GHSA-3mj3-57v2-4636). See .agents/coding-style.md.'
|
||||
- pattern: '^http\.(Get|Post|PostForm|Head)$'
|
||||
msg: 'Use pkg/httpclient (httpclient.New / NewWithTimeout) instead of http.Get/Post/PostForm/Head — these use http.DefaultClient, which follows redirects and leaks credential headers cross-host (GHSA-3mj3-57v2-4636). See .agents/coding-style.md.'
|
||||
exclusions:
|
||||
paths:
|
||||
# Upstream whisper.cpp source tree fetched by the whisper backend Makefile.
|
||||
@@ -95,3 +109,18 @@ linters:
|
||||
- path: _test\.go$
|
||||
text: 'os\.(Getenv|LookupEnv|Environ)'
|
||||
linters: [forbidigo]
|
||||
# pkg/httpclient is the sanctioned home for outbound HTTP clients; it
|
||||
# necessarily references net/http directly.
|
||||
- path: ^pkg/httpclient/
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
# Tests drive local httptest servers where redirect/TLS hardening is
|
||||
# irrelevant; the std client is fine there.
|
||||
- path: _test\.go$
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
# Vendored upstream whisper.cpp Go bindings are a separate module and
|
||||
# cannot import pkg/httpclient.
|
||||
- path: ^backend/go/whisper/sources/
|
||||
text: 'http\.(DefaultClient|Get|Post|PostForm|Head)'
|
||||
linters: [forbidigo]
|
||||
|
||||
@@ -192,6 +192,61 @@ var _ = Describe("Forward", func() {
|
||||
Expect(<-gotAuth).To(Equal("Bearer sk-real"), "caller-supplied Basic header must be replaced")
|
||||
})
|
||||
|
||||
It("refuses to follow upstream redirects and never leaks the key to the redirect target", func() {
|
||||
// A 3xx from the configured upstream means misconfiguration or a
|
||||
// hijacked/spoofed host. Following it would replay the request —
|
||||
// and the injected API key — to the Location host. Anthropic's
|
||||
// x-api-key is NOT stripped by Go on cross-host redirects, so this
|
||||
// would be a credential leak. The proxy must refuse the redirect.
|
||||
sinkHit := make(chan string, 1)
|
||||
sink := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sinkHit <- r.Header.Get("x-api-key")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer sink.Close()
|
||||
|
||||
redirector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, sink.URL, http.StatusFound)
|
||||
}))
|
||||
defer redirector.Close()
|
||||
|
||||
GinkgoT().Setenv("CLOUD_PROXY_REDIRECT_KEY", "ant-secret")
|
||||
|
||||
cp := NewCloudProxy()
|
||||
Expect(cp.Load(&pb.ModelOptions{
|
||||
Proxy: &pb.ProxyOptions{
|
||||
UpstreamUrl: redirector.URL,
|
||||
Mode: modePassthrough,
|
||||
Provider: providerAnthropic,
|
||||
ApiKeyEnv: "CLOUD_PROXY_REDIRECT_KEY",
|
||||
},
|
||||
})).To(Succeed())
|
||||
|
||||
addr := "test://forward-no-redirect"
|
||||
grpc.Provide(addr, cp)
|
||||
c := grpc.NewClient(addr, true, nil, false)
|
||||
stream, err := c.Forward(context.Background())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(stream.Send(&pb.ForwardRequest{
|
||||
Path: "/v1/messages",
|
||||
Method: "POST",
|
||||
})).To(Succeed())
|
||||
Expect(stream.CloseSend()).To(Succeed())
|
||||
|
||||
// Drain the stream; a refused redirect surfaces as a non-EOF error.
|
||||
var streamErr error
|
||||
for {
|
||||
if _, err := stream.Recv(); err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
streamErr = err
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
Expect(streamErr).To(HaveOccurred(), "refused redirect must surface as an error")
|
||||
Expect(sinkHit).NotTo(Receive(), "the redirect target must never be contacted")
|
||||
})
|
||||
|
||||
It("handles concurrent calls without interference", func() {
|
||||
// CloudProxy explicitly omits base.SingleThread — independent
|
||||
// Forward streams must not block each other or leak state.
|
||||
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/grpc/base"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// Mirror of core/config.Proxy{Mode,Provider}* — backends don't
|
||||
@@ -48,10 +50,15 @@ type proxyConfig struct {
|
||||
}
|
||||
|
||||
func NewCloudProxy() *CloudProxy {
|
||||
// No Client-level Timeout — that would bound streaming SSE
|
||||
// responses too, which can legitimately last minutes. Per-request
|
||||
// deadlines come from the gRPC stream context.
|
||||
return &CloudProxy{client: &http.Client{}}
|
||||
// httpclient.New refuses redirects outright: the proxy talks to a
|
||||
// single configured upstream API (OpenAI/Anthropic/...) that answers
|
||||
// directly, so a 3xx means misconfiguration, a hijacked upstream, or
|
||||
// DNS trickery — never normal operation. Following it would replay the
|
||||
// request, including the operator's x-api-key (which Go does NOT strip
|
||||
// on cross-host redirects), to an unvetted host and leak the key
|
||||
// (GHSA-3mj3-57v2-4636). It also imposes no body deadline, so streaming
|
||||
// SSE responses that legitimately last minutes are not truncated.
|
||||
return &CloudProxy{client: httpclient.New()}
|
||||
}
|
||||
|
||||
func (c *CloudProxy) Load(opts *pb.ModelOptions) error {
|
||||
@@ -426,4 +433,3 @@ func isHopByHopHeader(name string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// Release represents a LocalAI release
|
||||
@@ -67,9 +68,7 @@ func NewReleaseManager() *ReleaseManager {
|
||||
CurrentVersion: internal.PrintableVersion(),
|
||||
ChecksumsPath: checksumsPath,
|
||||
MetadataPath: metadataPath,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
HTTPClient: httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// RegistrationClient talks to the frontend's /api/node/* endpoints.
|
||||
@@ -37,7 +39,7 @@ func (c *RegistrationClient) httpTimeout() time.Duration {
|
||||
// httpClient returns the shared HTTP client, initializing it on first use.
|
||||
func (c *RegistrationClient) httpClient() *http.Client {
|
||||
c.clientOnce.Do(func() {
|
||||
c.client = &http.Client{Timeout: c.httpTimeout()}
|
||||
c.client = httpclient.NewWithTimeout(c.httpTimeout())
|
||||
})
|
||||
return c.client
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// Define a struct to hold the store API client
|
||||
@@ -47,7 +49,7 @@ type FindResponse struct {
|
||||
func NewStoreClient(baseUrl string) *StoreClient {
|
||||
return &StoreClient{
|
||||
BaseURL: baseUrl,
|
||||
Client: &http.Client{},
|
||||
Client: httpclient.New(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,10 +9,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -55,7 +56,7 @@ var allowedFields = map[string]bool{
|
||||
func main() {
|
||||
fmt.Fprintf(os.Stderr, "Fetching %s ...\n", unslothURL)
|
||||
|
||||
resp, err := http.Get(unslothURL)
|
||||
resp, err := httpclient.New(httpclient.WithFollowRedirects()).Get(unslothURL)
|
||||
if err != nil {
|
||||
fatal("fetch failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
githubOAuth "golang.org/x/oauth2/github"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// providerEntry holds the OAuth2/OIDC config for a single provider.
|
||||
@@ -389,7 +391,7 @@ func fetchGitHubUserInfoAsOAuth(ctx context.Context, accessToken string) (*oauth
|
||||
}
|
||||
|
||||
func fetchGitHubUserInfo(ctx context.Context, accessToken string) (*githubUserInfo, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
client := httpclient.NewWithTimeout(10 * time.Second)
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
@@ -420,7 +422,7 @@ func fetchGitHubUserInfo(ctx context.Context, accessToken string) (*githubUserIn
|
||||
}
|
||||
|
||||
func fetchGitHubPrimaryEmail(ctx context.Context, accessToken string) (string, error) {
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
client := httpclient.NewWithTimeout(10 * time.Second)
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
@@ -458,7 +460,6 @@ func fetchGitHubPrimaryEmail(ctx context.Context, accessToken string) (string, e
|
||||
return "", fmt.Errorf("no verified email found")
|
||||
}
|
||||
|
||||
|
||||
func generateState() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
@@ -22,7 +24,9 @@ import (
|
||||
// decoding on the leading `data:` bytes.
|
||||
var audioDataURIPattern = regexp.MustCompile(`^data:[^,]+?;base64,`)
|
||||
|
||||
var audioDownloadClient = http.Client{Timeout: 30 * time.Second}
|
||||
// Downloading user-supplied media URLs legitimately follows redirects (CDNs);
|
||||
// WithFollowRedirects still strips any credential header on a cross-host hop.
|
||||
var audioDownloadClient = httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects())
|
||||
|
||||
// decodeAudioInput materialises a URL / data-URI / raw-base64 audio
|
||||
// payload to a temporary file and returns its path plus a cleanup
|
||||
|
||||
@@ -11,9 +11,11 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
)
|
||||
|
||||
// CORSProxyEndpoint proxies HTTP requests to external MCP servers,
|
||||
@@ -77,7 +79,7 @@ func CORSProxyEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
)
|
||||
},
|
||||
}
|
||||
client := &http.Client{Transport: transport, Timeout: 10 * time.Minute}
|
||||
client := httpclient.New(httpclient.WithTransport(transport), httpclient.WithTimeout(10*time.Minute))
|
||||
|
||||
xlog.Debug("CORS proxy request", "method", c.Request().Method, "target", targetURL)
|
||||
|
||||
|
||||
@@ -16,14 +16,16 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/xlog"
|
||||
"gorm.io/gorm"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// nodeError builds a schema.ErrorResponse for node endpoints.
|
||||
@@ -65,15 +67,15 @@ func GetNodeEndpoint(registry *nodes.NodeRegistry) echo.HandlerFunc {
|
||||
|
||||
// RegisterNodeRequest is the request body for registering a new worker node.
|
||||
type RegisterNodeRequest struct {
|
||||
Name string `json:"name"`
|
||||
NodeType string `json:"node_type,omitempty"` // "backend" (default) or "agent"
|
||||
Address string `json:"address"`
|
||||
HTTPAddress string `json:"http_address,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
TotalVRAM uint64 `json:"total_vram,omitempty"`
|
||||
AvailableVRAM uint64 `json:"available_vram,omitempty"`
|
||||
TotalRAM uint64 `json:"total_ram,omitempty"`
|
||||
AvailableRAM uint64 `json:"available_ram,omitempty"`
|
||||
Name string `json:"name"`
|
||||
NodeType string `json:"node_type,omitempty"` // "backend" (default) or "agent"
|
||||
Address string `json:"address"`
|
||||
HTTPAddress string `json:"http_address,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
TotalVRAM uint64 `json:"total_vram,omitempty"`
|
||||
AvailableVRAM uint64 `json:"available_vram,omitempty"`
|
||||
TotalRAM uint64 `json:"total_ram,omitempty"`
|
||||
AvailableRAM uint64 `json:"available_ram,omitempty"`
|
||||
GPUVendor string `json:"gpu_vendor,omitempty"`
|
||||
Labels map[string]string `json:"labels,omitempty"`
|
||||
// MaxReplicasPerModel is the per-node cap on replicas of any single model.
|
||||
@@ -983,6 +985,6 @@ func proxyHTTPToWorker(httpAddress, path, token string) (*http.Response, error)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
client := httpclient.NewWithTimeout(15 * time.Second)
|
||||
return client.Do(req)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -15,18 +14,23 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
var videoDownloadClient = http.Client{Timeout: 30 * time.Second}
|
||||
// Downloading user-supplied media URLs legitimately follows redirects (CDNs);
|
||||
// WithFollowRedirects still strips any credential header on a cross-host hop.
|
||||
var videoDownloadClient = httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects())
|
||||
|
||||
func downloadFile(url string) (string, error) {
|
||||
if err := utils.ValidateExternalURL(url); err != nil {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/services/messaging"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
@@ -180,10 +181,10 @@ func SessionsFromMCPConfig(
|
||||
for _, server := range remote.Servers {
|
||||
xlog.Debug("[MCP remote server] Configuration", "server", server)
|
||||
// Create HTTP client with custom roundtripper for bearer token injection
|
||||
httpClient := &http.Client{
|
||||
Timeout: config.DefaultMCPToolTimeout,
|
||||
Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport),
|
||||
}
|
||||
httpClient := httpclient.New(
|
||||
httpclient.WithTimeout(config.DefaultMCPToolTimeout),
|
||||
httpclient.WithTransport(newBearerTokenRoundTripper(server.Token, httpclient.HardenedTransport())),
|
||||
)
|
||||
|
||||
transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient}
|
||||
mcpSession, err := client.Connect(ctx, transport, nil)
|
||||
@@ -262,10 +263,10 @@ func NamedSessionsFromMCPConfig(
|
||||
|
||||
for serverName, server := range remote.Servers {
|
||||
xlog.Debug("[MCP remote server] Configuration", "name", serverName, "server", server)
|
||||
httpClient := &http.Client{
|
||||
Timeout: config.DefaultMCPToolTimeout,
|
||||
Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport),
|
||||
}
|
||||
httpClient := httpclient.New(
|
||||
httpclient.WithTimeout(config.DefaultMCPToolTimeout),
|
||||
httpclient.WithTransport(newBearerTokenRoundTripper(server.Token, httpclient.HardenedTransport())),
|
||||
)
|
||||
|
||||
transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient}
|
||||
mcpSession, err := client.Connect(ctx, transport, nil)
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -16,15 +15,18 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
func downloadFile(url string) (string, error) {
|
||||
@@ -33,7 +35,7 @@ func downloadFile(url string) (string, error) {
|
||||
}
|
||||
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
resp, err := httpclient.New(httpclient.WithFollowRedirects()).Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -21,17 +21,19 @@ import (
|
||||
|
||||
"github.com/Masterminds/sprig/v3"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/cogito"
|
||||
"github.com/mudler/cogito/clients"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/jobs"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/xsync"
|
||||
"github.com/mudler/cogito"
|
||||
"github.com/mudler/cogito/clients"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
// AgentJobService manages agent tasks and job execution
|
||||
@@ -647,7 +649,7 @@ func (s *AgentJobService) fetchMultimediaFromURL(url string, headers map[string]
|
||||
}
|
||||
|
||||
// Execute request
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := httpclient.NewWithTimeout(30 * time.Second)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch URL: %w", err)
|
||||
@@ -1249,7 +1251,7 @@ func (s *AgentJobService) sendWebhook(job schema.Job, task schema.Task, webhookC
|
||||
}
|
||||
|
||||
// Execute with retry
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
client := httpclient.NewWithTimeout(30 * time.Second)
|
||||
err = s.executeWithRetry(client, req)
|
||||
if err != nil {
|
||||
xlog.Error("Webhook delivery failed", "error", err, "job_id", job.ID, "webhook_url", webhookConfig.URL)
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
|
||||
"github.com/mudler/cogito"
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// KBSearchResult represents a search result from the knowledge base.
|
||||
@@ -61,7 +63,7 @@ func KBAutoSearchPrompt(ctx context.Context, apiURL, apiKey, collection, query s
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := httpclient.New().Do(req)
|
||||
if err != nil {
|
||||
xlog.Warn("KB auto-search: request failed", "error", err)
|
||||
return ""
|
||||
@@ -181,7 +183,7 @@ func KBStoreContent(ctx context.Context, apiURL, apiKey, collection, content, us
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := httpclient.New().Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload request failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -12,12 +12,14 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/ssewire"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/piiadapter"
|
||||
"github.com/mudler/xlog"
|
||||
"golang.org/x/net/http2"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// PIIHandlerOptions configures NewPIIHandler.
|
||||
@@ -87,7 +89,14 @@ func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler {
|
||||
}
|
||||
|
||||
d := &piiDispatcher{
|
||||
client: &http.Client{Transport: transport},
|
||||
// Refuse redirects: the MITM client forwards to the real
|
||||
// upstream over TLS, and a 3xx means the upstream (or something
|
||||
// impersonating it) is trying to bounce the request elsewhere.
|
||||
// Following it would replay caller headers — including provider
|
||||
// API keys such as Anthropic's x-api-key, which Go does NOT
|
||||
// strip on cross-host redirects — to an unvetted host. Surface
|
||||
// it as an error (handled as a 502) instead.
|
||||
client: httpclient.New(httpclient.WithTransport(transport)),
|
||||
redactor: opts.Redactor,
|
||||
store: opts.EventStore,
|
||||
patternAction: patternAction,
|
||||
|
||||
@@ -123,6 +123,25 @@ var _ = Describe("PIIHandler", func() {
|
||||
Expect(store.recorded()).NotTo(BeZero(), "no PIIEvent recorded for the email match")
|
||||
})
|
||||
|
||||
It("refuses to follow an upstream redirect", func() {
|
||||
// A 3xx from the upstream would otherwise be followed, replaying
|
||||
// the request (and its provider API key, e.g. Anthropic's
|
||||
// x-api-key which Go does NOT strip on cross-host redirects) to
|
||||
// the Location host. The refused redirect surfaces as a 502.
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "https://evil.example.com/steal", http.StatusFound)
|
||||
})
|
||||
|
||||
client, base, _, cleanup := startPIITestRig(upstream)
|
||||
defer cleanup()
|
||||
|
||||
body := `{"model":"claude-3-5-sonnet","max_tokens":100,"messages":[{"role":"user","content":"hello"}]}`
|
||||
resp, err := client.Post(base+"/v1/messages", "application/json", strings.NewReader(body))
|
||||
Expect(err).NotTo(HaveOccurred(), "client.Post")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusBadGateway), "refused redirect must surface as 502, not be followed")
|
||||
})
|
||||
|
||||
It("blocks api key in request", func() {
|
||||
upstreamCalled := false
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -16,9 +16,11 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/storage"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/xlog"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// HTTPFileStager implements FileStager using HTTP for environments without S3.
|
||||
@@ -67,14 +69,12 @@ func NewHTTPFileStager(httpAddrFor func(nodeID string) (string, error), token st
|
||||
return &HTTPFileStager{
|
||||
httpAddrFor: httpAddrFor,
|
||||
token: token,
|
||||
client: &http.Client{
|
||||
// No Timeout set — for large uploads, http.Client.Timeout covers the
|
||||
// entire request lifecycle including the body upload. If it fires
|
||||
// mid-write, Go closes the connection causing "connection reset by peer"
|
||||
// on the server. Instead we use ResponseHeaderTimeout on the transport
|
||||
// to cover only the wait-for-server-response phase.
|
||||
Transport: transport,
|
||||
},
|
||||
// No Timeout set — for large uploads, http.Client.Timeout covers the
|
||||
// entire request lifecycle including the body upload. If it fires
|
||||
// mid-write, Go closes the connection causing "connection reset by peer"
|
||||
// on the server. Instead we use ResponseHeaderTimeout on the transport
|
||||
// to cover only the wait-for-server-response phase.
|
||||
client: httpclient.New(httpclient.WithTransport(transport)),
|
||||
responseTimeout: responseTimeout,
|
||||
maxRetries: maxRetries,
|
||||
}
|
||||
|
||||
@@ -5,8 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
type HuggingFaceScanResult struct {
|
||||
@@ -29,7 +30,7 @@ func HuggingFaceScan(uri URI) (*HuggingFaceScanResult, error) {
|
||||
if len(cleanParts) <= 4 || (cleanParts[2] != "huggingface.co" && cleanParts[2] != hfHost) {
|
||||
return nil, ErrNonHuggingFaceFile
|
||||
}
|
||||
results, err := http.Get(fmt.Sprintf("%s/api/models/%s/%s/scan", HF_ENDPOINT, cleanParts[3], cleanParts[4]))
|
||||
results, err := httpclient.New(httpclient.WithFollowRedirects()).Get(fmt.Sprintf("%s/api/models/%s/%s/scan", HF_ENDPOINT, cleanParts[3], cleanParts[4]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -17,10 +17,12 @@ import (
|
||||
"github.com/google/go-containerregistry/pkg/v1/tarball"
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
"github.com/mudler/LocalAI/pkg/oci"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/mudler/LocalAI/pkg/xio"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -171,7 +173,7 @@ func (uri URI) ReadWithAuthorizationAndCallback(ctx context.Context, basePath st
|
||||
req.Header.Add("Authorization", authorization)
|
||||
}
|
||||
|
||||
response, err := http.DefaultClient.Do(req)
|
||||
response, err := downloadClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -347,9 +349,15 @@ func calculateHashForPartialFile(file *os.File) (hash.Hash, error) {
|
||||
return hash, nil
|
||||
}
|
||||
|
||||
// downloadClient is the shared client for HTTP(S) downloads and size
|
||||
// probes. It follows redirects (model hosts and CDNs rely on them) but
|
||||
// strips credential headers on any cross-host hop, and sets no body
|
||||
// deadline so large downloads are not truncated.
|
||||
var downloadClient = httpclient.New(httpclient.WithFollowRedirects())
|
||||
|
||||
func (uri URI) checkSeverSupportsRangeHeader() (bool, error) {
|
||||
url := uri.ResolveURL()
|
||||
resp, err := http.Head(url)
|
||||
resp, err := downloadClient.Head(url)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -376,7 +384,7 @@ func (u URI) ContentLength(ctx context.Context) (int64, error) {
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := downloadClient.Do(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -395,7 +403,7 @@ func (u URI) ContentLength(ctx context.Context) (int64, error) {
|
||||
return 0, err
|
||||
}
|
||||
req2.Header.Set("Range", "bytes=0-0")
|
||||
resp2, err := http.DefaultClient.Do(req2)
|
||||
resp2, err := downloadClient.Do(req2)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -584,7 +592,7 @@ func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string
|
||||
contentLength = l.Size()
|
||||
} else {
|
||||
// Start the request
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := downloadClient.Do(req)
|
||||
if err != nil {
|
||||
// Check if error is due to context cancellation
|
||||
if errors.Is(err, context.Canceled) {
|
||||
|
||||
201
pkg/httpclient/client.go
Normal file
201
pkg/httpclient/client.go
Normal file
@@ -0,0 +1,201 @@
|
||||
// Package httpclient provides hardened *http.Client constructors for all
|
||||
// outbound HTTP traffic in LocalAI.
|
||||
//
|
||||
// Direct use of net/http's default client (http.DefaultClient, http.Get,
|
||||
// http.Post, ...) or a bare http.Client{} is forbidden by lint (forbidigo).
|
||||
// The reason is GHSA-3mj3-57v2-4636: the standard client follows up to 10
|
||||
// redirects by default, and on a *cross-host* redirect Go forwards custom
|
||||
// request headers — including credential headers such as Anthropic's
|
||||
// x-api-key — to the redirect target. (Go strips Authorization, Cookie and
|
||||
// WWW-Authenticate cross-host, but NOT arbitrary custom headers.) An attacker
|
||||
// who can elicit a redirect from an upstream then harvests the credential.
|
||||
//
|
||||
// Every client built here refuses redirects by default (see NoRedirect). The
|
||||
// rare caller that genuinely must follow redirects should opt in with
|
||||
// WithFollowRedirects, which still strips credential headers on host change.
|
||||
//
|
||||
// Streaming note: New() intentionally sets NO client-level Timeout, because a
|
||||
// global timeout also bounds the response body and would truncate long-lived
|
||||
// SSE streams (chat completions can stream for minutes). Per-request deadlines
|
||||
// belong on the request context. Use NewWithTimeout for simple, non-streaming
|
||||
// request/response calls.
|
||||
package httpclient
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// Transport-level bounds. These cap connection setup, NOT the response
|
||||
// body, so they are safe for streaming responses.
|
||||
dialTimeout = 30 * time.Second
|
||||
dialKeepAlive = 30 * time.Second
|
||||
tlsHandshakeTimeout = 10 * time.Second
|
||||
idleConnTimeout = 90 * time.Second
|
||||
expectContinueTimeout = 1 * time.Second
|
||||
maxIdleConns = 100
|
||||
|
||||
// maxRedirects bounds WithFollowRedirects chains (mirrors the net/http
|
||||
// default) so an opt-in follower can't be spun forever by a redirect loop.
|
||||
maxRedirects = 10
|
||||
)
|
||||
|
||||
// sensitiveHeaders are credential-bearing request headers that must never be
|
||||
// replayed to a different host on a redirect. Go already drops the first three
|
||||
// cross-host; the rest are custom headers Go does not know about. Compared
|
||||
// case-insensitively via http.Header canonicalisation.
|
||||
var sensitiveHeaders = []string{
|
||||
"Authorization",
|
||||
"Www-Authenticate",
|
||||
"Cookie",
|
||||
"Proxy-Authorization",
|
||||
"X-Api-Key", // Anthropic, and many OpenAI-compatible providers
|
||||
"Api-Key", // Azure OpenAI
|
||||
"X-Auth-Token", // common custom scheme
|
||||
"X-Goog-Api-Key", // Google
|
||||
}
|
||||
|
||||
// ErrRedirectBlocked is wrapped by the error NoRedirect returns, so callers can
|
||||
// distinguish "the upstream tried to redirect us" from other transport errors
|
||||
// via errors.Is.
|
||||
var ErrRedirectBlocked = errors.New("httpclient: redirect blocked")
|
||||
|
||||
// NoRedirect is an http.Client.CheckRedirect policy that refuses to follow any
|
||||
// redirect, surfacing it as an error instead. This is the default for clients
|
||||
// built by New/NewWithTimeout. The error uses URL.Redacted() so userinfo in
|
||||
// the target URL is not written to logs.
|
||||
func NoRedirect(req *http.Request, _ []*http.Request) error {
|
||||
return fmt.Errorf("%w: refusing to follow redirect to %s (set httpclient.WithFollowRedirects to opt in)", ErrRedirectBlocked, req.URL.Redacted())
|
||||
}
|
||||
|
||||
// stripAuthOnRedirect follows redirects but deletes credential headers whenever
|
||||
// the redirect crosses to a different host, closing the cross-host credential
|
||||
// leak while still allowing same-host or non-authenticated redirect chains.
|
||||
func stripAuthOnRedirect(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= maxRedirects {
|
||||
return fmt.Errorf("httpclient: stopped after %d redirects", maxRedirects)
|
||||
}
|
||||
prev := via[len(via)-1]
|
||||
if !sameOrigin(prev.URL, req.URL) {
|
||||
for _, h := range sensitiveHeaders {
|
||||
req.Header.Del(h)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sameOrigin reports whether two URLs share scheme AND host (including port).
|
||||
// Deliberately strict: a different port or scheme is treated as a different
|
||||
// origin so credential headers are stripped. This avoids the curl
|
||||
// CVE-2022-27774 class of bug where ports were ignored and credentials leaked
|
||||
// to a different service on the same hostname.
|
||||
func sameOrigin(a, b *url.URL) bool {
|
||||
return strings.EqualFold(a.Scheme, b.Scheme) && strings.EqualFold(a.Host, b.Host)
|
||||
}
|
||||
|
||||
// HardenedTransport returns a fresh *http.Transport with a TLS 1.2 floor and
|
||||
// bounded connection setup. Callers that need to wrap or extend the transport
|
||||
// (e.g. a credential-injecting RoundTripper) should base it on this rather than
|
||||
// http.DefaultTransport so the TLS floor and timeouts are preserved.
|
||||
func HardenedTransport() *http.Transport {
|
||||
return &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: dialTimeout,
|
||||
KeepAlive: dialKeepAlive,
|
||||
}).DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: maxIdleConns,
|
||||
IdleConnTimeout: idleConnTimeout,
|
||||
TLSHandshakeTimeout: tlsHandshakeTimeout,
|
||||
ExpectContinueTimeout: expectContinueTimeout,
|
||||
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
|
||||
}
|
||||
}
|
||||
|
||||
type options struct {
|
||||
timeout time.Duration
|
||||
transport http.RoundTripper
|
||||
followRedirects bool
|
||||
}
|
||||
|
||||
// Option configures a client built by New.
|
||||
type Option func(*options)
|
||||
|
||||
// WithTimeout sets an overall client Timeout (covers the entire exchange
|
||||
// including reading the body). Do NOT use this for streaming endpoints; prefer
|
||||
// a per-request context deadline there. Equivalent to NewWithTimeout.
|
||||
func WithTimeout(d time.Duration) Option { return func(o *options) { o.timeout = d } }
|
||||
|
||||
// WithTransport supplies a custom RoundTripper (e.g. an IP-pinned dialer or a
|
||||
// credential-injecting wrapper). The caller is responsible for the transport's
|
||||
// TLS configuration; base it on HardenedTransport to keep the TLS floor.
|
||||
func WithTransport(rt http.RoundTripper) Option { return func(o *options) { o.transport = rt } }
|
||||
|
||||
// WithFollowRedirects opts into following redirects, while still stripping
|
||||
// credential headers on any cross-host hop. Use only when an endpoint legitimately
|
||||
// redirects (e.g. some download CDNs) and the request carries a secret.
|
||||
func WithFollowRedirects() Option { return func(o *options) { o.followRedirects = true } }
|
||||
|
||||
// New returns a hardened *http.Client. By default it refuses redirects, sets a
|
||||
// TLS 1.2 floor, bounds connection setup, and imposes no body deadline (safe
|
||||
// for streaming). Apply Options to adjust.
|
||||
func New(opts ...Option) *http.Client {
|
||||
o := options{}
|
||||
for _, fn := range opts {
|
||||
fn(&o)
|
||||
}
|
||||
|
||||
rt := o.transport
|
||||
if rt == nil {
|
||||
rt = HardenedTransport()
|
||||
}
|
||||
|
||||
check := NoRedirect
|
||||
if o.followRedirects {
|
||||
check = stripAuthOnRedirect
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: rt,
|
||||
Timeout: o.timeout, // zero == no overall deadline (streaming-safe)
|
||||
CheckRedirect: check,
|
||||
}
|
||||
}
|
||||
|
||||
// NewWithTimeout returns a hardened client with an overall Timeout. Use for
|
||||
// simple request/response calls; for streaming, use New with a context deadline.
|
||||
func NewWithTimeout(timeout time.Duration, opts ...Option) *http.Client {
|
||||
return New(append([]Option{WithTimeout(timeout)}, opts...)...)
|
||||
}
|
||||
|
||||
// Harden applies the default hardening (refuse redirects, TLS 1.2 floor) to an
|
||||
// existing client in place, for the cases where a third-party library hands us
|
||||
// a *http.Client to configure rather than letting us construct one. It returns
|
||||
// the same client for convenience. A nil client is left nil.
|
||||
func Harden(c *http.Client) *http.Client {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if c.CheckRedirect == nil {
|
||||
c.CheckRedirect = NoRedirect
|
||||
}
|
||||
switch t := c.Transport.(type) {
|
||||
case nil:
|
||||
c.Transport = HardenedTransport()
|
||||
case *http.Transport:
|
||||
if t.TLSClientConfig == nil {
|
||||
t.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
} else if t.TLSClientConfig.MinVersion == 0 {
|
||||
t.TLSClientConfig.MinVersion = tls.VersionTLS12
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
132
pkg/httpclient/client_test.go
Normal file
132
pkg/httpclient/client_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package httpclient_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestHTTPClient(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "httpclient suite")
|
||||
}
|
||||
|
||||
var _ = Describe("httpclient", func() {
|
||||
Describe("New (default)", func() {
|
||||
It("refuses to follow redirects and never reaches the redirect target", func() {
|
||||
sinkHit := make(chan string, 1)
|
||||
sink := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sinkHit <- r.Header.Get("X-Api-Key")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer sink.Close()
|
||||
|
||||
redirector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, sink.URL, http.StatusFound)
|
||||
}))
|
||||
defer redirector.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, redirector.URL, nil)
|
||||
req.Header.Set("X-Api-Key", "secret")
|
||||
|
||||
_, err := httpclient.New().Do(req)
|
||||
Expect(err).To(HaveOccurred(), "redirect must surface as an error")
|
||||
Expect(errors.Is(err, httpclient.ErrRedirectBlocked)).To(BeTrue(), "error should wrap ErrRedirectBlocked")
|
||||
Expect(sinkHit).NotTo(Receive(), "the redirect target must never be contacted")
|
||||
})
|
||||
|
||||
It("sets no overall timeout (streaming-safe) by default", func() {
|
||||
Expect(httpclient.New().Timeout).To(BeZero())
|
||||
})
|
||||
|
||||
It("sets a TLS 1.2 floor on the default transport", func() {
|
||||
c := httpclient.New()
|
||||
t, ok := c.Transport.(*http.Transport)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(t.TLSClientConfig).NotTo(BeNil())
|
||||
Expect(t.TLSClientConfig.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("NewWithTimeout", func() {
|
||||
It("applies the overall timeout", func() {
|
||||
Expect(httpclient.NewWithTimeout(5 * time.Second).Timeout).To(Equal(5 * time.Second))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("WithFollowRedirects", func() {
|
||||
It("follows same-host redirects keeping the credential header", func() {
|
||||
got := make(chan string, 2)
|
||||
var srv *httptest.Server
|
||||
srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/start" {
|
||||
http.Redirect(w, r, srv.URL+"/end", http.StatusFound)
|
||||
return
|
||||
}
|
||||
got <- r.Header.Get("X-Api-Key")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, srv.URL+"/start", nil)
|
||||
req.Header.Set("X-Api-Key", "secret")
|
||||
|
||||
resp, err := httpclient.New(httpclient.WithFollowRedirects()).Do(req)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_ = resp.Body.Close()
|
||||
Expect(<-got).To(Equal("secret"), "same-host redirect should preserve the header")
|
||||
})
|
||||
|
||||
It("strips credential headers on a cross-host redirect", func() {
|
||||
sinkKey := make(chan string, 1)
|
||||
sink := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sinkKey <- r.Header.Get("X-Api-Key")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer sink.Close()
|
||||
|
||||
redirector := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, sink.URL, http.StatusFound)
|
||||
}))
|
||||
defer redirector.Close()
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, redirector.URL, nil)
|
||||
req.Header.Set("X-Api-Key", "secret")
|
||||
|
||||
resp, err := httpclient.New(httpclient.WithFollowRedirects()).Do(req)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
_ = resp.Body.Close()
|
||||
Expect(<-sinkKey).To(BeEmpty(), "x-api-key must be stripped crossing to a different host")
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Harden", func() {
|
||||
It("adds NoRedirect and a TLS floor to a bare client without clobbering existing config", func() {
|
||||
c := httpclient.Harden(&http.Client{})
|
||||
Expect(c.CheckRedirect).NotTo(BeNil())
|
||||
t, ok := c.Transport.(*http.Transport)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(t.TLSClientConfig.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
|
||||
})
|
||||
|
||||
It("returns nil for a nil client", func() {
|
||||
Expect(httpclient.Harden(nil)).To(BeNil())
|
||||
})
|
||||
|
||||
It("preserves a caller-supplied CheckRedirect", func() {
|
||||
sentinel := errors.New("mine")
|
||||
c := httpclient.Harden(&http.Client{
|
||||
CheckRedirect: func(*http.Request, []*http.Request) error { return sentinel },
|
||||
})
|
||||
Expect(c.CheckRedirect(nil, nil)).To(Equal(sentinel))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// Model represents a model from the Hugging Face API
|
||||
@@ -94,7 +96,7 @@ type Client struct {
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
baseURL: "https://huggingface.co/api/models",
|
||||
client: &http.Client{},
|
||||
client: httpclient.New(httpclient.WithFollowRedirects()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/modeladmin"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
localaitools "github.com/mudler/LocalAI/pkg/mcp/localaitools"
|
||||
"github.com/mudler/LocalAI/pkg/vram"
|
||||
)
|
||||
@@ -36,11 +37,9 @@ type Client struct {
|
||||
// New returns a Client targeting baseURL with an optional bearer token.
|
||||
func New(baseURL, apiKey string) *Client {
|
||||
return &Client{
|
||||
BaseURL: strings.TrimRight(baseURL, "/"),
|
||||
APIKey: apiKey,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
BaseURL: strings.TrimRight(baseURL, "/"),
|
||||
APIKey: apiKey,
|
||||
HTTPClient: httpclient.NewWithTimeout(60 * time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,8 +393,8 @@ func (c *Client) UpgradeBackend(ctx context.Context, name string) (string, error
|
||||
|
||||
func (c *Client) SystemInfo(ctx context.Context) (*localaitools.SystemInfo, error) {
|
||||
var welcome struct {
|
||||
Version string `json:"Version"`
|
||||
LoadedModels []any `json:"LoadedModels"`
|
||||
Version string `json:"Version"`
|
||||
LoadedModels []any `json:"LoadedModels"`
|
||||
InstalledBackends map[string]bool `json:"InstalledBackends"`
|
||||
}
|
||||
if err := c.do(ctx, http.MethodGet, routeWelcome, nil, &welcome); err != nil {
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"net/http"
|
||||
|
||||
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
// Define the main struct for the JSON data
|
||||
@@ -45,7 +47,7 @@ func OllamaModelManifest(image string) (*Manifest, error) {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||
client := &http.Client{}
|
||||
client := httpclient.New(httpclient.WithFollowRedirects())
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -4,17 +4,16 @@ import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
|
||||
var base64DownloadClient http.Client = http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
var base64DownloadClient = httpclient.NewWithTimeout(30*time.Second, httpclient.WithFollowRedirects())
|
||||
|
||||
// Match `data:<mime>[;param=value...];base64,` — browser-produced data URIs
|
||||
// often carry codec/charset params between the mime type and `;base64,`
|
||||
|
||||
Reference in New Issue
Block a user