diff --git a/pkg/downloader/huggingface.go b/pkg/downloader/huggingface.go index 9d7f1657f..1c3faa086 100644 --- a/pkg/downloader/huggingface.go +++ b/pkg/downloader/huggingface.go @@ -23,7 +23,10 @@ var ErrUnsafeFilesFound = errors.New("unsafe files found") func HuggingFaceScan(uri URI) (*HuggingFaceScanResult, error) { cleanParts := strings.Split(uri.ResolveURL(), "/") - if len(cleanParts) <= 4 || cleanParts[2] != "huggingface.co" && cleanParts[2] != HF_ENDPOINT { + // cleanParts[2] is the hostname from the URL (e.g. "huggingface.co" or "hf-mirror.com"). + // Extract the hostname from HF_ENDPOINT for comparison, since HF_ENDPOINT includes the scheme. + hfHost := strings.TrimPrefix(strings.TrimPrefix(HF_ENDPOINT, "https://"), "http://") + if len(cleanParts) <= 4 || (cleanParts[2] != "huggingface.co" && cleanParts[2] != hfHost) { return nil, ErrNonHuggingFaceFile } results, err := http.Get(fmt.Sprintf("%s/api/models/%s/%s/scan", HF_ENDPOINT, cleanParts[3], cleanParts[4])) diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go index 58764bb13..8c526399a 100644 --- a/pkg/downloader/uri.go +++ b/pkg/downloader/uri.go @@ -261,6 +261,13 @@ func (s URI) ResolveURL() string { return fmt.Sprintf("%s/%s/%s/resolve/%s/%s", HF_ENDPOINT, owner, repo, branch, filepath) } + // If a HuggingFace mirror is configured, rewrite direct https://huggingface.co/ URLs + // to use the mirror. This ensures gallery entries with hardcoded URLs also benefit + // from the mirror setting. + if HF_ENDPOINT != "https://huggingface.co" && strings.HasPrefix(string(s), "https://huggingface.co/") { + return HF_ENDPOINT + strings.TrimPrefix(string(s), "https://huggingface.co") + } + return string(s) } diff --git a/pkg/downloader/uri_test.go b/pkg/downloader/uri_test.go index bd895defc..3d2d3cbcb 100644 --- a/pkg/downloader/uri_test.go +++ b/pkg/downloader/uri_test.go @@ -49,6 +49,42 @@ var _ = Describe("Gallery API tests", func() { ).ToNot(HaveOccurred()) }) }) + + Context("HuggingFace mirror", func() { + var originalEndpoint string + + BeforeEach(func() { + originalEndpoint = HF_ENDPOINT + }) + + AfterEach(func() { + HF_ENDPOINT = originalEndpoint + }) + + It("rewrites direct https://huggingface.co URLs when mirror is set", func() { + HF_ENDPOINT = "https://hf-mirror.com" + uri := URI("https://huggingface.co/TheBloke/model-GGUF/resolve/main/model.Q4_K_M.gguf") + Expect(uri.ResolveURL()).To(Equal("https://hf-mirror.com/TheBloke/model-GGUF/resolve/main/model.Q4_K_M.gguf")) + }) + + It("does not rewrite direct https://huggingface.co URLs when no mirror is set", func() { + HF_ENDPOINT = "https://huggingface.co" + uri := URI("https://huggingface.co/TheBloke/model-GGUF/resolve/main/model.Q4_K_M.gguf") + Expect(uri.ResolveURL()).To(Equal("https://huggingface.co/TheBloke/model-GGUF/resolve/main/model.Q4_K_M.gguf")) + }) + + It("rewrites hf:// URIs when mirror is set", func() { + HF_ENDPOINT = "https://hf-mirror.com" + uri := URI("hf://TheBloke/model-GGUF/model.Q4_K_M.gguf") + Expect(uri.ResolveURL()).To(Equal("https://hf-mirror.com/TheBloke/model-GGUF/resolve/main/model.Q4_K_M.gguf")) + }) + + It("does not rewrite non-huggingface URLs", func() { + HF_ENDPOINT = "https://hf-mirror.com" + uri := URI("https://example.com/some/file.gguf") + Expect(uri.ResolveURL()).To(Equal("https://example.com/some/file.gguf")) + }) + }) }) var _ = Describe("ContentLength", func() {