diff --git a/.github/gallery-agent/main.go b/.github/gallery-agent/main.go index d9c6449f7..57431a672 100644 --- a/.github/gallery-agent/main.go +++ b/.github/gallery-agent/main.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/json" + "errors" "fmt" "os" "strconv" @@ -113,6 +114,17 @@ func main() { fmt.Println("Searching for trending models on HuggingFace...") rawModels, err := client.GetTrending(searchTerm, limit) if err != nil { + if errors.Is(err, hfapi.ErrRateLimited) { + fmt.Printf("HuggingFace API is rate limited after retries, skipping this run: %v\n", err) + writeSummary(AddedModelSummary{ + SearchTerm: searchTerm, + TotalFound: 0, + ModelsAdded: 0, + Quantization: quantization, + ProcessingTime: time.Since(startTime).String(), + }) + return + } fmt.Fprintf(os.Stderr, "Error fetching models: %v\n", err) os.Exit(1) } @@ -277,4 +289,3 @@ func truncateString(s string, maxLen int) string { } return s[:maxLen] + "..." } - diff --git a/pkg/huggingface-api/client.go b/pkg/huggingface-api/client.go index 840dcf90c..175c38574 100644 --- a/pkg/huggingface-api/client.go +++ b/pkg/huggingface-api/client.go @@ -2,6 +2,7 @@ package hfapi import ( "encoding/json" + "errors" "fmt" "io" "net/http" @@ -10,6 +11,7 @@ import ( "sort" "strconv" "strings" + "time" "github.com/mudler/LocalAI/pkg/httpclient" ) @@ -88,57 +90,128 @@ type SearchParams struct { // Client represents a Hugging Face API client type Client struct { - baseURL string - client *http.Client + baseURL string + client *http.Client + maxRetries int + retryBackoff time.Duration + maxBackoff time.Duration + sleepFn func(time.Duration) } +var ErrRateLimited = errors.New("huggingface API rate limited") + // NewClient creates a new Hugging Face API client func NewClient() *Client { return &Client{ - baseURL: "https://huggingface.co/api/models", - client: httpclient.New(httpclient.WithFollowRedirects()), + baseURL: "https://huggingface.co/api/models", + client: httpclient.New(httpclient.WithFollowRedirects()), + maxRetries: 5, + retryBackoff: 1 * time.Second, + maxBackoff: 30 * time.Second, + sleepFn: time.Sleep, } } // SearchModels searches for models using the Hugging Face API func (c *Client) SearchModels(params SearchParams) ([]Model, error) { - req, err := http.NewRequest("GET", c.baseURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) + for attempt := 1; attempt <= c.maxRetries; attempt++ { + req, err := http.NewRequest("GET", c.baseURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Add query parameters + q := req.URL.Query() + q.Add("sort", params.Sort) + q.Add("direction", fmt.Sprintf("%d", params.Direction)) + q.Add("limit", fmt.Sprintf("%d", params.Limit)) + q.Add("search", params.Search) + req.URL.RawQuery = q.Encode() + + resp, err := c.client.Do(req) + if err != nil { + if attempt < c.maxRetries { + c.sleepFn(c.exponentialBackoff(attempt)) + continue + } + return nil, fmt.Errorf("failed to make request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + if err := resp.Body.Close(); err != nil { + return nil, fmt.Errorf("failed to close response body: %w", err) + } + if c.isRetryableStatus(resp.StatusCode) && attempt < c.maxRetries { + c.sleepFn(c.retryDelay(resp, attempt)) + continue + } + if resp.StatusCode == http.StatusTooManyRequests { + return nil, fmt.Errorf("%w: failed to fetch models. Status code: %d", ErrRateLimited, resp.StatusCode) + } + return nil, fmt.Errorf("failed to fetch models. Status code: %d", resp.StatusCode) + } + + // Read the response body + body, err := io.ReadAll(resp.Body) + closeErr := resp.Body.Close() + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + if closeErr != nil { + return nil, fmt.Errorf("failed to close response body: %w", closeErr) + } + + // Parse the JSON response + var models []Model + if err := json.Unmarshal(body, &models); err != nil { + return nil, fmt.Errorf("failed to parse JSON response: %w", err) + } + + return models, nil } - // Add query parameters - q := req.URL.Query() - q.Add("sort", params.Sort) - q.Add("direction", fmt.Sprintf("%d", params.Direction)) - q.Add("limit", fmt.Sprintf("%d", params.Limit)) - q.Add("search", params.Search) - req.URL.RawQuery = q.Encode() + return nil, fmt.Errorf("%w: failed to fetch models. Status code: %d", ErrRateLimited, http.StatusTooManyRequests) +} - // Make the HTTP request - resp, err := c.client.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to make request: %w", err) - } - defer resp.Body.Close() +func (c *Client) isRetryableStatus(code int) bool { + return code == http.StatusTooManyRequests || (code >= http.StatusInternalServerError && code <= http.StatusNetworkAuthenticationRequired) +} - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("failed to fetch models. Status code: %d", resp.StatusCode) +func (c *Client) retryDelay(resp *http.Response, attempt int) time.Duration { + if retryAfter := strings.TrimSpace(resp.Header.Get("Retry-After")); retryAfter != "" { + if seconds, err := strconv.Atoi(retryAfter); err == nil && seconds > 0 { + delay := time.Duration(seconds) * time.Second + if delay > c.maxBackoff { + return c.maxBackoff + } + return delay + } + if at, err := http.ParseTime(retryAfter); err == nil { + delay := time.Until(at) + if delay > 0 { + if delay > c.maxBackoff { + return c.maxBackoff + } + return delay + } + } } - // Read the response body - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) - } + return c.exponentialBackoff(attempt) +} - // Parse the JSON response - var models []Model - if err := json.Unmarshal(body, &models); err != nil { - return nil, fmt.Errorf("failed to parse JSON response: %w", err) +func (c *Client) exponentialBackoff(attempt int) time.Duration { + delay := c.retryBackoff + for i := 1; i < attempt; i++ { + delay *= 2 + if delay >= c.maxBackoff { + return c.maxBackoff + } } - - return models, nil + if delay > c.maxBackoff { + return c.maxBackoff + } + return delay } // GetLatest fetches the latest GGUF models diff --git a/pkg/huggingface-api/client_test.go b/pkg/huggingface-api/client_test.go index b90df571f..feac4dba3 100644 --- a/pkg/huggingface-api/client_test.go +++ b/pkg/huggingface-api/client_test.go @@ -1,10 +1,12 @@ package hfapi_test import ( + "errors" "fmt" "net/http" "net/http/httptest" "strings" + "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -185,6 +187,87 @@ var _ = Describe("HuggingFace API Client", func() { Expect(err.Error()).To(ContainSubstring("failed to parse JSON response")) Expect(models).To(BeNil()) }) + + It("should retry 429 responses and honor Retry-After", func() { + attempts := 0 + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("[]")) + Expect(err).ToNot(HaveOccurred()) + })) + client.SetBaseURL(server.URL) + + params := hfapi.SearchParams{ + Sort: "lastModified", + Direction: -1, + Limit: 30, + Search: "GGUF", + } + + start := time.Now() + models, err := client.SearchModels(params) + elapsed := time.Since(start) + + Expect(err).ToNot(HaveOccurred()) + Expect(models).To(HaveLen(0)) + Expect(attempts).To(Equal(2)) + Expect(elapsed).To(BeNumerically(">=", 900*time.Millisecond)) + }) + + It("should fail fast on non-retryable 4xx responses", func() { + attempts := 0 + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + w.WriteHeader(http.StatusBadRequest) + })) + client.SetBaseURL(server.URL) + + params := hfapi.SearchParams{ + Sort: "lastModified", + Direction: -1, + Limit: 30, + Search: "GGUF", + } + + start := time.Now() + models, err := client.SearchModels(params) + elapsed := time.Since(start) + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("Status code: 400")) + Expect(models).To(BeNil()) + Expect(attempts).To(Equal(1)) + Expect(elapsed).To(BeNumerically("<", 500*time.Millisecond)) + }) + + It("should return ErrRateLimited when 429 persists after retries", func() { + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "1") + w.WriteHeader(http.StatusTooManyRequests) + })) + client.SetBaseURL(server.URL) + + params := hfapi.SearchParams{ + Sort: "trendingScore", + Direction: -1, + Limit: 15, + Search: "GGUF", + } + + models, err := client.SearchModels(params) + + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, hfapi.ErrRateLimited)).To(BeTrue()) + Expect(err.Error()).To(ContainSubstring("Status code: 429")) + Expect(models).To(BeNil()) + }) }) Context("when getting latest GGUF models", func() {