mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-05 23:36:49 -04:00
Harden gallery-agent Hugging Face fetches against transient rate limiting (#10187)
* Initial plan * fix: retry HuggingFace trending fetch on transient rate limits * fix: handle body close/write errors in huggingface retry paths --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
This commit is contained in:
13
.github/gallery-agent/main.go
vendored
13
.github/gallery-agent/main.go
vendored
@@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -113,6 +114,17 @@ func main() {
|
||||
fmt.Println("Searching for trending models on HuggingFace...")
|
||||
rawModels, err := client.GetTrending(searchTerm, limit)
|
||||
if err != nil {
|
||||
if errors.Is(err, hfapi.ErrRateLimited) {
|
||||
fmt.Printf("HuggingFace API is rate limited after retries, skipping this run: %v\n", err)
|
||||
writeSummary(AddedModelSummary{
|
||||
SearchTerm: searchTerm,
|
||||
TotalFound: 0,
|
||||
ModelsAdded: 0,
|
||||
Quantization: quantization,
|
||||
ProcessingTime: time.Since(startTime).String(),
|
||||
})
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Error fetching models: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
@@ -277,4 +289,3 @@ func truncateString(s string, maxLen int) string {
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package hfapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
)
|
||||
@@ -88,57 +90,128 @@ type SearchParams struct {
|
||||
|
||||
// Client represents a Hugging Face API client
|
||||
type Client struct {
|
||||
baseURL string
|
||||
client *http.Client
|
||||
baseURL string
|
||||
client *http.Client
|
||||
maxRetries int
|
||||
retryBackoff time.Duration
|
||||
maxBackoff time.Duration
|
||||
sleepFn func(time.Duration)
|
||||
}
|
||||
|
||||
var ErrRateLimited = errors.New("huggingface API rate limited")
|
||||
|
||||
// NewClient creates a new Hugging Face API client
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
baseURL: "https://huggingface.co/api/models",
|
||||
client: httpclient.New(httpclient.WithFollowRedirects()),
|
||||
baseURL: "https://huggingface.co/api/models",
|
||||
client: httpclient.New(httpclient.WithFollowRedirects()),
|
||||
maxRetries: 5,
|
||||
retryBackoff: 1 * time.Second,
|
||||
maxBackoff: 30 * time.Second,
|
||||
sleepFn: time.Sleep,
|
||||
}
|
||||
}
|
||||
|
||||
// SearchModels searches for models using the Hugging Face API
|
||||
func (c *Client) SearchModels(params SearchParams) ([]Model, error) {
|
||||
req, err := http.NewRequest("GET", c.baseURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
for attempt := 1; attempt <= c.maxRetries; attempt++ {
|
||||
req, err := http.NewRequest("GET", c.baseURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Add query parameters
|
||||
q := req.URL.Query()
|
||||
q.Add("sort", params.Sort)
|
||||
q.Add("direction", fmt.Sprintf("%d", params.Direction))
|
||||
q.Add("limit", fmt.Sprintf("%d", params.Limit))
|
||||
q.Add("search", params.Search)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
if attempt < c.maxRetries {
|
||||
c.sleepFn(c.exponentialBackoff(attempt))
|
||||
continue
|
||||
}
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
return nil, fmt.Errorf("failed to close response body: %w", err)
|
||||
}
|
||||
if c.isRetryableStatus(resp.StatusCode) && attempt < c.maxRetries {
|
||||
c.sleepFn(c.retryDelay(resp, attempt))
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
return nil, fmt.Errorf("%w: failed to fetch models. Status code: %d", ErrRateLimited, resp.StatusCode)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to fetch models. Status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Read the response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
closeErr := resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
if closeErr != nil {
|
||||
return nil, fmt.Errorf("failed to close response body: %w", closeErr)
|
||||
}
|
||||
|
||||
// Parse the JSON response
|
||||
var models []Model
|
||||
if err := json.Unmarshal(body, &models); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// Add query parameters
|
||||
q := req.URL.Query()
|
||||
q.Add("sort", params.Sort)
|
||||
q.Add("direction", fmt.Sprintf("%d", params.Direction))
|
||||
q.Add("limit", fmt.Sprintf("%d", params.Limit))
|
||||
q.Add("search", params.Search)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
return nil, fmt.Errorf("%w: failed to fetch models. Status code: %d", ErrRateLimited, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
// Make the HTTP request
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
func (c *Client) isRetryableStatus(code int) bool {
|
||||
return code == http.StatusTooManyRequests || (code >= http.StatusInternalServerError && code <= http.StatusNetworkAuthenticationRequired)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to fetch models. Status code: %d", resp.StatusCode)
|
||||
func (c *Client) retryDelay(resp *http.Response, attempt int) time.Duration {
|
||||
if retryAfter := strings.TrimSpace(resp.Header.Get("Retry-After")); retryAfter != "" {
|
||||
if seconds, err := strconv.Atoi(retryAfter); err == nil && seconds > 0 {
|
||||
delay := time.Duration(seconds) * time.Second
|
||||
if delay > c.maxBackoff {
|
||||
return c.maxBackoff
|
||||
}
|
||||
return delay
|
||||
}
|
||||
if at, err := http.ParseTime(retryAfter); err == nil {
|
||||
delay := time.Until(at)
|
||||
if delay > 0 {
|
||||
if delay > c.maxBackoff {
|
||||
return c.maxBackoff
|
||||
}
|
||||
return delay
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read the response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
return c.exponentialBackoff(attempt)
|
||||
}
|
||||
|
||||
// Parse the JSON response
|
||||
var models []Model
|
||||
if err := json.Unmarshal(body, &models); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
func (c *Client) exponentialBackoff(attempt int) time.Duration {
|
||||
delay := c.retryBackoff
|
||||
for i := 1; i < attempt; i++ {
|
||||
delay *= 2
|
||||
if delay >= c.maxBackoff {
|
||||
return c.maxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
return models, nil
|
||||
if delay > c.maxBackoff {
|
||||
return c.maxBackoff
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
// GetLatest fetches the latest GGUF models
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package hfapi_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -185,6 +187,87 @@ var _ = Describe("HuggingFace API Client", func() {
|
||||
Expect(err.Error()).To(ContainSubstring("failed to parse JSON response"))
|
||||
Expect(models).To(BeNil())
|
||||
})
|
||||
|
||||
It("should retry 429 responses and honor Retry-After", func() {
|
||||
attempts := 0
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
if attempts == 1 {
|
||||
w.Header().Set("Retry-After", "1")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := w.Write([]byte("[]"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}))
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: 30,
|
||||
Search: "GGUF",
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
models, err := client.SearchModels(params)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(models).To(HaveLen(0))
|
||||
Expect(attempts).To(Equal(2))
|
||||
Expect(elapsed).To(BeNumerically(">=", 900*time.Millisecond))
|
||||
})
|
||||
|
||||
It("should fail fast on non-retryable 4xx responses", func() {
|
||||
attempts := 0
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
attempts++
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: 30,
|
||||
Search: "GGUF",
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
models, err := client.SearchModels(params)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("Status code: 400"))
|
||||
Expect(models).To(BeNil())
|
||||
Expect(attempts).To(Equal(1))
|
||||
Expect(elapsed).To(BeNumerically("<", 500*time.Millisecond))
|
||||
})
|
||||
|
||||
It("should return ErrRateLimited when 429 persists after retries", func() {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Retry-After", "1")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
}))
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "trendingScore",
|
||||
Direction: -1,
|
||||
Limit: 15,
|
||||
Search: "GGUF",
|
||||
}
|
||||
|
||||
models, err := client.SearchModels(params)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, hfapi.ErrRateLimited)).To(BeTrue())
|
||||
Expect(err.Error()).To(ContainSubstring("Status code: 429"))
|
||||
Expect(models).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("when getting latest GGUF models", func() {
|
||||
|
||||
Reference in New Issue
Block a user