mirror of
https://github.com/mudler/LocalAI.git
synced 2026-03-24 09:42:57 -04:00
feat: import models via URI (#7245)
* feat: initial hook to install elements directly Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * WIP: ui changes Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Move HF api client to pkg Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add simple importer for gguf files Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add opcache Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * wire importers to CLI Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add omitempty to config fields Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Fix tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add MLX importer Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Small refactors to star to use HF for discovery Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Common preferences Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add support to bare HF repos Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(importer/llama.cpp): add support for mmproj files Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * add mmproj quants to common preferences Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Fix vlm usage in tokenizer mode with llama.cpp Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
87d0020c10
commit
3728552e94
@@ -13,99 +13,102 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// @Description GrammarConfig contains configuration for grammar parsing
|
||||
type GrammarConfig struct {
|
||||
// ParallelCalls enables the LLM to return multiple function calls in the same response
|
||||
ParallelCalls bool `yaml:"parallel_calls"`
|
||||
ParallelCalls bool `yaml:"parallel_calls,omitempty" json:"parallel_calls,omitempty"`
|
||||
|
||||
DisableParallelNewLines bool `yaml:"disable_parallel_new_lines"`
|
||||
DisableParallelNewLines bool `yaml:"disable_parallel_new_lines,omitempty" json:"disable_parallel_new_lines,omitempty"`
|
||||
|
||||
// MixedMode enables the LLM to return strings and not only JSON objects
|
||||
// This is useful for models to not constraining returning only JSON and also messages back to the user
|
||||
MixedMode bool `yaml:"mixed_mode"`
|
||||
MixedMode bool `yaml:"mixed_mode,omitempty" json:"mixed_mode,omitempty"`
|
||||
|
||||
// NoMixedFreeString disables the mixed mode for free strings
|
||||
// In this way if the LLM selects a free string, it won't be mixed necessarily with JSON objects.
|
||||
// For example, if enabled the LLM or returns a JSON object or a free string, but not a mix of both
|
||||
// If disabled(default): the LLM can return a JSON object surrounded by free strings (e.g. `this is the JSON result: { "bar": "baz" } for your question`). This forces the LLM to return at least a JSON object, but its not going to be strict
|
||||
NoMixedFreeString bool `yaml:"no_mixed_free_string"`
|
||||
NoMixedFreeString bool `yaml:"no_mixed_free_string,omitempty" json:"no_mixed_free_string,omitempty"`
|
||||
|
||||
// NoGrammar disables the grammar parsing and parses the responses directly from the LLM
|
||||
NoGrammar bool `yaml:"disable"`
|
||||
NoGrammar bool `yaml:"disable,omitempty" json:"disable,omitempty"`
|
||||
|
||||
// Prefix is the suffix to append to the grammar when being generated
|
||||
// This is useful when models prepend a tag before returning JSON
|
||||
Prefix string `yaml:"prefix"`
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// ExpectStringsAfterJSON enables mixed string suffix
|
||||
ExpectStringsAfterJSON bool `yaml:"expect_strings_after_json"`
|
||||
ExpectStringsAfterJSON bool `yaml:"expect_strings_after_json,omitempty" json:"expect_strings_after_json,omitempty"`
|
||||
|
||||
// PropOrder selects what order to print properties
|
||||
// for instance name,arguments will make print { "name": "foo", "arguments": { "bar": "baz" } }
|
||||
// instead of { "arguments": { "bar": "baz" }, "name": "foo" }
|
||||
PropOrder string `yaml:"properties_order"`
|
||||
PropOrder string `yaml:"properties_order,omitempty" json:"properties_order,omitempty"`
|
||||
|
||||
// SchemaType can be configured to use a specific schema type to force the grammar
|
||||
// available : json, llama3.1
|
||||
SchemaType string `yaml:"schema_type"`
|
||||
SchemaType string `yaml:"schema_type,omitempty" json:"schema_type,omitempty"`
|
||||
|
||||
GrammarTriggers []GrammarTrigger `yaml:"triggers"`
|
||||
GrammarTriggers []GrammarTrigger `yaml:"triggers,omitempty" json:"triggers,omitempty"`
|
||||
}
|
||||
|
||||
// @Description GrammarTrigger defines a trigger word for grammar parsing
|
||||
type GrammarTrigger struct {
|
||||
// Trigger is the string that triggers the grammar
|
||||
Word string `yaml:"word"`
|
||||
Word string `yaml:"word,omitempty" json:"word,omitempty"`
|
||||
}
|
||||
|
||||
// FunctionsConfig is the configuration for the tool/function call.
|
||||
// @Description FunctionsConfig is the configuration for the tool/function call.
|
||||
// It includes setting to map the function name and arguments from the response
|
||||
// and, for instance, also if processing the requests with BNF grammars.
|
||||
type FunctionsConfig struct {
|
||||
// DisableNoAction disables the "no action" tool
|
||||
// By default we inject a tool that does nothing and is used to return an answer from the LLM
|
||||
DisableNoAction bool `yaml:"disable_no_action"`
|
||||
DisableNoAction bool `yaml:"disable_no_action,omitempty" json:"disable_no_action,omitempty"`
|
||||
|
||||
// Grammar is the configuration for the grammar
|
||||
GrammarConfig GrammarConfig `yaml:"grammar"`
|
||||
GrammarConfig GrammarConfig `yaml:"grammar,omitempty" json:"grammar,omitempty"`
|
||||
|
||||
// NoActionFunctionName is the name of the function that does nothing. It defaults to "answer"
|
||||
NoActionFunctionName string `yaml:"no_action_function_name"`
|
||||
NoActionFunctionName string `yaml:"no_action_function_name,omitempty" json:"no_action_function_name,omitempty"`
|
||||
|
||||
// NoActionDescriptionName is the name of the function that returns the description of the no action function
|
||||
NoActionDescriptionName string `yaml:"no_action_description_name"`
|
||||
NoActionDescriptionName string `yaml:"no_action_description_name,omitempty" json:"no_action_description_name,omitempty"`
|
||||
|
||||
// ResponseRegex is a named regex to extract the function name and arguments from the response
|
||||
ResponseRegex []string `yaml:"response_regex"`
|
||||
ResponseRegex []string `yaml:"response_regex,omitempty" json:"response_regex,omitempty"`
|
||||
|
||||
// JSONRegexMatch is a regex to extract the JSON object from the response
|
||||
JSONRegexMatch []string `yaml:"json_regex_match"`
|
||||
JSONRegexMatch []string `yaml:"json_regex_match,omitempty" json:"json_regex_match,omitempty"`
|
||||
|
||||
// ArgumentRegex is a named regex to extract the arguments from the response. Use ArgumentRegexKey and ArgumentRegexValue to set the names of the named regex for key and value of the arguments.
|
||||
ArgumentRegex []string `yaml:"argument_regex"`
|
||||
ArgumentRegex []string `yaml:"argument_regex,omitempty" json:"argument_regex,omitempty"`
|
||||
// ArgumentRegex named regex names for key and value extractions. default: key and value
|
||||
ArgumentRegexKey string `yaml:"argument_regex_key_name"` // default: key
|
||||
ArgumentRegexValue string `yaml:"argument_regex_value_name"` // default: value
|
||||
ArgumentRegexKey string `yaml:"argument_regex_key_name,omitempty" json:"argument_regex_key_name,omitempty"` // default: key
|
||||
ArgumentRegexValue string `yaml:"argument_regex_value_name,omitempty" json:"argument_regex_value_name,omitempty"` // default: value
|
||||
|
||||
// ReplaceFunctionResults allow to replace strings in the results before parsing them
|
||||
ReplaceFunctionResults []ReplaceResult `yaml:"replace_function_results"`
|
||||
ReplaceFunctionResults []ReplaceResult `yaml:"replace_function_results,omitempty" json:"replace_function_results,omitempty"`
|
||||
|
||||
// ReplaceLLMResult allow to replace strings in the results before parsing them
|
||||
ReplaceLLMResult []ReplaceResult `yaml:"replace_llm_results"`
|
||||
ReplaceLLMResult []ReplaceResult `yaml:"replace_llm_results,omitempty" json:"replace_llm_results,omitempty"`
|
||||
|
||||
// CaptureLLMResult is a regex to extract a string from the LLM response
|
||||
// that is used as return string when using tools.
|
||||
// This is useful for e.g. if the LLM outputs a reasoning and we want to get the reasoning as a string back
|
||||
CaptureLLMResult []string `yaml:"capture_llm_results"`
|
||||
CaptureLLMResult []string `yaml:"capture_llm_results,omitempty" json:"capture_llm_results,omitempty"`
|
||||
|
||||
// FunctionName enable the LLM to return { "name": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }
|
||||
// instead of { "function": "function_name", "arguments": { "arg1": "value1", "arg2": "value2" } }.
|
||||
// This might be useful for certain models trained with the function name as the first token.
|
||||
FunctionNameKey string `yaml:"function_name_key"`
|
||||
FunctionArgumentsKey string `yaml:"function_arguments_key"`
|
||||
FunctionNameKey string `yaml:"function_name_key,omitempty" json:"function_name_key,omitempty"`
|
||||
FunctionArgumentsKey string `yaml:"function_arguments_key,omitempty" json:"function_arguments_key,omitempty"`
|
||||
}
|
||||
|
||||
// @Description ReplaceResult defines a key-value replacement for function results
|
||||
type ReplaceResult struct {
|
||||
Key string `yaml:"key"`
|
||||
Value string `yaml:"value"`
|
||||
Key string `yaml:"key,omitempty" json:"key,omitempty"`
|
||||
Value string `yaml:"value,omitempty" json:"value,omitempty"`
|
||||
}
|
||||
|
||||
type FuncCallResults struct {
|
||||
|
||||
306
pkg/huggingface-api/client.go
Normal file
306
pkg/huggingface-api/client.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package hfapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Model represents a model from the Hugging Face API
|
||||
type Model struct {
|
||||
ModelID string `json:"modelId"`
|
||||
Author string `json:"author"`
|
||||
Downloads int `json:"downloads"`
|
||||
LastModified string `json:"lastModified"`
|
||||
PipelineTag string `json:"pipelineTag"`
|
||||
Private bool `json:"private"`
|
||||
Tags []string `json:"tags"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
Sha string `json:"sha"`
|
||||
Config map[string]interface{} `json:"config"`
|
||||
ModelIndex string `json:"model_index"`
|
||||
LibraryName string `json:"library_name"`
|
||||
MaskToken string `json:"mask_token"`
|
||||
TokenizerClass string `json:"tokenizer_class"`
|
||||
}
|
||||
|
||||
// FileInfo represents file information from HuggingFace
|
||||
type FileInfo struct {
|
||||
Type string `json:"type"`
|
||||
Oid string `json:"oid"`
|
||||
Size int64 `json:"size"`
|
||||
Path string `json:"path"`
|
||||
LFS *LFSInfo `json:"lfs,omitempty"`
|
||||
XetHash string `json:"xetHash,omitempty"`
|
||||
}
|
||||
|
||||
// LFSInfo represents LFS (Large File Storage) information
|
||||
type LFSInfo struct {
|
||||
Oid string `json:"oid"`
|
||||
Size int64 `json:"size"`
|
||||
PointerSize int `json:"pointerSize"`
|
||||
}
|
||||
|
||||
// ModelFile represents a file in a model repository
|
||||
type ModelFile struct {
|
||||
Path string
|
||||
Size int64
|
||||
SHA256 string
|
||||
IsReadme bool
|
||||
URL string
|
||||
}
|
||||
|
||||
// ModelDetails represents detailed information about a model
|
||||
type ModelDetails struct {
|
||||
ModelID string
|
||||
Author string
|
||||
Files []ModelFile
|
||||
ReadmeFile *ModelFile
|
||||
ReadmeContent string
|
||||
}
|
||||
|
||||
// SearchParams represents the parameters for searching models
|
||||
type SearchParams struct {
|
||||
Sort string `json:"sort"`
|
||||
Direction int `json:"direction"`
|
||||
Limit int `json:"limit"`
|
||||
Search string `json:"search"`
|
||||
}
|
||||
|
||||
// Client represents a Hugging Face API client
|
||||
type Client struct {
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a new Hugging Face API client
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
baseURL: "https://huggingface.co/api/models",
|
||||
client: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// SearchModels searches for models using the Hugging Face API
|
||||
func (c *Client) SearchModels(params SearchParams) ([]Model, error) {
|
||||
req, err := http.NewRequest("GET", c.baseURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Add query parameters
|
||||
q := req.URL.Query()
|
||||
q.Add("sort", params.Sort)
|
||||
q.Add("direction", fmt.Sprintf("%d", params.Direction))
|
||||
q.Add("limit", fmt.Sprintf("%d", params.Limit))
|
||||
q.Add("search", params.Search)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
// Make the HTTP request
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to fetch models. Status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Read the response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
// Parse the JSON response
|
||||
var models []Model
|
||||
if err := json.Unmarshal(body, &models); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// GetLatest fetches the latest GGUF models
|
||||
func (c *Client) GetLatest(searchTerm string, limit int) ([]Model, error) {
|
||||
params := SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: limit,
|
||||
Search: searchTerm,
|
||||
}
|
||||
|
||||
return c.SearchModels(params)
|
||||
}
|
||||
|
||||
// BaseURL returns the current base URL
|
||||
func (c *Client) BaseURL() string {
|
||||
return c.baseURL
|
||||
}
|
||||
|
||||
// SetBaseURL sets a new base URL (useful for testing)
|
||||
func (c *Client) SetBaseURL(url string) {
|
||||
c.baseURL = url
|
||||
}
|
||||
|
||||
// ListFiles lists all files in a HuggingFace repository
|
||||
func (c *Client) ListFiles(repoID string) ([]FileInfo, error) {
|
||||
baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
|
||||
url := fmt.Sprintf("%s/api/models/%s/tree/main", baseURL, repoID)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to fetch files. Status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
var files []FileInfo
|
||||
if err := json.Unmarshal(body, &files); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// GetFileSHA gets the SHA256 checksum for a specific file by searching through the file list
|
||||
func (c *Client) GetFileSHA(repoID, fileName string) (string, error) {
|
||||
files, err := c.ListFiles(repoID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to list files: %w", err)
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if filepath.Base(file.Path) == fileName {
|
||||
if file.LFS != nil && file.LFS.Oid != "" {
|
||||
// The LFS OID contains the SHA256 hash
|
||||
return file.LFS.Oid, nil
|
||||
}
|
||||
// If no LFS, return the regular OID
|
||||
return file.Oid, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("file %s not found", fileName)
|
||||
}
|
||||
|
||||
// GetModelDetails gets detailed information about a model including files and checksums
|
||||
func (c *Client) GetModelDetails(repoID string) (*ModelDetails, error) {
|
||||
files, err := c.ListFiles(repoID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list files: %w", err)
|
||||
}
|
||||
|
||||
details := &ModelDetails{
|
||||
ModelID: repoID,
|
||||
Author: strings.Split(repoID, "/")[0],
|
||||
Files: make([]ModelFile, 0, len(files)),
|
||||
}
|
||||
|
||||
// Process each file
|
||||
baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
|
||||
for _, file := range files {
|
||||
fileName := filepath.Base(file.Path)
|
||||
isReadme := strings.Contains(strings.ToLower(fileName), "readme")
|
||||
|
||||
// Extract SHA256 from LFS or use OID
|
||||
sha256 := ""
|
||||
if file.LFS != nil && file.LFS.Oid != "" {
|
||||
sha256 = file.LFS.Oid
|
||||
} else {
|
||||
sha256 = file.Oid
|
||||
}
|
||||
|
||||
// Construct the full URL for the file
|
||||
// Use /resolve/main/ for downloading files (handles LFS properly)
|
||||
fileURL := fmt.Sprintf("%s/%s/resolve/main/%s", baseURL, repoID, file.Path)
|
||||
|
||||
modelFile := ModelFile{
|
||||
Path: file.Path,
|
||||
Size: file.Size,
|
||||
SHA256: sha256,
|
||||
IsReadme: isReadme,
|
||||
URL: fileURL,
|
||||
}
|
||||
|
||||
details.Files = append(details.Files, modelFile)
|
||||
|
||||
// Set the readme file
|
||||
if isReadme && details.ReadmeFile == nil {
|
||||
details.ReadmeFile = &modelFile
|
||||
}
|
||||
}
|
||||
|
||||
return details, nil
|
||||
}
|
||||
|
||||
// GetReadmeContent gets the content of a README file
|
||||
func (c *Client) GetReadmeContent(repoID, readmePath string) (string, error) {
|
||||
baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
|
||||
url := fmt.Sprintf("%s/%s/raw/main/%s", baseURL, repoID, readmePath)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("failed to fetch readme content. Status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// FilterFilesByQuantization filters files by quantization type
|
||||
func FilterFilesByQuantization(files []ModelFile, quantization string) []ModelFile {
|
||||
var filtered []ModelFile
|
||||
for _, file := range files {
|
||||
fileName := filepath.Base(file.Path)
|
||||
if strings.Contains(strings.ToLower(fileName), strings.ToLower(quantization)) {
|
||||
filtered = append(filtered, file)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// FindPreferredModelFile finds the preferred model file based on quantization preferences
|
||||
func FindPreferredModelFile(files []ModelFile, preferences []string) *ModelFile {
|
||||
for _, preference := range preferences {
|
||||
for i := range files {
|
||||
fileName := filepath.Base(files[i].Path)
|
||||
if strings.Contains(strings.ToLower(fileName), strings.ToLower(preference)) {
|
||||
return &files[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
541
pkg/huggingface-api/client_test.go
Normal file
541
pkg/huggingface-api/client_test.go
Normal file
@@ -0,0 +1,541 @@
|
||||
package hfapi_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
var _ = Describe("HuggingFace API Client", func() {
|
||||
var (
|
||||
client *hfapi.Client
|
||||
server *httptest.Server
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
client = hfapi.NewClient()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
if server != nil {
|
||||
server.Close()
|
||||
}
|
||||
})
|
||||
|
||||
Context("when creating a new client", func() {
|
||||
It("should initialize with correct base URL", func() {
|
||||
Expect(client).ToNot(BeNil())
|
||||
Expect(client.BaseURL()).To(Equal("https://huggingface.co/api/models"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when searching for models", func() {
|
||||
BeforeEach(func() {
|
||||
// Mock response data
|
||||
mockResponse := `[
|
||||
{
|
||||
"modelId": "test-model-1",
|
||||
"author": "test-author",
|
||||
"downloads": 1000,
|
||||
"lastModified": "2024-01-01T00:00:00.000Z",
|
||||
"pipelineTag": "text-generation",
|
||||
"private": false,
|
||||
"tags": ["gguf", "llama"],
|
||||
"createdAt": "2024-01-01T00:00:00.000Z",
|
||||
"updatedAt": "2024-01-01T00:00:00.000Z",
|
||||
"sha": "abc123",
|
||||
"config": {},
|
||||
"model_index": "test-index",
|
||||
"library_name": "transformers",
|
||||
"mask_token": null,
|
||||
"tokenizer_class": "LlamaTokenizer"
|
||||
},
|
||||
{
|
||||
"modelId": "test-model-2",
|
||||
"author": "test-author-2",
|
||||
"downloads": 2000,
|
||||
"lastModified": "2024-01-02T00:00:00.000Z",
|
||||
"pipelineTag": "text-generation",
|
||||
"private": false,
|
||||
"tags": ["gguf", "mistral"],
|
||||
"createdAt": "2024-01-02T00:00:00.000Z",
|
||||
"updatedAt": "2024-01-02T00:00:00.000Z",
|
||||
"sha": "def456",
|
||||
"config": {},
|
||||
"model_index": "test-index-2",
|
||||
"library_name": "transformers",
|
||||
"mask_token": null,
|
||||
"tokenizer_class": "MistralTokenizer"
|
||||
}
|
||||
]`
|
||||
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify request parameters
|
||||
Expect(r.URL.Query().Get("sort")).To(Equal("lastModified"))
|
||||
Expect(r.URL.Query().Get("direction")).To(Equal("-1"))
|
||||
Expect(r.URL.Query().Get("limit")).To(Equal("30"))
|
||||
Expect(r.URL.Query().Get("search")).To(Equal("GGUF"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockResponse))
|
||||
}))
|
||||
|
||||
// Override the client's base URL to use our mock server
|
||||
client.SetBaseURL(server.URL)
|
||||
})
|
||||
|
||||
It("should successfully search for models", func() {
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: 30,
|
||||
Search: "GGUF",
|
||||
}
|
||||
|
||||
models, err := client.SearchModels(params)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(models).To(HaveLen(2))
|
||||
|
||||
// Verify first model
|
||||
Expect(models[0].ModelID).To(Equal("test-model-1"))
|
||||
Expect(models[0].Author).To(Equal("test-author"))
|
||||
Expect(models[0].Downloads).To(Equal(1000))
|
||||
Expect(models[0].PipelineTag).To(Equal("text-generation"))
|
||||
Expect(models[0].Private).To(BeFalse())
|
||||
Expect(models[0].Tags).To(ContainElements("gguf", "llama"))
|
||||
|
||||
// Verify second model
|
||||
Expect(models[1].ModelID).To(Equal("test-model-2"))
|
||||
Expect(models[1].Author).To(Equal("test-author-2"))
|
||||
Expect(models[1].Downloads).To(Equal(2000))
|
||||
Expect(models[1].Tags).To(ContainElements("gguf", "mistral"))
|
||||
})
|
||||
|
||||
It("should handle empty search results", func() {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("[]"))
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: 30,
|
||||
Search: "nonexistent",
|
||||
}
|
||||
|
||||
models, err := client.SearchModels(params)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(models).To(HaveLen(0))
|
||||
})
|
||||
|
||||
It("should handle HTTP errors", func() {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("Internal Server Error"))
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: 30,
|
||||
Search: "GGUF",
|
||||
}
|
||||
|
||||
models, err := client.SearchModels(params)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("Status code: 500"))
|
||||
Expect(models).To(BeNil())
|
||||
})
|
||||
|
||||
It("should handle malformed JSON response", func() {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("invalid json"))
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: 30,
|
||||
Search: "GGUF",
|
||||
}
|
||||
|
||||
models, err := client.SearchModels(params)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("failed to parse JSON response"))
|
||||
Expect(models).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("when getting latest GGUF models", func() {
|
||||
BeforeEach(func() {
|
||||
mockResponse := `[
|
||||
{
|
||||
"modelId": "latest-gguf-model",
|
||||
"author": "gguf-author",
|
||||
"downloads": 5000,
|
||||
"lastModified": "2024-01-03T00:00:00.000Z",
|
||||
"pipelineTag": "text-generation",
|
||||
"private": false,
|
||||
"tags": ["gguf", "latest"],
|
||||
"createdAt": "2024-01-03T00:00:00.000Z",
|
||||
"updatedAt": "2024-01-03T00:00:00.000Z",
|
||||
"sha": "latest123",
|
||||
"config": {},
|
||||
"model_index": "latest-index",
|
||||
"library_name": "transformers",
|
||||
"mask_token": null,
|
||||
"tokenizer_class": "LlamaTokenizer"
|
||||
}
|
||||
]`
|
||||
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify the search parameters are correct for GGUF search
|
||||
Expect(r.URL.Query().Get("search")).To(Equal("GGUF"))
|
||||
Expect(r.URL.Query().Get("sort")).To(Equal("lastModified"))
|
||||
Expect(r.URL.Query().Get("direction")).To(Equal("-1"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockResponse))
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
})
|
||||
|
||||
It("should fetch latest GGUF models with correct parameters", func() {
|
||||
models, err := client.GetLatest("GGUF", 10)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(models).To(HaveLen(1))
|
||||
Expect(models[0].ModelID).To(Equal("latest-gguf-model"))
|
||||
Expect(models[0].Author).To(Equal("gguf-author"))
|
||||
Expect(models[0].Downloads).To(Equal(5000))
|
||||
Expect(models[0].Tags).To(ContainElements("gguf", "latest"))
|
||||
})
|
||||
|
||||
It("should use custom search term", func() {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.URL.Query().Get("search")).To(Equal("custom-search"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("[]"))
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
models, err := client.GetLatest("custom-search", 5)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(models).To(HaveLen(0))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when handling network errors", func() {
|
||||
It("should handle connection failures gracefully", func() {
|
||||
// Use an invalid URL to simulate connection failure
|
||||
client.SetBaseURL("http://invalid-url-that-does-not-exist")
|
||||
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: 30,
|
||||
Search: "GGUF",
|
||||
}
|
||||
|
||||
models, err := client.SearchModels(params)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("failed to make request"))
|
||||
Expect(models).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("when getting file SHA on remote model", func() {
|
||||
It("should get file SHA successfully", func() {
|
||||
sha, err := client.GetFileSHA(
|
||||
"mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF", "localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sha).To(Equal("4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when listing files", func() {
|
||||
BeforeEach(func() {
|
||||
mockFilesResponse := `[
|
||||
{
|
||||
"type": "file",
|
||||
"path": "model-Q4_K_M.gguf",
|
||||
"size": 1000000,
|
||||
"oid": "abc123",
|
||||
"lfs": {
|
||||
"oid": "def456789",
|
||||
"size": 1000000,
|
||||
"pointerSize": 135
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"path": "README.md",
|
||||
"size": 5000,
|
||||
"oid": "readme123"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"path": "config.json",
|
||||
"size": 1000,
|
||||
"oid": "config123"
|
||||
}
|
||||
]`
|
||||
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/tree/main") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockFilesResponse))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
})
|
||||
|
||||
It("should list files successfully", func() {
|
||||
files, err := client.ListFiles("test/model")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(files).To(HaveLen(3))
|
||||
|
||||
Expect(files[0].Path).To(Equal("model-Q4_K_M.gguf"))
|
||||
Expect(files[0].Size).To(Equal(int64(1000000)))
|
||||
Expect(files[0].LFS).ToNot(BeNil())
|
||||
Expect(files[0].LFS.Oid).To(Equal("def456789"))
|
||||
|
||||
Expect(files[1].Path).To(Equal("README.md"))
|
||||
Expect(files[1].Size).To(Equal(int64(5000)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when getting file SHA", func() {
|
||||
BeforeEach(func() {
|
||||
mockFilesResponse := `[
|
||||
{
|
||||
"type": "file",
|
||||
"path": "model-Q4_K_M.gguf",
|
||||
"size": 1000000,
|
||||
"oid": "abc123",
|
||||
"lfs": {
|
||||
"oid": "def456789",
|
||||
"size": 1000000,
|
||||
"pointerSize": 135
|
||||
}
|
||||
}
|
||||
]`
|
||||
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/tree/main") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockFilesResponse))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
})
|
||||
|
||||
It("should get file SHA successfully", func() {
|
||||
sha, err := client.GetFileSHA("test/model", "model-Q4_K_M.gguf")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sha).To(Equal("def456789"))
|
||||
})
|
||||
|
||||
It("should handle missing SHA gracefully", func() {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/tree/main") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`[
|
||||
{
|
||||
"type": "file",
|
||||
"path": "file.txt",
|
||||
"size": 100,
|
||||
"oid": "file123"
|
||||
}
|
||||
]`))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
sha, err := client.GetFileSHA("test/model", "file.txt")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// When there's no LFS, it should return the OID
|
||||
Expect(sha).To(Equal("file123"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when getting model details", func() {
|
||||
BeforeEach(func() {
|
||||
mockFilesResponse := `[
|
||||
{
|
||||
"path": "model-Q4_K_M.gguf",
|
||||
"size": 1000000,
|
||||
"oid": "abc123",
|
||||
"lfs": {
|
||||
"oid": "sha256:def456",
|
||||
"size": 1000000,
|
||||
"pointer": "version https://git-lfs.github.com/spec/v1",
|
||||
"sha256": "def456789"
|
||||
}
|
||||
},
|
||||
{
|
||||
"path": "README.md",
|
||||
"size": 5000,
|
||||
"oid": "readme123"
|
||||
}
|
||||
]`
|
||||
|
||||
mockFileInfoResponse := `{
|
||||
"path": "model-Q4_K_M.gguf",
|
||||
"size": 1000000,
|
||||
"oid": "abc123",
|
||||
"lfs": {
|
||||
"oid": "sha256:def456",
|
||||
"size": 1000000,
|
||||
"pointer": "version https://git-lfs.github.com/spec/v1",
|
||||
"sha256": "def456789"
|
||||
}
|
||||
}`
|
||||
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/tree/main") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockFilesResponse))
|
||||
} else if strings.Contains(r.URL.Path, "/paths-info") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockFileInfoResponse))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
})
|
||||
|
||||
It("should get model details successfully", func() {
|
||||
details, err := client.GetModelDetails("test/model")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(details.ModelID).To(Equal("test/model"))
|
||||
Expect(details.Author).To(Equal("test"))
|
||||
Expect(details.Files).To(HaveLen(2))
|
||||
|
||||
Expect(details.ReadmeFile).ToNot(BeNil())
|
||||
Expect(details.ReadmeFile.Path).To(Equal("README.md"))
|
||||
Expect(details.ReadmeFile.IsReadme).To(BeTrue())
|
||||
|
||||
// Verify URLs are set for all files
|
||||
baseURL := strings.TrimSuffix(server.URL, "/api/models")
|
||||
for _, file := range details.Files {
|
||||
expectedURL := fmt.Sprintf("%s/test/model/resolve/main/%s", baseURL, file.Path)
|
||||
Expect(file.URL).To(Equal(expectedURL))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
Context("when getting README content", func() {
|
||||
BeforeEach(func() {
|
||||
mockReadmeContent := "# Test Model\n\nThis is a test model for demonstration purposes."
|
||||
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/raw/main/") {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockReadmeContent))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
})
|
||||
|
||||
It("should get README content successfully", func() {
|
||||
content, err := client.GetReadmeContent("test/model", "README.md")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(content).To(Equal("# Test Model\n\nThis is a test model for demonstration purposes."))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when filtering files", func() {
|
||||
It("should filter files by quantization", func() {
|
||||
files := []hfapi.ModelFile{
|
||||
{Path: "model-Q4_K_M.gguf"},
|
||||
{Path: "model-Q3_K_M.gguf"},
|
||||
{Path: "README.md", IsReadme: true},
|
||||
}
|
||||
|
||||
filtered := hfapi.FilterFilesByQuantization(files, "Q4_K_M")
|
||||
|
||||
Expect(filtered).To(HaveLen(1))
|
||||
Expect(filtered[0].Path).To(Equal("model-Q4_K_M.gguf"))
|
||||
})
|
||||
|
||||
It("should find preferred model file", func() {
|
||||
files := []hfapi.ModelFile{
|
||||
{Path: "model-Q3_K_M.gguf"},
|
||||
{Path: "model-Q4_K_M.gguf"},
|
||||
{Path: "README.md", IsReadme: true},
|
||||
}
|
||||
|
||||
preferences := []string{"Q4_K_M", "Q3_K_M"}
|
||||
preferred := hfapi.FindPreferredModelFile(files, preferences)
|
||||
|
||||
Expect(preferred).ToNot(BeNil())
|
||||
Expect(preferred.Path).To(Equal("model-Q4_K_M.gguf"))
|
||||
Expect(preferred.IsReadme).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return nil if no preferred file found", func() {
|
||||
files := []hfapi.ModelFile{
|
||||
{Path: "model-Q2_K.gguf"},
|
||||
{Path: "README.md", IsReadme: true},
|
||||
}
|
||||
|
||||
preferences := []string{"Q4_K_M", "Q3_K_M"}
|
||||
preferred := hfapi.FindPreferredModelFile(files, preferences)
|
||||
|
||||
Expect(preferred).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
15
pkg/huggingface-api/hfapi_suite_test.go
Normal file
15
pkg/huggingface-api/hfapi_suite_test.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package hfapi_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestHfapi(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "HuggingFace API Suite")
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user