feat: import models via URI (#7245)

* feat: initial hook to install elements directly

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* WIP: ui changes

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Move HF api client to pkg

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add simple importer for gguf files

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add opcache

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* wire importers to CLI

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add omitempty to config fields

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fix tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add MLX importer

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Small refactors to star to use HF for discovery

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add tests

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Common preferences

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Add support to bare HF repos

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(importer/llama.cpp): add support for mmproj files

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* add mmproj quants to common preferences

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Fix vlm usage in tokenizer mode with llama.cpp

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-11-12 20:48:56 +01:00
committed by GitHub
parent 87d0020c10
commit 3728552e94
40 changed files with 1970 additions and 694 deletions

View File

@@ -13,99 +13,102 @@ import (
"github.com/rs/zerolog/log"
)
// @Description GrammarConfig contains configuration for grammar parsing
type GrammarConfig struct {
// ParallelCalls enables the LLM to return multiple function calls in the same response
ParallelCalls bool `yaml:"parallel_calls"`
ParallelCalls bool `yaml:"parallel_calls,omitempty" json:"parallel_calls,omitempty"`
DisableParallelNewLines bool `yaml:"disable_parallel_new_lines"`
DisableParallelNewLines bool `yaml:"disable_parallel_new_lines,omitempty" json:"disable_parallel_new_lines,omitempty"`
// MixedMode enables the LLM to return strings and not only JSON objects
// This is useful for models to not constraining returning only JSON and also messages back to the user
MixedMode bool `yaml:"mixed_mode"`
MixedMode bool `yaml:"mixed_mode,omitempty" json:"mixed_mode,omitempty"`
// NoMixedFreeString disables the mixed mode for free strings
// In this way if the LLM selects a free string, it won't be mixed necessarily with JSON objects.
// For example, if enabled the LLM or returns a JSON object or a free string, but not a mix of both
// If disabled(default): the LLM can return a JSON object surrounded by free strings (e.g. `this is the JSON result: { "bar": "baz" } for your question`). This forces the LLM to return at least a JSON object, but its not going to be strict
NoMixedFreeString bool `yaml:"no_mixed_free_string"`
NoMixedFreeString bool `yaml:"no_mixed_free_string,omitempty" json:"no_mixed_free_string,omitempty"`
// NoGrammar disables the grammar parsing and parses the responses directly from the LLM
NoGrammar bool `yaml:"disable"`
NoGrammar bool `yaml:"disable,omitempty" json:"disable,omitempty"`
// Prefix is the suffix to append to the grammar when being generated
// This is useful when models prepend a tag before returning JSON
Prefix string `yaml:"prefix"`
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
// ExpectStringsAfterJSON enables mixed string suffix
ExpectStringsAfterJSON bool `yaml:"expect_strings_after_json"`
ExpectStringsAfterJSON bool `yaml:"expect_strings_after_json,omitempty" json:"expect_strings_after_json,omitempty"`
// PropOrder selects what order to print properties
// for instance name,arguments will make print { "name": "foo", "arguments": { "bar": "baz" } }
// instead of { "arguments": { "bar": "baz" }, "name": "foo" }
PropOrder string `yaml:"properties_order"`
PropOrder string `yaml:"properties_order,omitempty" json:"properties_order,omitempty"`
// SchemaType can be configured to use a specific schema type to force the grammar
// available : json, llama3.1
SchemaType string `yaml:"schema_type"`
SchemaType string `yaml:"schema_type,omitempty" json:"schema_type,omitempty"`
GrammarTriggers []GrammarTrigger `yaml:"triggers"`
GrammarTriggers []GrammarTrigger `yaml:"triggers,omitempty" json:"triggers,omitempty"`
}
// @Description GrammarTrigger defines a trigger word for grammar parsing
type GrammarTrigger struct {
// Trigger is the string that triggers the grammar
Word string `yaml:"word"`
Word string `yaml:"word,omitempty" json:"word,omitempty"`
}
// FunctionsConfig is the configuration for the tool/function call.
// @Description FunctionsConfig is the configuration for the tool/function call.
// It includes setting to map the function name and arguments from the response
// and, for instance, also if processing the requests with BNF grammars.
type FunctionsConfig struct {
// DisableNoAction disables the "no action" tool
// By default we inject a tool that does nothing and is used to return an answer from the LLM
DisableNoAction bool `yaml:"disable_no_action"`
DisableNoAction bool `yaml:"disable_no_action,omitempty" json:"disable_no_action,omitempty"`
// Grammar is the configuration for the grammar
GrammarConfig GrammarConfig `yaml:"grammar"`
GrammarConfig GrammarConfig `yaml:"grammar,omitempty" json:"grammar,omitempty"`
// NoActionFunctionName is the name of the function that does nothing. It defaults to "answer"
NoActionFunctionName string `yaml:"no_action_function_name"`
NoActionFunctionName string `yaml:"no_action_function_name,omitempty" json:"no_action_function_name,omitempty"`
// NoActionDescriptionName is the name of the function that returns the description of the no action function
NoActionDescriptionName string `yaml:"no_action_description_name"`
NoActionDescriptionName string `yaml:"no_action_description_name,omitempty" json:"no_action_description_name,omitempty"`
// ResponseRegex is a named regex to extract the function name and arguments from the response
ResponseRegex []string `yaml:"response_regex"`
ResponseRegex []string `yaml:"response_regex,omitempty" json:"response_regex,omitempty"`
// JSONRegexMatch is a regex to extract the JSON object from the response
JSONRegexMatch []string `yaml:"json_regex_match"`
JSONRegexMatch []string `yaml:"json_regex_match,omitempty" json:"json_regex_match,omitempty"`
// ArgumentRegex is a named regex to extract the arguments from the response. Use ArgumentRegexKey and ArgumentRegexValue to set the names of the named regex for key and value of the arguments.
ArgumentRegex []string `yaml:"argument_regex"`
ArgumentRegex []string `yaml:"argument_regex,omitempty" json:"argument_regex,omitempty"`
// ArgumentRegex named regex names for key and value extractions. default: key and value
ArgumentRegexKey string `yaml:"argument_regex_key_name"` // default: key
ArgumentRegexValue string `yaml:"argument_regex_value_name"` // default: value
ArgumentRegexKey string `yaml:"argument_regex_key_name,omitempty" json:"argument_regex_key_name,omitempty"` // default: key
ArgumentRegexValue string `yaml:"argument_regex_value_name,omitempty" json:"argument_regex_value_name,omitempty"` // default: value
// ReplaceFunctionResults allow to replace strings in the results before parsing them
ReplaceFunctionResults []ReplaceResult `yaml:"replace_function_results"`
ReplaceFunctionResults []ReplaceResult `yaml:"replace_function_results,omitempty" json:"replace_function_results,omitempty"`
// ReplaceLLMResult allow to replace strings in the results before parsing them
ReplaceLLMResult []ReplaceResult `yaml:"replace_llm_results"`
ReplaceLLMResult []ReplaceResult `yaml:"replace_llm_results,omitempty" json:"replace_llm_results,omitempty"`
// CaptureLLMResult is a regex to extract a string from the LLM response
// that is used as return string when using tools.
// This is useful for e.g. if the LLM outputs a reasoning and we want to get the reasoning as a string back
CaptureLLMResult []string `yaml:"capture_llm_results"`
CaptureLLMResult []string `yaml:"capture_llm_results,omitempty" json:"capture_llm_results,omitempty"`
// FunctionName enable the LLM to return { "name": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }
// instead of { "function": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }.
// This might be useful for certain models trained with the function name as the first token.
FunctionNameKey string `yaml:"function_name_key"`
FunctionArgumentsKey string `yaml:"function_arguments_key"`
FunctionNameKey string `yaml:"function_name_key,omitempty" json:"function_name_key,omitempty"`
FunctionArgumentsKey string `yaml:"function_arguments_key,omitempty" json:"function_arguments_key,omitempty"`
}
// @Description ReplaceResult defines a key-value replacement for function results
type ReplaceResult struct {
Key string `yaml:"key"`
Value string `yaml:"value"`
Key string `yaml:"key,omitempty" json:"key,omitempty"`
Value string `yaml:"value,omitempty" json:"value,omitempty"`
}
type FuncCallResults struct {

View File

@@ -0,0 +1,306 @@
package hfapi
import (
"encoding/json"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
)
// Model represents a model from the Hugging Face API
type Model struct {
ModelID string `json:"modelId"`
Author string `json:"author"`
Downloads int `json:"downloads"`
LastModified string `json:"lastModified"`
PipelineTag string `json:"pipelineTag"`
Private bool `json:"private"`
Tags []string `json:"tags"`
CreatedAt string `json:"createdAt"`
UpdatedAt string `json:"updatedAt"`
Sha string `json:"sha"`
Config map[string]interface{} `json:"config"`
ModelIndex string `json:"model_index"`
LibraryName string `json:"library_name"`
MaskToken string `json:"mask_token"`
TokenizerClass string `json:"tokenizer_class"`
}
// FileInfo represents file information from HuggingFace
type FileInfo struct {
Type string `json:"type"`
Oid string `json:"oid"`
Size int64 `json:"size"`
Path string `json:"path"`
LFS *LFSInfo `json:"lfs,omitempty"`
XetHash string `json:"xetHash,omitempty"`
}
// LFSInfo represents LFS (Large File Storage) information
type LFSInfo struct {
Oid string `json:"oid"`
Size int64 `json:"size"`
PointerSize int `json:"pointerSize"`
}
// ModelFile represents a file in a model repository
type ModelFile struct {
Path string
Size int64
SHA256 string
IsReadme bool
URL string
}
// ModelDetails represents detailed information about a model
type ModelDetails struct {
ModelID string
Author string
Files []ModelFile
ReadmeFile *ModelFile
ReadmeContent string
}
// SearchParams represents the parameters for searching models
type SearchParams struct {
Sort string `json:"sort"`
Direction int `json:"direction"`
Limit int `json:"limit"`
Search string `json:"search"`
}
// Client represents a Hugging Face API client
type Client struct {
baseURL string
client *http.Client
}
// NewClient creates a new Hugging Face API client
func NewClient() *Client {
return &Client{
baseURL: "https://huggingface.co/api/models",
client: &http.Client{},
}
}
// SearchModels searches for models using the Hugging Face API
func (c *Client) SearchModels(params SearchParams) ([]Model, error) {
req, err := http.NewRequest("GET", c.baseURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Add query parameters
q := req.URL.Query()
q.Add("sort", params.Sort)
q.Add("direction", fmt.Sprintf("%d", params.Direction))
q.Add("limit", fmt.Sprintf("%d", params.Limit))
q.Add("search", params.Search)
req.URL.RawQuery = q.Encode()
// Make the HTTP request
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch models. Status code: %d", resp.StatusCode)
}
// Read the response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
// Parse the JSON response
var models []Model
if err := json.Unmarshal(body, &models); err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
return models, nil
}
// GetLatest fetches the latest GGUF models
func (c *Client) GetLatest(searchTerm string, limit int) ([]Model, error) {
params := SearchParams{
Sort: "lastModified",
Direction: -1,
Limit: limit,
Search: searchTerm,
}
return c.SearchModels(params)
}
// BaseURL returns the current base URL
func (c *Client) BaseURL() string {
return c.baseURL
}
// SetBaseURL sets a new base URL (useful for testing)
func (c *Client) SetBaseURL(url string) {
c.baseURL = url
}
// ListFiles lists all files in a HuggingFace repository
func (c *Client) ListFiles(repoID string) ([]FileInfo, error) {
baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
url := fmt.Sprintf("%s/api/models/%s/tree/main", baseURL, repoID)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
resp, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("failed to fetch files. Status code: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
var files []FileInfo
if err := json.Unmarshal(body, &files); err != nil {
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
}
return files, nil
}
// GetFileSHA gets the SHA256 checksum for a specific file by searching through the file list
func (c *Client) GetFileSHA(repoID, fileName string) (string, error) {
files, err := c.ListFiles(repoID)
if err != nil {
return "", fmt.Errorf("failed to list files: %w", err)
}
for _, file := range files {
if filepath.Base(file.Path) == fileName {
if file.LFS != nil && file.LFS.Oid != "" {
// The LFS OID contains the SHA256 hash
return file.LFS.Oid, nil
}
// If no LFS, return the regular OID
return file.Oid, nil
}
}
return "", fmt.Errorf("file %s not found", fileName)
}
// GetModelDetails gets detailed information about a model including files and checksums
func (c *Client) GetModelDetails(repoID string) (*ModelDetails, error) {
files, err := c.ListFiles(repoID)
if err != nil {
return nil, fmt.Errorf("failed to list files: %w", err)
}
details := &ModelDetails{
ModelID: repoID,
Author: strings.Split(repoID, "/")[0],
Files: make([]ModelFile, 0, len(files)),
}
// Process each file
baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
for _, file := range files {
fileName := filepath.Base(file.Path)
isReadme := strings.Contains(strings.ToLower(fileName), "readme")
// Extract SHA256 from LFS or use OID
sha256 := ""
if file.LFS != nil && file.LFS.Oid != "" {
sha256 = file.LFS.Oid
} else {
sha256 = file.Oid
}
// Construct the full URL for the file
// Use /resolve/main/ for downloading files (handles LFS properly)
fileURL := fmt.Sprintf("%s/%s/resolve/main/%s", baseURL, repoID, file.Path)
modelFile := ModelFile{
Path: file.Path,
Size: file.Size,
SHA256: sha256,
IsReadme: isReadme,
URL: fileURL,
}
details.Files = append(details.Files, modelFile)
// Set the readme file
if isReadme && details.ReadmeFile == nil {
details.ReadmeFile = &modelFile
}
}
return details, nil
}
// GetReadmeContent gets the content of a README file
func (c *Client) GetReadmeContent(repoID, readmePath string) (string, error) {
baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
url := fmt.Sprintf("%s/%s/raw/main/%s", baseURL, repoID, readmePath)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
resp, err := c.client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("failed to fetch readme content. Status code: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response body: %w", err)
}
return string(body), nil
}
// FilterFilesByQuantization filters files by quantization type
func FilterFilesByQuantization(files []ModelFile, quantization string) []ModelFile {
var filtered []ModelFile
for _, file := range files {
fileName := filepath.Base(file.Path)
if strings.Contains(strings.ToLower(fileName), strings.ToLower(quantization)) {
filtered = append(filtered, file)
}
}
return filtered
}
// FindPreferredModelFile finds the preferred model file based on quantization preferences
func FindPreferredModelFile(files []ModelFile, preferences []string) *ModelFile {
for _, preference := range preferences {
for i := range files {
fileName := filepath.Base(files[i].Path)
if strings.Contains(strings.ToLower(fileName), strings.ToLower(preference)) {
return &files[i]
}
}
}
return nil
}

View File

@@ -0,0 +1,541 @@
package hfapi_test
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
)
var _ = Describe("HuggingFace API Client", func() {
var (
client *hfapi.Client
server *httptest.Server
)
BeforeEach(func() {
client = hfapi.NewClient()
})
AfterEach(func() {
if server != nil {
server.Close()
}
})
Context("when creating a new client", func() {
It("should initialize with correct base URL", func() {
Expect(client).ToNot(BeNil())
Expect(client.BaseURL()).To(Equal("https://huggingface.co/api/models"))
})
})
Context("when searching for models", func() {
BeforeEach(func() {
// Mock response data
mockResponse := `[
{
"modelId": "test-model-1",
"author": "test-author",
"downloads": 1000,
"lastModified": "2024-01-01T00:00:00.000Z",
"pipelineTag": "text-generation",
"private": false,
"tags": ["gguf", "llama"],
"createdAt": "2024-01-01T00:00:00.000Z",
"updatedAt": "2024-01-01T00:00:00.000Z",
"sha": "abc123",
"config": {},
"model_index": "test-index",
"library_name": "transformers",
"mask_token": null,
"tokenizer_class": "LlamaTokenizer"
},
{
"modelId": "test-model-2",
"author": "test-author-2",
"downloads": 2000,
"lastModified": "2024-01-02T00:00:00.000Z",
"pipelineTag": "text-generation",
"private": false,
"tags": ["gguf", "mistral"],
"createdAt": "2024-01-02T00:00:00.000Z",
"updatedAt": "2024-01-02T00:00:00.000Z",
"sha": "def456",
"config": {},
"model_index": "test-index-2",
"library_name": "transformers",
"mask_token": null,
"tokenizer_class": "MistralTokenizer"
}
]`
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify request parameters
Expect(r.URL.Query().Get("sort")).To(Equal("lastModified"))
Expect(r.URL.Query().Get("direction")).To(Equal("-1"))
Expect(r.URL.Query().Get("limit")).To(Equal("30"))
Expect(r.URL.Query().Get("search")).To(Equal("GGUF"))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockResponse))
}))
// Override the client's base URL to use our mock server
client.SetBaseURL(server.URL)
})
It("should successfully search for models", func() {
params := hfapi.SearchParams{
Sort: "lastModified",
Direction: -1,
Limit: 30,
Search: "GGUF",
}
models, err := client.SearchModels(params)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(2))
// Verify first model
Expect(models[0].ModelID).To(Equal("test-model-1"))
Expect(models[0].Author).To(Equal("test-author"))
Expect(models[0].Downloads).To(Equal(1000))
Expect(models[0].PipelineTag).To(Equal("text-generation"))
Expect(models[0].Private).To(BeFalse())
Expect(models[0].Tags).To(ContainElements("gguf", "llama"))
// Verify second model
Expect(models[1].ModelID).To(Equal("test-model-2"))
Expect(models[1].Author).To(Equal("test-author-2"))
Expect(models[1].Downloads).To(Equal(2000))
Expect(models[1].Tags).To(ContainElements("gguf", "mistral"))
})
It("should handle empty search results", func() {
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte("[]"))
}))
client.SetBaseURL(server.URL)
params := hfapi.SearchParams{
Sort: "lastModified",
Direction: -1,
Limit: 30,
Search: "nonexistent",
}
models, err := client.SearchModels(params)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(0))
})
It("should handle HTTP errors", func() {
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Internal Server Error"))
}))
client.SetBaseURL(server.URL)
params := hfapi.SearchParams{
Sort: "lastModified",
Direction: -1,
Limit: 30,
Search: "GGUF",
}
models, err := client.SearchModels(params)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("Status code: 500"))
Expect(models).To(BeNil())
})
It("should handle malformed JSON response", func() {
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte("invalid json"))
}))
client.SetBaseURL(server.URL)
params := hfapi.SearchParams{
Sort: "lastModified",
Direction: -1,
Limit: 30,
Search: "GGUF",
}
models, err := client.SearchModels(params)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("failed to parse JSON response"))
Expect(models).To(BeNil())
})
})
Context("when getting latest GGUF models", func() {
BeforeEach(func() {
mockResponse := `[
{
"modelId": "latest-gguf-model",
"author": "gguf-author",
"downloads": 5000,
"lastModified": "2024-01-03T00:00:00.000Z",
"pipelineTag": "text-generation",
"private": false,
"tags": ["gguf", "latest"],
"createdAt": "2024-01-03T00:00:00.000Z",
"updatedAt": "2024-01-03T00:00:00.000Z",
"sha": "latest123",
"config": {},
"model_index": "latest-index",
"library_name": "transformers",
"mask_token": null,
"tokenizer_class": "LlamaTokenizer"
}
]`
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify the search parameters are correct for GGUF search
Expect(r.URL.Query().Get("search")).To(Equal("GGUF"))
Expect(r.URL.Query().Get("sort")).To(Equal("lastModified"))
Expect(r.URL.Query().Get("direction")).To(Equal("-1"))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockResponse))
}))
client.SetBaseURL(server.URL)
})
It("should fetch latest GGUF models with correct parameters", func() {
models, err := client.GetLatest("GGUF", 10)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(1))
Expect(models[0].ModelID).To(Equal("latest-gguf-model"))
Expect(models[0].Author).To(Equal("gguf-author"))
Expect(models[0].Downloads).To(Equal(5000))
Expect(models[0].Tags).To(ContainElements("gguf", "latest"))
})
It("should use custom search term", func() {
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expect(r.URL.Query().Get("search")).To(Equal("custom-search"))
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte("[]"))
}))
client.SetBaseURL(server.URL)
models, err := client.GetLatest("custom-search", 5)
Expect(err).ToNot(HaveOccurred())
Expect(models).To(HaveLen(0))
})
})
Context("when handling network errors", func() {
It("should handle connection failures gracefully", func() {
// Use an invalid URL to simulate connection failure
client.SetBaseURL("http://invalid-url-that-does-not-exist")
params := hfapi.SearchParams{
Sort: "lastModified",
Direction: -1,
Limit: 30,
Search: "GGUF",
}
models, err := client.SearchModels(params)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("failed to make request"))
Expect(models).To(BeNil())
})
})
Context("when getting file SHA on remote model", func() {
It("should get file SHA successfully", func() {
sha, err := client.GetFileSHA(
"mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF", "localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf")
Expect(err).ToNot(HaveOccurred())
Expect(sha).To(Equal("4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4"))
})
})
Context("when listing files", func() {
BeforeEach(func() {
mockFilesResponse := `[
{
"type": "file",
"path": "model-Q4_K_M.gguf",
"size": 1000000,
"oid": "abc123",
"lfs": {
"oid": "def456789",
"size": 1000000,
"pointerSize": 135
}
},
{
"type": "file",
"path": "README.md",
"size": 5000,
"oid": "readme123"
},
{
"type": "file",
"path": "config.json",
"size": 1000,
"oid": "config123"
}
]`
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/tree/main") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockFilesResponse))
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
client.SetBaseURL(server.URL)
})
It("should list files successfully", func() {
files, err := client.ListFiles("test/model")
Expect(err).ToNot(HaveOccurred())
Expect(files).To(HaveLen(3))
Expect(files[0].Path).To(Equal("model-Q4_K_M.gguf"))
Expect(files[0].Size).To(Equal(int64(1000000)))
Expect(files[0].LFS).ToNot(BeNil())
Expect(files[0].LFS.Oid).To(Equal("def456789"))
Expect(files[1].Path).To(Equal("README.md"))
Expect(files[1].Size).To(Equal(int64(5000)))
})
})
Context("when getting file SHA", func() {
BeforeEach(func() {
mockFilesResponse := `[
{
"type": "file",
"path": "model-Q4_K_M.gguf",
"size": 1000000,
"oid": "abc123",
"lfs": {
"oid": "def456789",
"size": 1000000,
"pointerSize": 135
}
}
]`
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/tree/main") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockFilesResponse))
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
client.SetBaseURL(server.URL)
})
It("should get file SHA successfully", func() {
sha, err := client.GetFileSHA("test/model", "model-Q4_K_M.gguf")
Expect(err).ToNot(HaveOccurred())
Expect(sha).To(Equal("def456789"))
})
It("should handle missing SHA gracefully", func() {
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/tree/main") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`[
{
"type": "file",
"path": "file.txt",
"size": 100,
"oid": "file123"
}
]`))
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
client.SetBaseURL(server.URL)
sha, err := client.GetFileSHA("test/model", "file.txt")
Expect(err).ToNot(HaveOccurred())
// When there's no LFS, it should return the OID
Expect(sha).To(Equal("file123"))
})
})
Context("when getting model details", func() {
BeforeEach(func() {
mockFilesResponse := `[
{
"path": "model-Q4_K_M.gguf",
"size": 1000000,
"oid": "abc123",
"lfs": {
"oid": "sha256:def456",
"size": 1000000,
"pointer": "version https://git-lfs.github.com/spec/v1",
"sha256": "def456789"
}
},
{
"path": "README.md",
"size": 5000,
"oid": "readme123"
}
]`
mockFileInfoResponse := `{
"path": "model-Q4_K_M.gguf",
"size": 1000000,
"oid": "abc123",
"lfs": {
"oid": "sha256:def456",
"size": 1000000,
"pointer": "version https://git-lfs.github.com/spec/v1",
"sha256": "def456789"
}
}`
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/tree/main") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockFilesResponse))
} else if strings.Contains(r.URL.Path, "/paths-info") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockFileInfoResponse))
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
client.SetBaseURL(server.URL)
})
It("should get model details successfully", func() {
details, err := client.GetModelDetails("test/model")
Expect(err).ToNot(HaveOccurred())
Expect(details.ModelID).To(Equal("test/model"))
Expect(details.Author).To(Equal("test"))
Expect(details.Files).To(HaveLen(2))
Expect(details.ReadmeFile).ToNot(BeNil())
Expect(details.ReadmeFile.Path).To(Equal("README.md"))
Expect(details.ReadmeFile.IsReadme).To(BeTrue())
// Verify URLs are set for all files
baseURL := strings.TrimSuffix(server.URL, "/api/models")
for _, file := range details.Files {
expectedURL := fmt.Sprintf("%s/test/model/resolve/main/%s", baseURL, file.Path)
Expect(file.URL).To(Equal(expectedURL))
}
})
})
Context("when getting README content", func() {
BeforeEach(func() {
mockReadmeContent := "# Test Model\n\nThis is a test model for demonstration purposes."
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/raw/main/") {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockReadmeContent))
} else {
w.WriteHeader(http.StatusNotFound)
}
}))
client.SetBaseURL(server.URL)
})
It("should get README content successfully", func() {
content, err := client.GetReadmeContent("test/model", "README.md")
Expect(err).ToNot(HaveOccurred())
Expect(content).To(Equal("# Test Model\n\nThis is a test model for demonstration purposes."))
})
})
Context("when filtering files", func() {
It("should filter files by quantization", func() {
files := []hfapi.ModelFile{
{Path: "model-Q4_K_M.gguf"},
{Path: "model-Q3_K_M.gguf"},
{Path: "README.md", IsReadme: true},
}
filtered := hfapi.FilterFilesByQuantization(files, "Q4_K_M")
Expect(filtered).To(HaveLen(1))
Expect(filtered[0].Path).To(Equal("model-Q4_K_M.gguf"))
})
It("should find preferred model file", func() {
files := []hfapi.ModelFile{
{Path: "model-Q3_K_M.gguf"},
{Path: "model-Q4_K_M.gguf"},
{Path: "README.md", IsReadme: true},
}
preferences := []string{"Q4_K_M", "Q3_K_M"}
preferred := hfapi.FindPreferredModelFile(files, preferences)
Expect(preferred).ToNot(BeNil())
Expect(preferred.Path).To(Equal("model-Q4_K_M.gguf"))
Expect(preferred.IsReadme).To(BeFalse())
})
It("should return nil if no preferred file found", func() {
files := []hfapi.ModelFile{
{Path: "model-Q2_K.gguf"},
{Path: "README.md", IsReadme: true},
}
preferences := []string{"Q4_K_M", "Q3_K_M"}
preferred := hfapi.FindPreferredModelFile(files, preferences)
Expect(preferred).To(BeNil())
})
})
})

View File

@@ -0,0 +1,15 @@
package hfapi_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestHfapi(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "HuggingFace API Suite")
}