mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-03 03:02:38 -05:00
Compare commits
242 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
15da375c4d | ||
|
|
ba1b8e7757 | ||
|
|
79b68fdc25 | ||
|
|
a946cb08b5 | ||
|
|
d95d4992fe | ||
|
|
e13cb8346d | ||
|
|
615c56503e | ||
|
|
79a8edd8b9 | ||
|
|
8d138dd68f | ||
|
|
2b33844562 | ||
|
|
63e6721c2f | ||
|
|
4859d809aa | ||
|
|
be027b1ccd | ||
|
|
3ecadeeb93 | ||
|
|
4af3348f91 | ||
|
|
dde08845bf | ||
|
|
76d1ba168d | ||
|
|
80605e4f66 | ||
|
|
5b99584a31 | ||
|
|
fc134b18fe | ||
|
|
c2006273c5 | ||
|
|
5343889098 | ||
|
|
c42afc56d9 | ||
|
|
53f44dac89 | ||
|
|
0468456fad | ||
|
|
df899ee26a | ||
|
|
93fe25468f | ||
|
|
238aad666e | ||
|
|
4408ed4f88 | ||
|
|
5df1f59a3c | ||
|
|
8225697139 | ||
|
|
0c0186d866 | ||
|
|
ce2f8828f9 | ||
|
|
7a8565a45e | ||
|
|
192589a17f | ||
|
|
28ab73d4a1 | ||
|
|
ed4ac0b61e | ||
|
|
e41d8b65ce | ||
|
|
c28e5b39d6 | ||
|
|
b66bd2706f | ||
|
|
fa7a9d96f8 | ||
|
|
61d972a2ef | ||
|
|
fffdbc31c6 | ||
|
|
32c0ab3a7f | ||
|
|
24ce79a67c | ||
|
|
bfa8530088 | ||
|
|
4278144dd5 | ||
|
|
79fa4d691e | ||
|
|
7a3d9ee5c1 | ||
|
|
22923d3b23 | ||
|
|
d32a459209 | ||
|
|
47b2a502dd | ||
|
|
b85f339eb4 | ||
|
|
8821865eac | ||
|
|
4b30846d57 | ||
|
|
7a35986407 | ||
|
|
ee34aa7bd5 | ||
|
|
40cf798dfe | ||
|
|
18810038f5 | ||
|
|
8fb79bc6f6 | ||
|
|
4b5ad1405f | ||
|
|
4493078cdd | ||
|
|
7f68c89cbe | ||
|
|
69adc46936 | ||
|
|
d22439918f | ||
|
|
103d4e87e5 | ||
|
|
8c5ba9e0d7 | ||
|
|
f1b713df08 | ||
|
|
f94b89c1b5 | ||
|
|
a1b056737a | ||
|
|
a22f6a499d | ||
|
|
e5bf2a9a11 | ||
|
|
05aba5a311 | ||
|
|
354bf5debb | ||
|
|
7f88abb3b1 | ||
|
|
36b3a538f8 | ||
|
|
e293b65ad9 | ||
|
|
cce185b345 | ||
|
|
03ed4382c7 | ||
|
|
1c73e10676 | ||
|
|
4ade65f959 | ||
|
|
c54f5cdf12 | ||
|
|
33c48164d7 | ||
|
|
7aed3b3bac | ||
|
|
9e349c715e | ||
|
|
639ecb59b3 | ||
|
|
bfb0794f87 | ||
|
|
05f1e9e757 | ||
|
|
1ca6f6dada | ||
|
|
bc5397bcfc | ||
|
|
f452a027a2 | ||
|
|
7bac49fb87 | ||
|
|
02300cfbd1 | ||
|
|
17c5c732c7 | ||
|
|
10a66938f9 | ||
|
|
f0245fa36c | ||
|
|
83534f8e00 | ||
|
|
75eaf8c853 | ||
|
|
03096154d4 | ||
|
|
22c9e8c09e | ||
|
|
da16727ad6 | ||
|
|
ad44df6d83 | ||
|
|
276c552583 | ||
|
|
9109e5c149 | ||
|
|
71a84b91e3 | ||
|
|
209d40be71 | ||
|
|
bfd76805e8 | ||
|
|
561aa5e443 | ||
|
|
b0eb1ab2a1 | ||
|
|
1208fb6fa1 | ||
|
|
f98fe85c42 | ||
|
|
167c183c84 | ||
|
|
244e47e1e0 | ||
|
|
9680a0b0fe | ||
|
|
acbd10a661 | ||
|
|
c6b989be13 | ||
|
|
670103705c | ||
|
|
cb90bd226e | ||
|
|
df9b2abf84 | ||
|
|
582114bda9 | ||
|
|
91ffe5ac38 | ||
|
|
8a58d76254 | ||
|
|
c3442fe574 | ||
|
|
1087bd217e | ||
|
|
7ed3666d2e | ||
|
|
2e2e89e499 | ||
|
|
13c9c20f42 | ||
|
|
b3d3988d85 | ||
|
|
0529c7d0a0 | ||
|
|
af31a77061 | ||
|
|
2d8956167f | ||
|
|
509f85f82c | ||
|
|
bb2b377b18 | ||
|
|
48917889ce | ||
|
|
ef754259b0 | ||
|
|
7e26f28113 | ||
|
|
d7c8129549 | ||
|
|
3a8fbb698e | ||
|
|
b1ef34ef9f | ||
|
|
b7822250fe | ||
|
|
05055f7e95 | ||
|
|
c856d7dc73 | ||
|
|
69d565e55d | ||
|
|
fa6bbd9fa2 | ||
|
|
3f767121d2 | ||
|
|
e963e16bc5 | ||
|
|
1e9b115251 | ||
|
|
cd1e1124ea | ||
|
|
81b31b4283 | ||
|
|
d763bce46d | ||
|
|
4aac0ef42e | ||
|
|
7a36e8d967 | ||
|
|
dc2be93412 | ||
|
|
69a2b91495 | ||
|
|
791bc769c1 | ||
|
|
a15a1f07e3 | ||
|
|
c6f0b44228 | ||
|
|
cb0ed55d89 | ||
|
|
2fe97110fd | ||
|
|
fa8037b21d | ||
|
|
99a72a4b11 | ||
|
|
1a52ce1bd4 | ||
|
|
925d752f8d | ||
|
|
c0b9d00f35 | ||
|
|
fcf8d41a00 | ||
|
|
27c4161401 | ||
|
|
459b6ab86d | ||
|
|
336257cc3c | ||
|
|
df46a438b8 | ||
|
|
5e1d809904 | ||
|
|
a9c7ce7275 | ||
|
|
8c47c8c8ed | ||
|
|
8e8d427549 | ||
|
|
ee251115f4 | ||
|
|
661e66090c | ||
|
|
c38564e22c | ||
|
|
20f1e842b3 | ||
|
|
aa8965b634 | ||
|
|
35c676188b | ||
|
|
183559bb98 | ||
|
|
1123a5c49c | ||
|
|
6f17c260a7 | ||
|
|
da6278aae9 | ||
|
|
2e51871ad5 | ||
|
|
8067d25710 | ||
|
|
cb2df6c5bf | ||
|
|
07e1519b3f | ||
|
|
8fc41673fa | ||
|
|
fff0e5911b | ||
|
|
09346bdc06 | ||
|
|
d4d42740c8 | ||
|
|
5de7a43319 | ||
|
|
85e27ec74c | ||
|
|
698205a2f3 | ||
|
|
3ed582b091 | ||
|
|
752e33f676 | ||
|
|
930553ef60 | ||
|
|
fc8d5c9198 | ||
|
|
60b6472fa0 | ||
|
|
6b2c8277c2 | ||
|
|
6d5d3ebcf6 | ||
|
|
530c174fd3 | ||
|
|
8fb95686af | ||
|
|
4132085c01 | ||
|
|
c14f1ffcfd | ||
|
|
07cca4b69a | ||
|
|
dd927c36f6 | ||
|
|
052f42e926 | ||
|
|
30d43588ab | ||
|
|
d21ec22f74 | ||
|
|
04fecd634a | ||
|
|
33c14198db | ||
|
|
967c2727e3 | ||
|
|
f41f30ad92 | ||
|
|
e77340e8a5 | ||
|
|
d51a3090f7 | ||
|
|
1bf3bc932c | ||
|
|
564a47da4e | ||
|
|
c37ee93ff2 | ||
|
|
f4b65db4e7 | ||
|
|
f5fa8e6649 | ||
|
|
570e39bdcf | ||
|
|
2ebe37b671 | ||
|
|
dca685f784 | ||
|
|
84ebf2a2c9 | ||
|
|
ce5662ba90 | ||
|
|
9878f27813 | ||
|
|
f2b9452ec4 | ||
|
|
585da99c52 | ||
|
|
fd4f432079 | ||
|
|
238c68c57b | ||
|
|
04fbf5cb82 | ||
|
|
c85d559919 | ||
|
|
b5efc4f89e | ||
|
|
3f9c09a4c5 | ||
|
|
4a84660475 | ||
|
|
737248256e | ||
|
|
0ae334fc62 | ||
|
|
36c373b7c9 | ||
|
|
6afcb932b7 | ||
|
|
357bf571a3 | ||
|
|
e74ade9ebb |
288
.github/gallery-agent/agent.go
vendored
Normal file
288
.github/gallery-agent/agent.go
vendored
Normal file
@@ -0,0 +1,288 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||
"github.com/mudler/cogito"
|
||||
|
||||
"github.com/mudler/cogito/structures"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
var (
|
||||
openAIModel = os.Getenv("OPENAI_MODEL")
|
||||
openAIKey = os.Getenv("OPENAI_KEY")
|
||||
openAIBaseURL = os.Getenv("OPENAI_BASE_URL")
|
||||
galleryIndexPath = os.Getenv("GALLERY_INDEX_PATH")
|
||||
//defaultclient
|
||||
llm = cogito.NewOpenAILLM(openAIModel, openAIKey, openAIBaseURL)
|
||||
)
|
||||
|
||||
// cleanTextContent removes trailing spaces, tabs, and normalizes line endings
|
||||
// to prevent YAML linting issues like trailing spaces and multiple empty lines
|
||||
func cleanTextContent(text string) string {
|
||||
lines := strings.Split(text, "\n")
|
||||
var cleanedLines []string
|
||||
var prevEmpty bool
|
||||
for _, line := range lines {
|
||||
// Remove all trailing whitespace (spaces, tabs, etc.)
|
||||
trimmed := strings.TrimRight(line, " \t\r")
|
||||
// Avoid multiple consecutive empty lines
|
||||
if trimmed == "" {
|
||||
if !prevEmpty {
|
||||
cleanedLines = append(cleanedLines, "")
|
||||
}
|
||||
prevEmpty = true
|
||||
} else {
|
||||
cleanedLines = append(cleanedLines, trimmed)
|
||||
prevEmpty = false
|
||||
}
|
||||
}
|
||||
// Remove trailing empty lines from the result
|
||||
result := strings.Join(cleanedLines, "\n")
|
||||
return strings.TrimRight(result, "\n")
|
||||
}
|
||||
|
||||
// isModelExisting checks if a specific model ID exists in the gallery using text search
|
||||
func isModelExisting(modelID string) (bool, error) {
|
||||
indexPath := getGalleryIndexPath()
|
||||
content, err := os.ReadFile(indexPath)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read %s: %w", indexPath, err)
|
||||
}
|
||||
|
||||
contentStr := string(content)
|
||||
// Simple text search - if the model ID appears anywhere in the file, it exists
|
||||
return strings.Contains(contentStr, modelID), nil
|
||||
}
|
||||
|
||||
// filterExistingModels removes models that already exist in the gallery
|
||||
func filterExistingModels(models []ProcessedModel) ([]ProcessedModel, error) {
|
||||
var filteredModels []ProcessedModel
|
||||
for _, model := range models {
|
||||
exists, err := isModelExisting(model.ModelID)
|
||||
if err != nil {
|
||||
fmt.Printf("Error checking if model %s exists: %v, skipping\n", model.ModelID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !exists {
|
||||
filteredModels = append(filteredModels, model)
|
||||
} else {
|
||||
fmt.Printf("Skipping existing model: %s\n", model.ModelID)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Filtered out %d existing models, %d new models remaining\n",
|
||||
len(models)-len(filteredModels), len(filteredModels))
|
||||
|
||||
return filteredModels, nil
|
||||
}
|
||||
|
||||
// getGalleryIndexPath returns the gallery index file path, with a default fallback
|
||||
func getGalleryIndexPath() string {
|
||||
if galleryIndexPath != "" {
|
||||
return galleryIndexPath
|
||||
}
|
||||
return "gallery/index.yaml"
|
||||
}
|
||||
|
||||
func getRealReadme(ctx context.Context, repository string) (string, error) {
|
||||
// Create a conversation fragment
|
||||
fragment := cogito.NewEmptyFragment().
|
||||
AddMessage("user",
|
||||
`Your task is to get a clear description of a large language model from huggingface by using the provided tool. I will share with you a repository that might be quantized, and as such probably not by the original model author. We need to get the real description of the model, and not the one that might be quantized. You will have to call the tool to get the readme more than once by figuring out from the quantized readme which is the base model readme. This is the repository: `+repository)
|
||||
|
||||
// Execute with tools
|
||||
result, err := cogito.ExecuteTools(llm, fragment,
|
||||
cogito.WithIterations(3),
|
||||
cogito.WithMaxAttempts(3),
|
||||
cogito.WithTools(&HFReadmeTool{client: hfapi.NewClient()}))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
result = result.AddMessage("user", "Describe the model in a clear and concise way that can be shared in a model gallery.")
|
||||
|
||||
// Get a response
|
||||
newFragment, err := llm.Ask(ctx, result)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
content := newFragment.LastMessage().Content
|
||||
return cleanTextContent(content), nil
|
||||
}
|
||||
|
||||
func selectMostInterestingModels(ctx context.Context, searchResult *SearchResult) ([]ProcessedModel, error) {
|
||||
// Create a conversation fragment
|
||||
fragment := cogito.NewEmptyFragment().
|
||||
AddMessage("user",
|
||||
`Your task is to analyze a list of AI models and select the most interesting ones for a model gallery. You will be given detailed information about multiple models including their metadata, file information, and README content.
|
||||
|
||||
Consider the following criteria when selecting models:
|
||||
1. Model popularity (download count)
|
||||
2. Model recency (last modified date)
|
||||
3. Model completeness (has preferred model file, README, etc.)
|
||||
4. Model uniqueness (not duplicates or very similar models)
|
||||
5. Model quality (based on README content and description)
|
||||
6. Model utility (practical applications)
|
||||
|
||||
You should select models that would be most valuable for users browsing a model gallery. Prioritize models that are:
|
||||
- Well-documented with clear READMEs
|
||||
- Recently updated
|
||||
- Popular (high download count)
|
||||
- Have the preferred quantization format available
|
||||
- Offer unique capabilities or are from reputable authors
|
||||
|
||||
Return your analysis and selection reasoning.`)
|
||||
|
||||
// Add the search results as context
|
||||
modelsInfo := fmt.Sprintf("Found %d models matching '%s' with quantization preference '%s':\n\n",
|
||||
searchResult.TotalModelsFound, searchResult.SearchTerm, searchResult.Quantization)
|
||||
|
||||
for i, model := range searchResult.Models {
|
||||
modelsInfo += fmt.Sprintf("Model %d:\n", i+1)
|
||||
modelsInfo += fmt.Sprintf(" ID: %s\n", model.ModelID)
|
||||
modelsInfo += fmt.Sprintf(" Author: %s\n", model.Author)
|
||||
modelsInfo += fmt.Sprintf(" Downloads: %d\n", model.Downloads)
|
||||
modelsInfo += fmt.Sprintf(" Last Modified: %s\n", model.LastModified)
|
||||
modelsInfo += fmt.Sprintf(" Files: %d files\n", len(model.Files))
|
||||
|
||||
if model.PreferredModelFile != nil {
|
||||
modelsInfo += fmt.Sprintf(" Preferred Model File: %s (%d bytes)\n",
|
||||
model.PreferredModelFile.Path, model.PreferredModelFile.Size)
|
||||
} else {
|
||||
modelsInfo += " No preferred model file found\n"
|
||||
}
|
||||
|
||||
if model.ReadmeContent != "" {
|
||||
modelsInfo += fmt.Sprintf(" README: %s\n", model.ReadmeContent)
|
||||
}
|
||||
|
||||
if model.ProcessingError != "" {
|
||||
modelsInfo += fmt.Sprintf(" Processing Error: %s\n", model.ProcessingError)
|
||||
}
|
||||
|
||||
modelsInfo += "\n"
|
||||
}
|
||||
|
||||
fragment = fragment.AddMessage("user", modelsInfo)
|
||||
|
||||
fragment = fragment.AddMessage("user", "Based on your analysis, select the top 5 most interesting models and provide a brief explanation for each selection. Also, create a filtered SearchResult with only the selected models. Return just a list of repositories IDs, you will later be asked to output it as a JSON array with the json tool.")
|
||||
|
||||
// Get a response
|
||||
newFragment, err := llm.Ask(ctx, fragment)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println(newFragment.LastMessage().Content)
|
||||
repositories := struct {
|
||||
Repositories []string `json:"repositories"`
|
||||
}{}
|
||||
|
||||
s := structures.Structure{
|
||||
Schema: jsonschema.Definition{
|
||||
Type: jsonschema.Object,
|
||||
AdditionalProperties: false,
|
||||
Properties: map[string]jsonschema.Definition{
|
||||
"repositories": {
|
||||
Type: jsonschema.Array,
|
||||
Items: &jsonschema.Definition{Type: jsonschema.String},
|
||||
Description: "The trending repositories IDs",
|
||||
},
|
||||
},
|
||||
Required: []string{"repositories"},
|
||||
},
|
||||
Object: &repositories,
|
||||
}
|
||||
|
||||
err = newFragment.ExtractStructure(ctx, llm, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filteredModels := []ProcessedModel{}
|
||||
for _, m := range searchResult.Models {
|
||||
if slices.Contains(repositories.Repositories, m.ModelID) {
|
||||
filteredModels = append(filteredModels, m)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredModels, nil
|
||||
}
|
||||
|
||||
// ModelFamily represents a YAML anchor/family
|
||||
type ModelFamily struct {
|
||||
Anchor string `json:"anchor"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// selectModelFamily selects the appropriate model family/anchor for a given model
|
||||
func selectModelFamily(ctx context.Context, model ProcessedModel, availableFamilies []ModelFamily) (string, error) {
|
||||
// Create a conversation fragment
|
||||
fragment := cogito.NewEmptyFragment().
|
||||
AddMessage("user",
|
||||
`Your task is to select the most appropriate model family/anchor for a given AI model. You will be provided with:
|
||||
1. Information about the model (name, description, etc.)
|
||||
2. A list of available model families/anchors
|
||||
|
||||
You need to select the family that best matches the model's architecture, capabilities, or characteristics. Consider:
|
||||
- Model architecture (e.g., Llama, Qwen, Mistral, etc.)
|
||||
- Model capabilities (e.g., vision, coding, chat, etc.)
|
||||
- Model size/type (e.g., small, medium, large)
|
||||
- Model purpose (e.g., general purpose, specialized, etc.)
|
||||
|
||||
Return the anchor name that best fits the model.`)
|
||||
|
||||
// Add model information
|
||||
modelInfo := "Model Information:\n"
|
||||
modelInfo += fmt.Sprintf(" ID: %s\n", model.ModelID)
|
||||
modelInfo += fmt.Sprintf(" Author: %s\n", model.Author)
|
||||
modelInfo += fmt.Sprintf(" Downloads: %d\n", model.Downloads)
|
||||
modelInfo += fmt.Sprintf(" Description: %s\n", model.ReadmeContentPreview)
|
||||
|
||||
fragment = fragment.AddMessage("user", modelInfo)
|
||||
|
||||
// Add available families
|
||||
familiesInfo := "Available Model Families:\n"
|
||||
for _, family := range availableFamilies {
|
||||
familiesInfo += fmt.Sprintf(" - %s (%s)\n", family.Anchor, family.Name)
|
||||
}
|
||||
|
||||
fragment = fragment.AddMessage("user", familiesInfo)
|
||||
fragment = fragment.AddMessage("user", "Select the most appropriate family anchor for this model. Return just the anchor name.")
|
||||
|
||||
// Get a response
|
||||
newFragment, err := llm.Ask(ctx, fragment)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Extract the selected family
|
||||
selectedFamily := strings.TrimSpace(newFragment.LastMessage().Content)
|
||||
|
||||
// Validate that the selected family exists in our list
|
||||
for _, family := range availableFamilies {
|
||||
if family.Anchor == selectedFamily {
|
||||
return selectedFamily, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If no exact match, try to find a close match
|
||||
for _, family := range availableFamilies {
|
||||
if strings.Contains(strings.ToLower(family.Anchor), strings.ToLower(selectedFamily)) ||
|
||||
strings.Contains(strings.ToLower(selectedFamily), strings.ToLower(family.Anchor)) {
|
||||
return family.Anchor, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Default fallback
|
||||
return "llama3", nil
|
||||
}
|
||||
203
.github/gallery-agent/gallery.go
vendored
Normal file
203
.github/gallery-agent/gallery.go
vendored
Normal file
@@ -0,0 +1,203 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// generateYAMLEntry generates a YAML entry for a model using the specified anchor
|
||||
func generateYAMLEntry(model ProcessedModel, familyAnchor string) string {
|
||||
// Extract model name from ModelID
|
||||
parts := strings.Split(model.ModelID, "/")
|
||||
modelName := model.ModelID
|
||||
if len(parts) > 0 {
|
||||
modelName = strings.ToLower(parts[len(parts)-1])
|
||||
}
|
||||
// Remove common suffixes
|
||||
modelName = strings.ReplaceAll(modelName, "-gguf", "")
|
||||
modelName = strings.ReplaceAll(modelName, "-q4_k_m", "")
|
||||
modelName = strings.ReplaceAll(modelName, "-q4_k_s", "")
|
||||
modelName = strings.ReplaceAll(modelName, "-q3_k_m", "")
|
||||
modelName = strings.ReplaceAll(modelName, "-q2_k", "")
|
||||
|
||||
fileName := ""
|
||||
checksum := ""
|
||||
if model.PreferredModelFile != nil {
|
||||
fileParts := strings.Split(model.PreferredModelFile.Path, "/")
|
||||
if len(fileParts) > 0 {
|
||||
fileName = fileParts[len(fileParts)-1]
|
||||
}
|
||||
checksum = model.PreferredModelFile.SHA256
|
||||
} else {
|
||||
fileName = model.ModelID
|
||||
}
|
||||
|
||||
description := model.ReadmeContent
|
||||
if description == "" {
|
||||
description = fmt.Sprintf("AI model: %s", modelName)
|
||||
}
|
||||
|
||||
// Clean up description to prevent YAML linting issues
|
||||
description = cleanTextContent(description)
|
||||
|
||||
// Format description for YAML (indent each line and ensure no trailing spaces)
|
||||
lines := strings.Split(description, "\n")
|
||||
var formattedLines []string
|
||||
for _, line := range lines {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
// Keep empty lines as empty (no indentation)
|
||||
formattedLines = append(formattedLines, "")
|
||||
} else {
|
||||
// Add indentation to non-empty lines
|
||||
formattedLines = append(formattedLines, " "+line)
|
||||
}
|
||||
}
|
||||
formattedDescription := strings.Join(formattedLines, "\n")
|
||||
// Remove any trailing spaces from the formatted description
|
||||
formattedDescription = strings.TrimRight(formattedDescription, " \t")
|
||||
yamlTemplate := ""
|
||||
if checksum != "" {
|
||||
yamlTemplate = `- !!merge <<: *%s
|
||||
name: "%s"
|
||||
urls:
|
||||
- https://huggingface.co/%s
|
||||
description: |
|
||||
%s
|
||||
overrides:
|
||||
parameters:
|
||||
model: %s
|
||||
files:
|
||||
- filename: %s
|
||||
sha256: %s
|
||||
uri: huggingface://%s/%s`
|
||||
return fmt.Sprintf(yamlTemplate,
|
||||
familyAnchor,
|
||||
modelName,
|
||||
model.ModelID,
|
||||
formattedDescription,
|
||||
fileName,
|
||||
fileName,
|
||||
checksum,
|
||||
model.ModelID,
|
||||
fileName,
|
||||
)
|
||||
} else {
|
||||
yamlTemplate = `- !!merge <<: *%s
|
||||
name: "%s"
|
||||
urls:
|
||||
- https://huggingface.co/%s
|
||||
description: |
|
||||
%s
|
||||
overrides:
|
||||
parameters:
|
||||
model: %s`
|
||||
return fmt.Sprintf(yamlTemplate,
|
||||
familyAnchor,
|
||||
modelName,
|
||||
model.ModelID,
|
||||
formattedDescription,
|
||||
fileName,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// extractModelFamilies extracts all YAML anchors from the gallery index.yaml file
|
||||
func extractModelFamilies() ([]ModelFamily, error) {
|
||||
// Read the index.yaml file
|
||||
indexPath := getGalleryIndexPath()
|
||||
content, err := os.ReadFile(indexPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read %s: %w", indexPath, err)
|
||||
}
|
||||
|
||||
lines := strings.Split(string(content), "\n")
|
||||
var families []ModelFamily
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
// Look for YAML anchors (lines starting with "- &")
|
||||
if strings.HasPrefix(line, "- &") {
|
||||
// Extract the anchor name (everything after "- &")
|
||||
anchor := strings.TrimPrefix(line, "- &")
|
||||
// Remove any trailing colon or other characters
|
||||
anchor = strings.Split(anchor, ":")[0]
|
||||
anchor = strings.Split(anchor, " ")[0]
|
||||
|
||||
if anchor != "" {
|
||||
families = append(families, ModelFamily{
|
||||
Anchor: anchor,
|
||||
Name: anchor, // Use anchor as name for now
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return families, nil
|
||||
}
|
||||
|
||||
// generateYAMLForModels generates YAML entries for selected models and appends to index.yaml
|
||||
func generateYAMLForModels(ctx context.Context, models []ProcessedModel) error {
|
||||
// Extract available model families
|
||||
families, err := extractModelFamilies()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract model families: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Found %d model families: %v\n", len(families),
|
||||
func() []string {
|
||||
var names []string
|
||||
for _, f := range families {
|
||||
names = append(names, f.Anchor)
|
||||
}
|
||||
return names
|
||||
}())
|
||||
|
||||
// Generate YAML entries for each model
|
||||
var yamlEntries []string
|
||||
for _, model := range models {
|
||||
fmt.Printf("Selecting family for model: %s\n", model.ModelID)
|
||||
|
||||
// Select appropriate family for this model
|
||||
familyAnchor, err := selectModelFamily(ctx, model, families)
|
||||
if err != nil {
|
||||
fmt.Printf("Error selecting family for %s: %v, using default\n", model.ModelID, err)
|
||||
familyAnchor = "llama3" // Default fallback
|
||||
}
|
||||
|
||||
fmt.Printf("Selected family '%s' for model %s\n", familyAnchor, model.ModelID)
|
||||
|
||||
// Generate YAML entry
|
||||
yamlEntry := generateYAMLEntry(model, familyAnchor)
|
||||
yamlEntries = append(yamlEntries, yamlEntry)
|
||||
}
|
||||
|
||||
// Append to index.yaml
|
||||
if len(yamlEntries) > 0 {
|
||||
indexPath := getGalleryIndexPath()
|
||||
fmt.Printf("Appending YAML entries to %s...\n", indexPath)
|
||||
|
||||
// Read current content
|
||||
content, err := os.ReadFile(indexPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read %s: %w", indexPath, err)
|
||||
}
|
||||
|
||||
// Append new entries
|
||||
// Remove trailing whitespace from existing content and join entries without extra newlines
|
||||
existingContent := strings.TrimRight(string(content), " \t\n\r")
|
||||
yamlBlock := strings.Join(yamlEntries, "\n")
|
||||
newContent := existingContent + "\n" + yamlBlock + "\n"
|
||||
|
||||
// Write back to file
|
||||
err = os.WriteFile(indexPath, []byte(newContent), 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write %s: %w", indexPath, err)
|
||||
}
|
||||
|
||||
fmt.Printf("Successfully added %d models to %s\n", len(yamlEntries), indexPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
39
.github/gallery-agent/go.mod
vendored
Normal file
39
.github/gallery-agent/go.mod
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
module github.com/go-skynet/LocalAI/.github/gallery-agent
|
||||
|
||||
go 1.24.1
|
||||
|
||||
require (
|
||||
github.com/mudler/cogito v0.3.0
|
||||
github.com/onsi/ginkgo/v2 v2.25.3
|
||||
github.com/onsi/gomega v1.38.2
|
||||
github.com/sashabaranov/go-openai v1.41.2
|
||||
github.com/tmc/langchaingo v0.1.13
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.1 // indirect
|
||||
github.com/Masterminds/goutils v1.1.1 // indirect
|
||||
github.com/Masterminds/semver/v3 v3.4.0 // indirect
|
||||
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/huandu/xstrings v1.5.0 // indirect
|
||||
github.com/mitchellh/copystructure v1.2.0 // indirect
|
||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||
github.com/modelcontextprotocol/go-sdk v1.0.0 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/spf13/cast v1.7.0 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.41.0 // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.28.0 // indirect
|
||||
golang.org/x/tools v0.36.0 // indirect
|
||||
)
|
||||
168
.github/gallery-agent/go.sum
vendored
Normal file
168
.github/gallery-agent/go.sum
vendored
Normal file
@@ -0,0 +1,168 @@
|
||||
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
||||
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
|
||||
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
|
||||
github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0=
|
||||
github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
|
||||
github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs=
|
||||
github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
|
||||
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
|
||||
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.2.2+incompatible h1:CjwRSksz8Yo4+RmQ339Dp/D2tGO5JxwYeqtMOEe0LDw=
|
||||
github.com/docker/docker v28.2.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
|
||||
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8=
|
||||
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
|
||||
github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
|
||||
github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ=
|
||||
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
||||
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
|
||||
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
||||
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
||||
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
|
||||
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
||||
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
||||
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||
github.com/modelcontextprotocol/go-sdk v1.0.0 h1:Z4MSjLi38bTgLrd/LjSmofqRqyBiVKRyQSJgw8q8V74=
|
||||
github.com/modelcontextprotocol/go-sdk v1.0.0/go.mod h1:nYtYQroQ2KQiM0/SbyEPUWQ6xs4B95gJjEalc9AQyOs=
|
||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/mudler/cogito v0.3.0 h1:NbVAO3bLkK5oGSY0xq87jlz8C9OIsLW55s+8Hfzeu9s=
|
||||
github.com/mudler/cogito v0.3.0/go.mod h1:abMwl+CUjCp87IufA2quZdZt0bbLaHHN79o17HbUKxU=
|
||||
github.com/onsi/ginkgo/v2 v2.25.3 h1:Ty8+Yi/ayDAGtk4XxmmfUy4GabvM+MegeB4cDLRi6nw=
|
||||
github.com/onsi/ginkgo/v2 v2.25.3/go.mod h1:43uiyQC4Ed2tkOzLsEYm7hnrb7UJTWHYNsuy3bG/snE=
|
||||
github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A=
|
||||
github.com/onsi/gomega v1.38.2/go.mod h1:W2MJcYxRGV63b418Ai34Ud0hEdTVXq9NW9+Sx6uXf3k=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM=
|
||||
github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/shirou/gopsutil/v4 v4.25.5 h1:rtd9piuSMGeU8g1RMXjZs9y9luK5BwtnG7dZaQUJAsc=
|
||||
github.com/shirou/gopsutil/v4 v4.25.5/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
||||
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/testcontainers/testcontainers-go v0.38.0 h1:d7uEapLcv2P8AvH8ahLqDMMxda2W9gQN1nRbHS28HBw=
|
||||
github.com/testcontainers/testcontainers-go v0.38.0/go.mod h1:C52c9MoHpWO+C4aqmgSU+hxlR5jlEayWtgYrb8Pzz1w=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA=
|
||||
github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
||||
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
|
||||
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
|
||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
299
.github/gallery-agent/hfapi/client.go
vendored
Normal file
299
.github/gallery-agent/hfapi/client.go
vendored
Normal file
@@ -0,0 +1,299 @@
|
||||
package hfapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Model represents a model from the Hugging Face API
|
||||
type Model struct {
|
||||
ModelID string `json:"modelId"`
|
||||
Author string `json:"author"`
|
||||
Downloads int `json:"downloads"`
|
||||
LastModified string `json:"lastModified"`
|
||||
PipelineTag string `json:"pipelineTag"`
|
||||
Private bool `json:"private"`
|
||||
Tags []string `json:"tags"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
UpdatedAt string `json:"updatedAt"`
|
||||
Sha string `json:"sha"`
|
||||
Config map[string]interface{} `json:"config"`
|
||||
ModelIndex string `json:"model_index"`
|
||||
LibraryName string `json:"library_name"`
|
||||
MaskToken string `json:"mask_token"`
|
||||
TokenizerClass string `json:"tokenizer_class"`
|
||||
}
|
||||
|
||||
// FileInfo represents file information from HuggingFace
|
||||
type FileInfo struct {
|
||||
Type string `json:"type"`
|
||||
Oid string `json:"oid"`
|
||||
Size int64 `json:"size"`
|
||||
Path string `json:"path"`
|
||||
LFS *LFSInfo `json:"lfs,omitempty"`
|
||||
XetHash string `json:"xetHash,omitempty"`
|
||||
}
|
||||
|
||||
// LFSInfo represents LFS (Large File Storage) information
|
||||
type LFSInfo struct {
|
||||
Oid string `json:"oid"`
|
||||
Size int64 `json:"size"`
|
||||
PointerSize int `json:"pointerSize"`
|
||||
}
|
||||
|
||||
// ModelFile represents a file in a model repository
|
||||
type ModelFile struct {
|
||||
Path string
|
||||
Size int64
|
||||
SHA256 string
|
||||
IsReadme bool
|
||||
}
|
||||
|
||||
// ModelDetails represents detailed information about a model
|
||||
type ModelDetails struct {
|
||||
ModelID string
|
||||
Author string
|
||||
Files []ModelFile
|
||||
ReadmeFile *ModelFile
|
||||
ReadmeContent string
|
||||
}
|
||||
|
||||
// SearchParams represents the parameters for searching models
|
||||
type SearchParams struct {
|
||||
Sort string `json:"sort"`
|
||||
Direction int `json:"direction"`
|
||||
Limit int `json:"limit"`
|
||||
Search string `json:"search"`
|
||||
}
|
||||
|
||||
// Client represents a Hugging Face API client
|
||||
type Client struct {
|
||||
baseURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a new Hugging Face API client
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
baseURL: "https://huggingface.co/api/models",
|
||||
client: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
// SearchModels searches for models using the Hugging Face API
|
||||
func (c *Client) SearchModels(params SearchParams) ([]Model, error) {
|
||||
req, err := http.NewRequest("GET", c.baseURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
// Add query parameters
|
||||
q := req.URL.Query()
|
||||
q.Add("sort", params.Sort)
|
||||
q.Add("direction", fmt.Sprintf("%d", params.Direction))
|
||||
q.Add("limit", fmt.Sprintf("%d", params.Limit))
|
||||
q.Add("search", params.Search)
|
||||
req.URL.RawQuery = q.Encode()
|
||||
|
||||
// Make the HTTP request
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to fetch models. Status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Read the response body
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
// Parse the JSON response
|
||||
var models []Model
|
||||
if err := json.Unmarshal(body, &models); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// GetLatest fetches the latest GGUF models
|
||||
func (c *Client) GetLatest(searchTerm string, limit int) ([]Model, error) {
|
||||
params := SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: limit,
|
||||
Search: searchTerm,
|
||||
}
|
||||
|
||||
return c.SearchModels(params)
|
||||
}
|
||||
|
||||
// BaseURL returns the current base URL
|
||||
func (c *Client) BaseURL() string {
|
||||
return c.baseURL
|
||||
}
|
||||
|
||||
// SetBaseURL sets a new base URL (useful for testing)
|
||||
func (c *Client) SetBaseURL(url string) {
|
||||
c.baseURL = url
|
||||
}
|
||||
|
||||
// ListFiles lists all files in a HuggingFace repository
|
||||
func (c *Client) ListFiles(repoID string) ([]FileInfo, error) {
|
||||
baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
|
||||
url := fmt.Sprintf("%s/api/models/%s/tree/main", baseURL, repoID)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("failed to fetch files. Status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
var files []FileInfo
|
||||
if err := json.Unmarshal(body, &files); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// GetFileSHA gets the SHA256 checksum for a specific file by searching through the file list
|
||||
func (c *Client) GetFileSHA(repoID, fileName string) (string, error) {
|
||||
files, err := c.ListFiles(repoID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to list files: %w", err)
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if filepath.Base(file.Path) == fileName {
|
||||
if file.LFS != nil && file.LFS.Oid != "" {
|
||||
// The LFS OID contains the SHA256 hash
|
||||
return file.LFS.Oid, nil
|
||||
}
|
||||
// If no LFS, return the regular OID
|
||||
return file.Oid, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("file %s not found", fileName)
|
||||
}
|
||||
|
||||
// GetModelDetails gets detailed information about a model including files and checksums
|
||||
func (c *Client) GetModelDetails(repoID string) (*ModelDetails, error) {
|
||||
files, err := c.ListFiles(repoID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list files: %w", err)
|
||||
}
|
||||
|
||||
details := &ModelDetails{
|
||||
ModelID: repoID,
|
||||
Author: strings.Split(repoID, "/")[0],
|
||||
Files: make([]ModelFile, 0, len(files)),
|
||||
}
|
||||
|
||||
// Process each file
|
||||
for _, file := range files {
|
||||
fileName := filepath.Base(file.Path)
|
||||
isReadme := strings.Contains(strings.ToLower(fileName), "readme")
|
||||
|
||||
// Extract SHA256 from LFS or use OID
|
||||
sha256 := ""
|
||||
if file.LFS != nil && file.LFS.Oid != "" {
|
||||
sha256 = file.LFS.Oid
|
||||
} else {
|
||||
sha256 = file.Oid
|
||||
}
|
||||
|
||||
modelFile := ModelFile{
|
||||
Path: file.Path,
|
||||
Size: file.Size,
|
||||
SHA256: sha256,
|
||||
IsReadme: isReadme,
|
||||
}
|
||||
|
||||
details.Files = append(details.Files, modelFile)
|
||||
|
||||
// Set the readme file
|
||||
if isReadme && details.ReadmeFile == nil {
|
||||
details.ReadmeFile = &modelFile
|
||||
}
|
||||
}
|
||||
|
||||
return details, nil
|
||||
}
|
||||
|
||||
// GetReadmeContent gets the content of a README file
|
||||
func (c *Client) GetReadmeContent(repoID, readmePath string) (string, error) {
|
||||
baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
|
||||
url := fmt.Sprintf("%s/%s/raw/main/%s", baseURL, repoID, readmePath)
|
||||
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to make request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("failed to fetch readme content. Status code: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// FilterFilesByQuantization filters files by quantization type
|
||||
func FilterFilesByQuantization(files []ModelFile, quantization string) []ModelFile {
|
||||
var filtered []ModelFile
|
||||
for _, file := range files {
|
||||
fileName := filepath.Base(file.Path)
|
||||
if strings.Contains(strings.ToLower(fileName), strings.ToLower(quantization)) {
|
||||
filtered = append(filtered, file)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// FindPreferredModelFile finds the preferred model file based on quantization preferences
|
||||
func FindPreferredModelFile(files []ModelFile, preferences []string) *ModelFile {
|
||||
for _, preference := range preferences {
|
||||
for i := range files {
|
||||
fileName := filepath.Base(files[i].Path)
|
||||
if strings.Contains(strings.ToLower(fileName), strings.ToLower(preference)) {
|
||||
return &files[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
511
.github/gallery-agent/hfapi/client_test.go
vendored
Normal file
511
.github/gallery-agent/hfapi/client_test.go
vendored
Normal file
@@ -0,0 +1,511 @@
|
||||
package hfapi_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||
)
|
||||
|
||||
var _ = Describe("HuggingFace API Client", func() {
|
||||
var (
|
||||
client *hfapi.Client
|
||||
server *httptest.Server
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
client = hfapi.NewClient()
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
if server != nil {
|
||||
server.Close()
|
||||
}
|
||||
})
|
||||
|
||||
Context("when creating a new client", func() {
|
||||
It("should initialize with correct base URL", func() {
|
||||
Expect(client).ToNot(BeNil())
|
||||
Expect(client.BaseURL()).To(Equal("https://huggingface.co/api/models"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when searching for models", func() {
|
||||
BeforeEach(func() {
|
||||
// Mock response data
|
||||
mockResponse := `[
|
||||
{
|
||||
"modelId": "test-model-1",
|
||||
"author": "test-author",
|
||||
"downloads": 1000,
|
||||
"lastModified": "2024-01-01T00:00:00.000Z",
|
||||
"pipelineTag": "text-generation",
|
||||
"private": false,
|
||||
"tags": ["gguf", "llama"],
|
||||
"createdAt": "2024-01-01T00:00:00.000Z",
|
||||
"updatedAt": "2024-01-01T00:00:00.000Z",
|
||||
"sha": "abc123",
|
||||
"config": {},
|
||||
"model_index": "test-index",
|
||||
"library_name": "transformers",
|
||||
"mask_token": null,
|
||||
"tokenizer_class": "LlamaTokenizer"
|
||||
},
|
||||
{
|
||||
"modelId": "test-model-2",
|
||||
"author": "test-author-2",
|
||||
"downloads": 2000,
|
||||
"lastModified": "2024-01-02T00:00:00.000Z",
|
||||
"pipelineTag": "text-generation",
|
||||
"private": false,
|
||||
"tags": ["gguf", "mistral"],
|
||||
"createdAt": "2024-01-02T00:00:00.000Z",
|
||||
"updatedAt": "2024-01-02T00:00:00.000Z",
|
||||
"sha": "def456",
|
||||
"config": {},
|
||||
"model_index": "test-index-2",
|
||||
"library_name": "transformers",
|
||||
"mask_token": null,
|
||||
"tokenizer_class": "MistralTokenizer"
|
||||
}
|
||||
]`
|
||||
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify request parameters
|
||||
Expect(r.URL.Query().Get("sort")).To(Equal("lastModified"))
|
||||
Expect(r.URL.Query().Get("direction")).To(Equal("-1"))
|
||||
Expect(r.URL.Query().Get("limit")).To(Equal("30"))
|
||||
Expect(r.URL.Query().Get("search")).To(Equal("GGUF"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockResponse))
|
||||
}))
|
||||
|
||||
// Override the client's base URL to use our mock server
|
||||
client.SetBaseURL(server.URL)
|
||||
})
|
||||
|
||||
It("should successfully search for models", func() {
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: 30,
|
||||
Search: "GGUF",
|
||||
}
|
||||
|
||||
models, err := client.SearchModels(params)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(models).To(HaveLen(2))
|
||||
|
||||
// Verify first model
|
||||
Expect(models[0].ModelID).To(Equal("test-model-1"))
|
||||
Expect(models[0].Author).To(Equal("test-author"))
|
||||
Expect(models[0].Downloads).To(Equal(1000))
|
||||
Expect(models[0].PipelineTag).To(Equal("text-generation"))
|
||||
Expect(models[0].Private).To(BeFalse())
|
||||
Expect(models[0].Tags).To(ContainElements("gguf", "llama"))
|
||||
|
||||
// Verify second model
|
||||
Expect(models[1].ModelID).To(Equal("test-model-2"))
|
||||
Expect(models[1].Author).To(Equal("test-author-2"))
|
||||
Expect(models[1].Downloads).To(Equal(2000))
|
||||
Expect(models[1].Tags).To(ContainElements("gguf", "mistral"))
|
||||
})
|
||||
|
||||
It("should handle empty search results", func() {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("[]"))
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: 30,
|
||||
Search: "nonexistent",
|
||||
}
|
||||
|
||||
models, err := client.SearchModels(params)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(models).To(HaveLen(0))
|
||||
})
|
||||
|
||||
It("should handle HTTP errors", func() {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
w.Write([]byte("Internal Server Error"))
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: 30,
|
||||
Search: "GGUF",
|
||||
}
|
||||
|
||||
models, err := client.SearchModels(params)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("Status code: 500"))
|
||||
Expect(models).To(BeNil())
|
||||
})
|
||||
|
||||
It("should handle malformed JSON response", func() {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("invalid json"))
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: 30,
|
||||
Search: "GGUF",
|
||||
}
|
||||
|
||||
models, err := client.SearchModels(params)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("failed to parse JSON response"))
|
||||
Expect(models).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("when getting latest GGUF models", func() {
|
||||
BeforeEach(func() {
|
||||
mockResponse := `[
|
||||
{
|
||||
"modelId": "latest-gguf-model",
|
||||
"author": "gguf-author",
|
||||
"downloads": 5000,
|
||||
"lastModified": "2024-01-03T00:00:00.000Z",
|
||||
"pipelineTag": "text-generation",
|
||||
"private": false,
|
||||
"tags": ["gguf", "latest"],
|
||||
"createdAt": "2024-01-03T00:00:00.000Z",
|
||||
"updatedAt": "2024-01-03T00:00:00.000Z",
|
||||
"sha": "latest123",
|
||||
"config": {},
|
||||
"model_index": "latest-index",
|
||||
"library_name": "transformers",
|
||||
"mask_token": null,
|
||||
"tokenizer_class": "LlamaTokenizer"
|
||||
}
|
||||
]`
|
||||
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify the search parameters are correct for GGUF search
|
||||
Expect(r.URL.Query().Get("search")).To(Equal("GGUF"))
|
||||
Expect(r.URL.Query().Get("sort")).To(Equal("lastModified"))
|
||||
Expect(r.URL.Query().Get("direction")).To(Equal("-1"))
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockResponse))
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
})
|
||||
|
||||
It("should fetch latest GGUF models with correct parameters", func() {
|
||||
models, err := client.GetLatest("GGUF", 10)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(models).To(HaveLen(1))
|
||||
Expect(models[0].ModelID).To(Equal("latest-gguf-model"))
|
||||
Expect(models[0].Author).To(Equal("gguf-author"))
|
||||
Expect(models[0].Downloads).To(Equal(5000))
|
||||
Expect(models[0].Tags).To(ContainElements("gguf", "latest"))
|
||||
})
|
||||
|
||||
It("should use custom search term", func() {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
Expect(r.URL.Query().Get("search")).To(Equal("custom-search"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("[]"))
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
models, err := client.GetLatest("custom-search", 5)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(models).To(HaveLen(0))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when handling network errors", func() {
|
||||
It("should handle connection failures gracefully", func() {
|
||||
// Use an invalid URL to simulate connection failure
|
||||
client.SetBaseURL("http://invalid-url-that-does-not-exist")
|
||||
|
||||
params := hfapi.SearchParams{
|
||||
Sort: "lastModified",
|
||||
Direction: -1,
|
||||
Limit: 30,
|
||||
Search: "GGUF",
|
||||
}
|
||||
|
||||
models, err := client.SearchModels(params)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("failed to make request"))
|
||||
Expect(models).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("when listing files", func() {
|
||||
BeforeEach(func() {
|
||||
mockFilesResponse := `[
|
||||
{
|
||||
"type": "file",
|
||||
"path": "model-Q4_K_M.gguf",
|
||||
"size": 1000000,
|
||||
"oid": "abc123",
|
||||
"lfs": {
|
||||
"oid": "def456789",
|
||||
"size": 1000000,
|
||||
"pointerSize": 135
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"path": "README.md",
|
||||
"size": 5000,
|
||||
"oid": "readme123"
|
||||
},
|
||||
{
|
||||
"type": "file",
|
||||
"path": "config.json",
|
||||
"size": 1000,
|
||||
"oid": "config123"
|
||||
}
|
||||
]`
|
||||
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/tree/main") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockFilesResponse))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
})
|
||||
|
||||
It("should list files successfully", func() {
|
||||
files, err := client.ListFiles("test/model")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(files).To(HaveLen(3))
|
||||
|
||||
Expect(files[0].Path).To(Equal("model-Q4_K_M.gguf"))
|
||||
Expect(files[0].Size).To(Equal(int64(1000000)))
|
||||
Expect(files[0].LFS).ToNot(BeNil())
|
||||
Expect(files[0].LFS.Oid).To(Equal("def456789"))
|
||||
|
||||
Expect(files[1].Path).To(Equal("README.md"))
|
||||
Expect(files[1].Size).To(Equal(int64(5000)))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when getting file SHA", func() {
|
||||
BeforeEach(func() {
|
||||
mockFileInfoResponse := `{
|
||||
"path": "model-Q4_K_M.gguf",
|
||||
"size": 1000000,
|
||||
"oid": "abc123",
|
||||
"lfs": {
|
||||
"oid": "sha256:def456",
|
||||
"size": 1000000,
|
||||
"pointer": "version https://git-lfs.github.com/spec/v1",
|
||||
"sha256": "def456789"
|
||||
}
|
||||
}`
|
||||
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/paths-info") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockFileInfoResponse))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
})
|
||||
|
||||
It("should get file SHA successfully", func() {
|
||||
sha, err := client.GetFileSHA("test/model", "model-Q4_K_M.gguf")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(sha).To(Equal("def456789"))
|
||||
})
|
||||
|
||||
It("should handle missing SHA gracefully", func() {
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"path": "file.txt", "size": 100}`))
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
|
||||
sha, err := client.GetFileSHA("test/model", "file.txt")
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no SHA256 found"))
|
||||
Expect(sha).To(Equal(""))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when getting model details", func() {
|
||||
BeforeEach(func() {
|
||||
mockFilesResponse := `[
|
||||
{
|
||||
"path": "model-Q4_K_M.gguf",
|
||||
"size": 1000000,
|
||||
"oid": "abc123",
|
||||
"lfs": {
|
||||
"oid": "sha256:def456",
|
||||
"size": 1000000,
|
||||
"pointer": "version https://git-lfs.github.com/spec/v1",
|
||||
"sha256": "def456789"
|
||||
}
|
||||
},
|
||||
{
|
||||
"path": "README.md",
|
||||
"size": 5000,
|
||||
"oid": "readme123"
|
||||
}
|
||||
]`
|
||||
|
||||
mockFileInfoResponse := `{
|
||||
"path": "model-Q4_K_M.gguf",
|
||||
"size": 1000000,
|
||||
"oid": "abc123",
|
||||
"lfs": {
|
||||
"oid": "sha256:def456",
|
||||
"size": 1000000,
|
||||
"pointer": "version https://git-lfs.github.com/spec/v1",
|
||||
"sha256": "def456789"
|
||||
}
|
||||
}`
|
||||
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/tree/main") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockFilesResponse))
|
||||
} else if strings.Contains(r.URL.Path, "/paths-info") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockFileInfoResponse))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
})
|
||||
|
||||
It("should get model details successfully", func() {
|
||||
details, err := client.GetModelDetails("test/model")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(details.ModelID).To(Equal("test/model"))
|
||||
Expect(details.Author).To(Equal("test"))
|
||||
Expect(details.Files).To(HaveLen(2))
|
||||
|
||||
Expect(details.ReadmeFile).ToNot(BeNil())
|
||||
Expect(details.ReadmeFile.Path).To(Equal("README.md"))
|
||||
Expect(details.ReadmeFile.IsReadme).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("when getting README content", func() {
|
||||
BeforeEach(func() {
|
||||
mockReadmeContent := "# Test Model\n\nThis is a test model for demonstration purposes."
|
||||
|
||||
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/raw/main/") {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(mockReadmeContent))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
client.SetBaseURL(server.URL)
|
||||
})
|
||||
|
||||
It("should get README content successfully", func() {
|
||||
content, err := client.GetReadmeContent("test/model", "README.md")
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(content).To(Equal("# Test Model\n\nThis is a test model for demonstration purposes."))
|
||||
})
|
||||
})
|
||||
|
||||
Context("when filtering files", func() {
|
||||
It("should filter files by quantization", func() {
|
||||
files := []hfapi.ModelFile{
|
||||
{Path: "model-Q4_K_M.gguf"},
|
||||
{Path: "model-Q3_K_M.gguf"},
|
||||
{Path: "README.md", IsReadme: true},
|
||||
}
|
||||
|
||||
filtered := hfapi.FilterFilesByQuantization(files, "Q4_K_M")
|
||||
|
||||
Expect(filtered).To(HaveLen(1))
|
||||
Expect(filtered[0].Path).To(Equal("model-Q4_K_M.gguf"))
|
||||
})
|
||||
|
||||
It("should find preferred model file", func() {
|
||||
files := []hfapi.ModelFile{
|
||||
{Path: "model-Q3_K_M.gguf"},
|
||||
{Path: "model-Q4_K_M.gguf"},
|
||||
{Path: "README.md", IsReadme: true},
|
||||
}
|
||||
|
||||
preferences := []string{"Q4_K_M", "Q3_K_M"}
|
||||
preferred := hfapi.FindPreferredModelFile(files, preferences)
|
||||
|
||||
Expect(preferred).ToNot(BeNil())
|
||||
Expect(preferred.Path).To(Equal("model-Q4_K_M.gguf"))
|
||||
Expect(preferred.IsReadme).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return nil if no preferred file found", func() {
|
||||
files := []hfapi.ModelFile{
|
||||
{Path: "model-Q2_K.gguf"},
|
||||
{Path: "README.md", IsReadme: true},
|
||||
}
|
||||
|
||||
preferences := []string{"Q4_K_M", "Q3_K_M"}
|
||||
preferred := hfapi.FindPreferredModelFile(files, preferences)
|
||||
|
||||
Expect(preferred).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
13
.github/gallery-agent/hfapi/hfapi_suite_test.go
vendored
Normal file
13
.github/gallery-agent/hfapi/hfapi_suite_test.go
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
package hfapi_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestHfapi(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "HuggingFace API Suite")
|
||||
}
|
||||
351
.github/gallery-agent/main.go
vendored
Normal file
351
.github/gallery-agent/main.go
vendored
Normal file
@@ -0,0 +1,351 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||
)
|
||||
|
||||
// ProcessedModelFile represents a processed model file with additional metadata
|
||||
type ProcessedModelFile struct {
|
||||
Path string `json:"path"`
|
||||
Size int64 `json:"size"`
|
||||
SHA256 string `json:"sha256"`
|
||||
IsReadme bool `json:"is_readme"`
|
||||
FileType string `json:"file_type"` // "model", "readme", "other"
|
||||
}
|
||||
|
||||
// ProcessedModel represents a processed model with all gathered metadata
|
||||
type ProcessedModel struct {
|
||||
ModelID string `json:"model_id"`
|
||||
Author string `json:"author"`
|
||||
Downloads int `json:"downloads"`
|
||||
LastModified string `json:"last_modified"`
|
||||
Files []ProcessedModelFile `json:"files"`
|
||||
PreferredModelFile *ProcessedModelFile `json:"preferred_model_file,omitempty"`
|
||||
ReadmeFile *ProcessedModelFile `json:"readme_file,omitempty"`
|
||||
ReadmeContent string `json:"readme_content,omitempty"`
|
||||
ReadmeContentPreview string `json:"readme_content_preview,omitempty"`
|
||||
QuantizationPreferences []string `json:"quantization_preferences"`
|
||||
ProcessingError string `json:"processing_error,omitempty"`
|
||||
}
|
||||
|
||||
// SearchResult represents the complete result of searching and processing models
|
||||
type SearchResult struct {
|
||||
SearchTerm string `json:"search_term"`
|
||||
Limit int `json:"limit"`
|
||||
Quantization string `json:"quantization"`
|
||||
TotalModelsFound int `json:"total_models_found"`
|
||||
Models []ProcessedModel `json:"models"`
|
||||
FormattedOutput string `json:"formatted_output"`
|
||||
}
|
||||
|
||||
// AddedModelSummary represents a summary of models added to the gallery
|
||||
type AddedModelSummary struct {
|
||||
SearchTerm string `json:"search_term"`
|
||||
TotalFound int `json:"total_found"`
|
||||
ModelsAdded int `json:"models_added"`
|
||||
AddedModelIDs []string `json:"added_model_ids"`
|
||||
AddedModelURLs []string `json:"added_model_urls"`
|
||||
Quantization string `json:"quantization"`
|
||||
ProcessingTime string `json:"processing_time"`
|
||||
}
|
||||
|
||||
func main() {
|
||||
startTime := time.Now()
|
||||
|
||||
// Check for synthetic mode
|
||||
syntheticMode := os.Getenv("SYNTHETIC_MODE")
|
||||
if syntheticMode == "true" || syntheticMode == "1" {
|
||||
fmt.Println("Running in SYNTHETIC MODE - generating random test data")
|
||||
err := runSyntheticMode()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error in synthetic mode: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Get configuration from environment variables
|
||||
searchTerm := os.Getenv("SEARCH_TERM")
|
||||
if searchTerm == "" {
|
||||
searchTerm = "GGUF"
|
||||
}
|
||||
|
||||
limitStr := os.Getenv("LIMIT")
|
||||
if limitStr == "" {
|
||||
limitStr = "5"
|
||||
}
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing LIMIT: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
quantization := os.Getenv("QUANTIZATION")
|
||||
|
||||
maxModels := os.Getenv("MAX_MODELS")
|
||||
if maxModels == "" {
|
||||
maxModels = "1"
|
||||
}
|
||||
maxModelsInt, err := strconv.Atoi(maxModels)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error parsing MAX_MODELS: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Print configuration
|
||||
fmt.Printf("Gallery Agent Configuration:\n")
|
||||
fmt.Printf(" Search Term: %s\n", searchTerm)
|
||||
fmt.Printf(" Limit: %d\n", limit)
|
||||
fmt.Printf(" Quantization: %s\n", quantization)
|
||||
fmt.Printf(" Max Models to Add: %d\n", maxModelsInt)
|
||||
fmt.Printf(" Gallery Index Path: %s\n", os.Getenv("GALLERY_INDEX_PATH"))
|
||||
fmt.Println()
|
||||
|
||||
result, err := searchAndProcessModels(searchTerm, limit, quantization)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Println(result.FormattedOutput)
|
||||
|
||||
// Use AI agent to select the most interesting models
|
||||
fmt.Println("Using AI agent to select the most interesting models...")
|
||||
models, err := selectMostInterestingModels(context.Background(), result)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error in model selection: %v\n", err)
|
||||
// Continue with original result if selection fails
|
||||
models = result.Models
|
||||
}
|
||||
|
||||
fmt.Print(models)
|
||||
|
||||
// Filter out models that already exist in the gallery
|
||||
fmt.Println("Filtering out existing models...")
|
||||
models, err = filterExistingModels(models)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error filtering existing models: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Limit to maxModelsInt after filtering
|
||||
if len(models) > maxModelsInt {
|
||||
models = models[:maxModelsInt]
|
||||
}
|
||||
|
||||
// Track added models for summary
|
||||
var addedModelIDs []string
|
||||
var addedModelURLs []string
|
||||
|
||||
// Generate YAML entries and append to gallery/index.yaml
|
||||
if len(models) > 0 {
|
||||
for _, model := range models {
|
||||
addedModelIDs = append(addedModelIDs, model.ModelID)
|
||||
// Generate Hugging Face URL for the model
|
||||
modelURL := fmt.Sprintf("https://huggingface.co/%s", model.ModelID)
|
||||
addedModelURLs = append(addedModelURLs, modelURL)
|
||||
}
|
||||
fmt.Println("Generating YAML entries for selected models...")
|
||||
err = generateYAMLForModels(context.Background(), models)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error generating YAML entries: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
fmt.Println("No new models to add to the gallery.")
|
||||
}
|
||||
|
||||
// Create and write summary
|
||||
processingTime := time.Since(startTime).String()
|
||||
summary := AddedModelSummary{
|
||||
SearchTerm: searchTerm,
|
||||
TotalFound: result.TotalModelsFound,
|
||||
ModelsAdded: len(addedModelIDs),
|
||||
AddedModelIDs: addedModelIDs,
|
||||
AddedModelURLs: addedModelURLs,
|
||||
Quantization: quantization,
|
||||
ProcessingTime: processingTime,
|
||||
}
|
||||
|
||||
// Write summary to file
|
||||
summaryData, err := json.MarshalIndent(summary, "", " ")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error marshaling summary: %v\n", err)
|
||||
} else {
|
||||
err = os.WriteFile("gallery-agent-summary.json", summaryData, 0644)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error writing summary file: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("Summary written to gallery-agent-summary.json\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func searchAndProcessModels(searchTerm string, limit int, quantization string) (*SearchResult, error) {
|
||||
client := hfapi.NewClient()
|
||||
var outputBuilder strings.Builder
|
||||
|
||||
fmt.Println("Searching for models...")
|
||||
// Initialize the result struct
|
||||
result := &SearchResult{
|
||||
SearchTerm: searchTerm,
|
||||
Limit: limit,
|
||||
Quantization: quantization,
|
||||
Models: []ProcessedModel{},
|
||||
}
|
||||
|
||||
models, err := client.GetLatest(searchTerm, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch models: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("Models found:", len(models))
|
||||
result.TotalModelsFound = len(models)
|
||||
|
||||
if len(models) == 0 {
|
||||
outputBuilder.WriteString("No models found.\n")
|
||||
result.FormattedOutput = outputBuilder.String()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
outputBuilder.WriteString(fmt.Sprintf("Found %d models matching '%s':\n\n", len(models), searchTerm))
|
||||
|
||||
// Process each model
|
||||
for i, model := range models {
|
||||
outputBuilder.WriteString(fmt.Sprintf("%d. Processing Model: %s\n", i+1, model.ModelID))
|
||||
outputBuilder.WriteString(fmt.Sprintf(" Author: %s\n", model.Author))
|
||||
outputBuilder.WriteString(fmt.Sprintf(" Downloads: %d\n", model.Downloads))
|
||||
outputBuilder.WriteString(fmt.Sprintf(" Last Modified: %s\n", model.LastModified))
|
||||
|
||||
// Initialize processed model struct
|
||||
processedModel := ProcessedModel{
|
||||
ModelID: model.ModelID,
|
||||
Author: model.Author,
|
||||
Downloads: model.Downloads,
|
||||
LastModified: model.LastModified,
|
||||
QuantizationPreferences: []string{quantization, "Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"},
|
||||
}
|
||||
|
||||
// Get detailed model information
|
||||
details, err := client.GetModelDetails(model.ModelID)
|
||||
if err != nil {
|
||||
errorMsg := fmt.Sprintf(" Error getting model details: %v\n", err)
|
||||
outputBuilder.WriteString(errorMsg)
|
||||
processedModel.ProcessingError = err.Error()
|
||||
result.Models = append(result.Models, processedModel)
|
||||
continue
|
||||
}
|
||||
|
||||
// Define quantization preferences (in order of preference)
|
||||
quantizationPreferences := []string{quantization, "Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"}
|
||||
|
||||
// Find preferred model file
|
||||
preferredModelFile := hfapi.FindPreferredModelFile(details.Files, quantizationPreferences)
|
||||
|
||||
// Process files
|
||||
processedFiles := make([]ProcessedModelFile, len(details.Files))
|
||||
for j, file := range details.Files {
|
||||
fileType := "other"
|
||||
if file.IsReadme {
|
||||
fileType = "readme"
|
||||
} else if preferredModelFile != nil && file.Path == preferredModelFile.Path {
|
||||
fileType = "model"
|
||||
}
|
||||
|
||||
processedFiles[j] = ProcessedModelFile{
|
||||
Path: file.Path,
|
||||
Size: file.Size,
|
||||
SHA256: file.SHA256,
|
||||
IsReadme: file.IsReadme,
|
||||
FileType: fileType,
|
||||
}
|
||||
}
|
||||
|
||||
processedModel.Files = processedFiles
|
||||
|
||||
// Set preferred model file
|
||||
if preferredModelFile != nil {
|
||||
for _, file := range processedFiles {
|
||||
if file.Path == preferredModelFile.Path {
|
||||
processedModel.PreferredModelFile = &file
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Print file information
|
||||
outputBuilder.WriteString(fmt.Sprintf(" Files found: %d\n", len(details.Files)))
|
||||
|
||||
if preferredModelFile != nil {
|
||||
outputBuilder.WriteString(fmt.Sprintf(" Preferred Model File: %s (SHA256: %s)\n",
|
||||
preferredModelFile.Path,
|
||||
preferredModelFile.SHA256))
|
||||
} else {
|
||||
outputBuilder.WriteString(fmt.Sprintf(" No model file found with quantization preferences: %v\n", quantizationPreferences))
|
||||
}
|
||||
|
||||
if details.ReadmeFile != nil {
|
||||
outputBuilder.WriteString(fmt.Sprintf(" README File: %s\n", details.ReadmeFile.Path))
|
||||
|
||||
// Find and set readme file
|
||||
for _, file := range processedFiles {
|
||||
if file.IsReadme {
|
||||
processedModel.ReadmeFile = &file
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Getting real readme for", model.ModelID, "waiting...")
|
||||
// Use agent to get the real readme and prepare the model description
|
||||
readmeContent, err := getRealReadme(context.Background(), model.ModelID)
|
||||
if err == nil {
|
||||
processedModel.ReadmeContent = readmeContent
|
||||
processedModel.ReadmeContentPreview = truncateString(readmeContent, 200)
|
||||
outputBuilder.WriteString(fmt.Sprintf(" README Content Preview: %s\n",
|
||||
processedModel.ReadmeContentPreview))
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
fmt.Println("Real readme got", readmeContent)
|
||||
// Get README content
|
||||
// readmeContent, err := client.GetReadmeContent(model.ModelID, details.ReadmeFile.Path)
|
||||
// if err == nil {
|
||||
// processedModel.ReadmeContent = readmeContent
|
||||
// processedModel.ReadmeContentPreview = truncateString(readmeContent, 200)
|
||||
// outputBuilder.WriteString(fmt.Sprintf(" README Content Preview: %s\n",
|
||||
// processedModel.ReadmeContentPreview))
|
||||
// }
|
||||
}
|
||||
|
||||
// Print all files with their checksums
|
||||
outputBuilder.WriteString(" All Files:\n")
|
||||
for _, file := range processedFiles {
|
||||
outputBuilder.WriteString(fmt.Sprintf(" - %s (%s, %d bytes", file.Path, file.FileType, file.Size))
|
||||
if file.SHA256 != "" {
|
||||
outputBuilder.WriteString(fmt.Sprintf(", SHA256: %s", file.SHA256))
|
||||
}
|
||||
outputBuilder.WriteString(")\n")
|
||||
}
|
||||
|
||||
outputBuilder.WriteString("\n")
|
||||
result.Models = append(result.Models, processedModel)
|
||||
}
|
||||
|
||||
result.FormattedOutput = outputBuilder.String()
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
190
.github/gallery-agent/testing.go
vendored
Normal file
190
.github/gallery-agent/testing.go
vendored
Normal file
@@ -0,0 +1,190 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// runSyntheticMode generates synthetic test data and appends it to the gallery
|
||||
func runSyntheticMode() error {
|
||||
generator := NewSyntheticDataGenerator()
|
||||
|
||||
// Generate a random number of synthetic models (1-3)
|
||||
numModels := generator.rand.Intn(3) + 1
|
||||
fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
|
||||
|
||||
var models []ProcessedModel
|
||||
for i := 0; i < numModels; i++ {
|
||||
model := generator.GenerateProcessedModel()
|
||||
models = append(models, model)
|
||||
fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
|
||||
}
|
||||
|
||||
// Generate YAML entries and append to gallery/index.yaml
|
||||
fmt.Println("Generating YAML entries for synthetic models...")
|
||||
err := generateYAMLForModels(context.Background(), models)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error generating YAML entries: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Successfully added %d synthetic models to the gallery for testing!\n", len(models))
|
||||
return nil
|
||||
}
|
||||
|
||||
// SyntheticDataGenerator provides methods to generate synthetic test data
|
||||
type SyntheticDataGenerator struct {
|
||||
rand *rand.Rand
|
||||
}
|
||||
|
||||
// NewSyntheticDataGenerator creates a new synthetic data generator
|
||||
func NewSyntheticDataGenerator() *SyntheticDataGenerator {
|
||||
return &SyntheticDataGenerator{
|
||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateProcessedModelFile creates a synthetic ProcessedModelFile
|
||||
func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
|
||||
fileTypes := []string{"model", "readme", "other"}
|
||||
fileType := fileTypes[g.rand.Intn(len(fileTypes))]
|
||||
|
||||
var path string
|
||||
var isReadme bool
|
||||
|
||||
switch fileType {
|
||||
case "model":
|
||||
path = fmt.Sprintf("model-%s.gguf", g.randomString(8))
|
||||
isReadme = false
|
||||
case "readme":
|
||||
path = "README.md"
|
||||
isReadme = true
|
||||
default:
|
||||
path = fmt.Sprintf("file-%s.txt", g.randomString(6))
|
||||
isReadme = false
|
||||
}
|
||||
|
||||
return ProcessedModelFile{
|
||||
Path: path,
|
||||
Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB
|
||||
SHA256: g.randomSHA256(),
|
||||
IsReadme: isReadme,
|
||||
FileType: fileType,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateProcessedModel creates a synthetic ProcessedModel
|
||||
func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
|
||||
authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
|
||||
modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
|
||||
|
||||
author := authors[g.rand.Intn(len(authors))]
|
||||
modelName := modelNames[g.rand.Intn(len(modelNames))]
|
||||
modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
|
||||
|
||||
// Generate files
|
||||
numFiles := g.rand.Intn(5) + 2 // 2-6 files
|
||||
files := make([]ProcessedModelFile, numFiles)
|
||||
|
||||
// Ensure at least one model file and one readme
|
||||
hasModelFile := false
|
||||
hasReadme := false
|
||||
|
||||
for i := 0; i < numFiles; i++ {
|
||||
files[i] = g.GenerateProcessedModelFile()
|
||||
if files[i].FileType == "model" {
|
||||
hasModelFile = true
|
||||
}
|
||||
if files[i].FileType == "readme" {
|
||||
hasReadme = true
|
||||
}
|
||||
}
|
||||
|
||||
// Add required files if missing
|
||||
if !hasModelFile {
|
||||
modelFile := g.GenerateProcessedModelFile()
|
||||
modelFile.FileType = "model"
|
||||
modelFile.Path = fmt.Sprintf("%s-Q4_K_M.gguf", modelName)
|
||||
files = append(files, modelFile)
|
||||
}
|
||||
|
||||
if !hasReadme {
|
||||
readmeFile := g.GenerateProcessedModelFile()
|
||||
readmeFile.FileType = "readme"
|
||||
readmeFile.Path = "README.md"
|
||||
readmeFile.IsReadme = true
|
||||
files = append(files, readmeFile)
|
||||
}
|
||||
|
||||
// Find preferred model file
|
||||
var preferredModelFile *ProcessedModelFile
|
||||
for i := range files {
|
||||
if files[i].FileType == "model" {
|
||||
preferredModelFile = &files[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Find readme file
|
||||
var readmeFile *ProcessedModelFile
|
||||
for i := range files {
|
||||
if files[i].FileType == "readme" {
|
||||
readmeFile = &files[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
readmeContent := g.generateReadmeContent(modelName, author)
|
||||
|
||||
return ProcessedModel{
|
||||
ModelID: modelID,
|
||||
Author: author,
|
||||
Downloads: g.rand.Intn(1000000) + 1000,
|
||||
LastModified: g.randomDate(),
|
||||
Files: files,
|
||||
PreferredModelFile: preferredModelFile,
|
||||
ReadmeFile: readmeFile,
|
||||
ReadmeContent: readmeContent,
|
||||
ReadmeContentPreview: truncateString(readmeContent, 200),
|
||||
QuantizationPreferences: []string{"Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"},
|
||||
ProcessingError: "",
|
||||
}
|
||||
}
|
||||
|
||||
// Helper methods for synthetic data generation
|
||||
func (g *SyntheticDataGenerator) randomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[g.rand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (g *SyntheticDataGenerator) randomSHA256() string {
|
||||
const charset = "0123456789abcdef"
|
||||
b := make([]byte, 64)
|
||||
for i := range b {
|
||||
b[i] = charset[g.rand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func (g *SyntheticDataGenerator) randomDate() string {
|
||||
now := time.Now()
|
||||
daysAgo := g.rand.Intn(365) // Random date within last year
|
||||
pastDate := now.AddDate(0, 0, -daysAgo)
|
||||
return pastDate.Format("2006-01-02T15:04:05.000Z")
|
||||
}
|
||||
|
||||
func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string) string {
|
||||
templates := []string{
|
||||
fmt.Sprintf("# %s Model\n\nThis is a %s model developed by %s. It's designed for various natural language processing tasks including text generation, question answering, and conversation.\n\n## Features\n\n- High-quality text generation\n- Efficient inference\n- Multiple quantization options\n- Easy to use with LocalAI\n\n## Usage\n\nUse this model with LocalAI for various AI tasks.", strings.Title(modelName), modelName, author),
|
||||
fmt.Sprintf("# %s\n\nA powerful language model from %s. This model excels at understanding and generating human-like text across multiple domains.\n\n## Capabilities\n\n- Text completion\n- Code generation\n- Creative writing\n- Technical documentation\n\n## Model Details\n\n- Architecture: Transformer-based\n- Training: Large-scale supervised learning\n- Quantization: Available in multiple formats", strings.Title(modelName), author),
|
||||
fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
|
||||
}
|
||||
|
||||
return templates[g.rand.Intn(len(templates))]
|
||||
}
|
||||
46
.github/gallery-agent/tools.go
vendored
Normal file
46
.github/gallery-agent/tools.go
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/tmc/langchaingo/jsonschema"
|
||||
)
|
||||
|
||||
// Get repository README from HF
|
||||
type HFReadmeTool struct {
|
||||
client *hfapi.Client
|
||||
}
|
||||
|
||||
func (s *HFReadmeTool) Run(args map[string]any) (string, error) {
|
||||
q, ok := args["repository"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no query")
|
||||
}
|
||||
readme, err := s.client.GetReadmeContent(q, "README.md")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return readme, nil
|
||||
}
|
||||
|
||||
func (s *HFReadmeTool) Tool() openai.Tool {
|
||||
return openai.Tool{
|
||||
Type: openai.ToolTypeFunction,
|
||||
Function: &openai.FunctionDefinition{
|
||||
Name: "hf_readme",
|
||||
Description: "A tool to get the README content of a huggingface repository",
|
||||
Parameters: jsonschema.Definition{
|
||||
Type: jsonschema.Object,
|
||||
Properties: map[string]jsonschema.Definition{
|
||||
"repository": {
|
||||
Type: jsonschema.String,
|
||||
Description: "The huggingface repository to get the README content of",
|
||||
},
|
||||
},
|
||||
Required: []string{"repository"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
83
.github/workflows/backend.yml
vendored
83
.github/workflows/backend.yml
vendored
@@ -489,6 +489,18 @@ jobs:
|
||||
backend: "diffusers"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-l4t-kokoro'
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
skip-drivers: 'true'
|
||||
backend: "kokoro"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
# SYCL additional backends
|
||||
- build-type: 'intel'
|
||||
cuda-major-version: ""
|
||||
@@ -870,7 +882,7 @@ jobs:
|
||||
backend: "rfdetr"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'cublas'
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
@@ -943,6 +955,18 @@ jobs:
|
||||
backend: "exllama2"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'true'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-l4t-arm64-chatterbox'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "chatterbox"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
# runs out of space on the runner
|
||||
# - build-type: 'hipblas'
|
||||
# cuda-major-version: ""
|
||||
@@ -969,6 +993,55 @@ jobs:
|
||||
backend: "kitten-tts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
# neutts
|
||||
- build-type: ''
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64,linux/arm64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-cpu-neutts'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "neutts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'cublas'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-nvidia-cuda-12-neutts'
|
||||
runs-on: 'ubuntu-latest'
|
||||
base-image: "ubuntu:22.04"
|
||||
skip-drivers: 'false'
|
||||
backend: "neutts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'hipblas'
|
||||
cuda-major-version: ""
|
||||
cuda-minor-version: ""
|
||||
platforms: 'linux/amd64'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-gpu-rocm-hipblas-neutts'
|
||||
runs-on: 'arc-runner-set'
|
||||
base-image: "rocm/dev-ubuntu-22.04:6.4.3"
|
||||
skip-drivers: 'false'
|
||||
backend: "neutts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
- build-type: 'l4t'
|
||||
cuda-major-version: "12"
|
||||
cuda-minor-version: "0"
|
||||
platforms: 'linux/arm64'
|
||||
skip-drivers: 'true'
|
||||
tag-latest: 'auto'
|
||||
tag-suffix: '-nvidia-l4t-arm64-neutts'
|
||||
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||
runs-on: 'ubuntu-24.04-arm'
|
||||
backend: "neutts"
|
||||
dockerfile: "./backend/Dockerfile.python"
|
||||
context: "./backend"
|
||||
backend-jobs-darwin:
|
||||
uses: ./.github/workflows/backend_build_darwin.yml
|
||||
strategy:
|
||||
@@ -1036,7 +1109,7 @@ jobs:
|
||||
make protogen-go
|
||||
make backends/llama-cpp-darwin
|
||||
- name: Upload llama-cpp.tar
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: llama-cpp-tar
|
||||
path: backend-images/llama-cpp.tar
|
||||
@@ -1046,7 +1119,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Download llama-cpp.tar
|
||||
uses: actions/download-artifact@v5
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: llama-cpp-tar
|
||||
path: .
|
||||
@@ -1124,7 +1197,7 @@ jobs:
|
||||
export PLATFORMARCH=darwin/amd64
|
||||
make backends/llama-cpp-darwin
|
||||
- name: Upload llama-cpp.tar
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: llama-cpp-tar-x86
|
||||
path: backend-images/llama-cpp.tar
|
||||
@@ -1134,7 +1207,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Download llama-cpp.tar
|
||||
uses: actions/download-artifact@v5
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: llama-cpp-tar-x86
|
||||
path: .
|
||||
|
||||
4
.github/workflows/backend_build_darwin.yml
vendored
4
.github/workflows/backend_build_darwin.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
BACKEND=${{ inputs.backend }} BUILD_TYPE=${{ inputs.build-type }} USE_PIP=${{ inputs.use-pip }} make build-darwin-${{ inputs.lang }}-backend
|
||||
|
||||
- name: Upload ${{ inputs.backend }}.tar
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: ${{ inputs.backend }}-tar
|
||||
path: backend-images/${{ inputs.backend }}.tar
|
||||
@@ -85,7 +85,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Download ${{ inputs.backend }}.tar
|
||||
uses: actions/download-artifact@v5
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
name: ${{ inputs.backend }}-tar
|
||||
path: .
|
||||
|
||||
10
.github/workflows/build-test.yaml
vendored
10
.github/workflows/build-test.yaml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.23
|
||||
go-version: 1.25
|
||||
- name: Run GoReleaser
|
||||
run: |
|
||||
make dev-dist
|
||||
@@ -31,13 +31,13 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.23
|
||||
go-version: 1.25
|
||||
- name: Build launcher for macOS ARM64
|
||||
run: |
|
||||
make build-launcher-darwin
|
||||
ls -liah dist
|
||||
- name: Upload macOS launcher artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: launcher-macos
|
||||
path: dist/
|
||||
@@ -53,14 +53,14 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 1.23
|
||||
go-version: 1.25
|
||||
- name: Build launcher for Linux
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev
|
||||
make build-launcher-linux
|
||||
- name: Upload Linux launcher artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: launcher-linux
|
||||
path: local-ai-launcher-linux.tar.xz
|
||||
|
||||
126
.github/workflows/gallery-agent.yaml
vendored
Normal file
126
.github/workflows/gallery-agent.yaml
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
name: Gallery Agent
|
||||
on:
|
||||
|
||||
schedule:
|
||||
- cron: '0 */1 * * *' # Run every 4 hours
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
search_term:
|
||||
description: 'Search term for models'
|
||||
required: false
|
||||
default: 'GGUF'
|
||||
type: string
|
||||
limit:
|
||||
description: 'Maximum number of models to process'
|
||||
required: false
|
||||
default: '15'
|
||||
type: string
|
||||
quantization:
|
||||
description: 'Preferred quantization format'
|
||||
required: false
|
||||
default: 'Q4_K_M'
|
||||
type: string
|
||||
max_models:
|
||||
description: 'Maximum number of models to add to the gallery'
|
||||
required: false
|
||||
default: '1'
|
||||
type: string
|
||||
jobs:
|
||||
gallery-agent:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.21'
|
||||
|
||||
- name: Build gallery agent
|
||||
run: |
|
||||
cd .github/gallery-agent
|
||||
go mod download
|
||||
go build -o gallery-agent .
|
||||
|
||||
- name: Run gallery agent
|
||||
env:
|
||||
OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||
OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
|
||||
OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}
|
||||
SEARCH_TERM: ${{ github.event.inputs.search_term || 'GGUF' }}
|
||||
LIMIT: ${{ github.event.inputs.limit || '15' }}
|
||||
QUANTIZATION: ${{ github.event.inputs.quantization || 'Q4_K_M' }}
|
||||
MAX_MODELS: ${{ github.event.inputs.max_models || '1' }}
|
||||
run: |
|
||||
export GALLERY_INDEX_PATH=$PWD/gallery/index.yaml
|
||||
cd .github/gallery-agent
|
||||
./gallery-agent
|
||||
rm -rf gallery-agent
|
||||
|
||||
- name: Check for changes
|
||||
id: check_changes
|
||||
run: |
|
||||
if git diff --quiet gallery/index.yaml; then
|
||||
echo "changes=false" >> $GITHUB_OUTPUT
|
||||
echo "No changes detected in gallery/index.yaml"
|
||||
else
|
||||
echo "changes=true" >> $GITHUB_OUTPUT
|
||||
echo "Changes detected in gallery/index.yaml"
|
||||
git diff gallery/index.yaml
|
||||
fi
|
||||
|
||||
- name: Read gallery agent summary
|
||||
id: read_summary
|
||||
if: steps.check_changes.outputs.changes == 'true'
|
||||
run: |
|
||||
if [ -f ".github/gallery-agent/gallery-agent-summary.json" ]; then
|
||||
echo "summary_exists=true" >> $GITHUB_OUTPUT
|
||||
# Extract summary data using jq
|
||||
echo "search_term=$(jq -r '.search_term' .github/gallery-agent/gallery-agent-summary.json)" >> $GITHUB_OUTPUT
|
||||
echo "total_found=$(jq -r '.total_found' .github/gallery-agent/gallery-agent-summary.json)" >> $GITHUB_OUTPUT
|
||||
echo "models_added=$(jq -r '.models_added' .github/gallery-agent/gallery-agent-summary.json)" >> $GITHUB_OUTPUT
|
||||
echo "quantization=$(jq -r '.quantization' .github/gallery-agent/gallery-agent-summary.json)" >> $GITHUB_OUTPUT
|
||||
echo "processing_time=$(jq -r '.processing_time' .github/gallery-agent/gallery-agent-summary.json)" >> $GITHUB_OUTPUT
|
||||
|
||||
# Create a formatted list of added models with URLs
|
||||
added_models=$(jq -r 'range(0; .added_model_ids | length) as $i | "- [\(.added_model_ids[$i])](\(.added_model_urls[$i]))"' .github/gallery-agent/gallery-agent-summary.json | tr '\n' '\n')
|
||||
echo "added_models<<EOF" >> $GITHUB_OUTPUT
|
||||
echo "$added_models" >> $GITHUB_OUTPUT
|
||||
echo "EOF" >> $GITHUB_OUTPUT
|
||||
rm -f .github/gallery-agent/gallery-agent-summary.json
|
||||
else
|
||||
echo "summary_exists=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.check_changes.outputs.changes == 'true'
|
||||
uses: peter-evans/create-pull-request@v7
|
||||
with:
|
||||
token: ${{ secrets.UPDATE_BOT_TOKEN }}
|
||||
push-to-fork: ci-forks/LocalAI
|
||||
commit-message: 'chore(model gallery): :robot: add new models via gallery agent'
|
||||
title: 'chore(model gallery): :robot: add ${{ steps.read_summary.outputs.models_added || 0 }} new models via gallery agent'
|
||||
# Branch has to be unique so PRs are not overriding each other
|
||||
branch-suffix: timestamp
|
||||
body: |
|
||||
This PR was automatically created by the gallery agent workflow.
|
||||
|
||||
**Summary:**
|
||||
- **Search Term:** ${{ steps.read_summary.outputs.search_term || github.event.inputs.search_term || 'GGUF' }}
|
||||
- **Models Found:** ${{ steps.read_summary.outputs.total_found || 'N/A' }}
|
||||
- **Models Added:** ${{ steps.read_summary.outputs.models_added || '0' }}
|
||||
- **Quantization:** ${{ steps.read_summary.outputs.quantization || github.event.inputs.quantization || 'Q4_K_M' }}
|
||||
- **Processing Time:** ${{ steps.read_summary.outputs.processing_time || 'N/A' }}
|
||||
|
||||
**Added Models:**
|
||||
${{ steps.read_summary.outputs.added_models || '- No models added' }}
|
||||
|
||||
**Workflow Details:**
|
||||
- Triggered by: `${{ github.event_name }}`
|
||||
- Run ID: `${{ github.run_id }}`
|
||||
- Commit: `${{ github.sha }}`
|
||||
signoff: true
|
||||
delete-branch: true
|
||||
2
.github/workflows/localaibot_automerge.yml
vendored
2
.github/workflows/localaibot_automerge.yml
vendored
@@ -11,7 +11,7 @@ permissions:
|
||||
jobs:
|
||||
dependabot:
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.actor == 'localai-bot' }}
|
||||
if: ${{ github.actor == 'localai-bot' && !contains(github.event.pull_request.title, 'chore(model gallery):') }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
|
||||
18
.github/workflows/notify-models.yaml
vendored
18
.github/workflows/notify-models.yaml
vendored
@@ -1,22 +1,27 @@
|
||||
name: Notifications for new models
|
||||
on:
|
||||
pull_request:
|
||||
pull_request_target:
|
||||
types:
|
||||
- closed
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
notify-discord:
|
||||
if: ${{ (github.event.pull_request.merged == true) && (contains(github.event.pull_request.labels.*.name, 'area/ai-model')) }}
|
||||
env:
|
||||
MODEL_NAME: gemma-3-12b-it
|
||||
MODEL_NAME: gemma-3-12b-it-qat
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0 # needed to checkout all branches for this Action to work
|
||||
ref: ${{ github.event.pull_request.head.sha }} # Checkout the PR head to get the actual changes
|
||||
- uses: mudler/localai-github-action@v1
|
||||
with:
|
||||
model: 'gemma-3-12b-it' # Any from models.localai.io, or from huggingface.com with: "huggingface://<repository>/file"
|
||||
model: 'gemma-3-12b-it-qat' # Any from models.localai.io, or from huggingface.com with: "huggingface://<repository>/file"
|
||||
# Check the PR diff using the current branch and the base branch of the PR
|
||||
- uses: GrantBirki/git-diff-action@v2.8.1
|
||||
id: git-diff-action
|
||||
@@ -79,7 +84,7 @@ jobs:
|
||||
args: ${{ steps.summarize.outputs.message }}
|
||||
- name: Setup tmate session if fails
|
||||
if: ${{ failure() }}
|
||||
uses: mxschmitt/action-tmate@v3.22
|
||||
uses: mxschmitt/action-tmate@v3.23
|
||||
with:
|
||||
detached: true
|
||||
connect-timeout-seconds: 180
|
||||
@@ -87,12 +92,13 @@ jobs:
|
||||
notify-twitter:
|
||||
if: ${{ (github.event.pull_request.merged == true) && (contains(github.event.pull_request.labels.*.name, 'area/ai-model')) }}
|
||||
env:
|
||||
MODEL_NAME: gemma-3-12b-it
|
||||
MODEL_NAME: gemma-3-12b-it-qat
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
with:
|
||||
fetch-depth: 0 # needed to checkout all branches for this Action to work
|
||||
ref: ${{ github.event.pull_request.head.sha }} # Checkout the PR head to get the actual changes
|
||||
- name: Start LocalAI
|
||||
run: |
|
||||
echo "Starting LocalAI..."
|
||||
@@ -161,7 +167,7 @@ jobs:
|
||||
TWITTER_ACCESS_TOKEN_SECRET: ${{ secrets.TWITTER_ACCESS_TOKEN_SECRET }}
|
||||
- name: Setup tmate session if fails
|
||||
if: ${{ failure() }}
|
||||
uses: mxschmitt/action-tmate@v3.22
|
||||
uses: mxschmitt/action-tmate@v3.23
|
||||
with:
|
||||
detached: true
|
||||
connect-timeout-seconds: 180
|
||||
|
||||
3
.github/workflows/notify-releases.yaml
vendored
3
.github/workflows/notify-releases.yaml
vendored
@@ -11,10 +11,11 @@ jobs:
|
||||
RELEASE_BODY: ${{ github.event.release.body }}
|
||||
RELEASE_TITLE: ${{ github.event.release.name }}
|
||||
RELEASE_TAG_NAME: ${{ github.event.release.tag_name }}
|
||||
MODEL_NAME: gemma-3-12b-it-qat
|
||||
steps:
|
||||
- uses: mudler/localai-github-action@v1
|
||||
with:
|
||||
model: 'gemma-3-12b-it' # Any from models.localai.io, or from huggingface.com with: "huggingface://<repository>/file"
|
||||
model: 'gemma-3-12b-it-qat' # Any from models.localai.io, or from huggingface.com with: "huggingface://<repository>/file"
|
||||
- name: Summarize
|
||||
id: summarize
|
||||
run: |
|
||||
|
||||
4
.github/workflows/secscan.yaml
vendored
4
.github/workflows/secscan.yaml
vendored
@@ -18,13 +18,13 @@ jobs:
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
- name: Run Gosec Security Scanner
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
uses: securego/gosec@v2.22.8
|
||||
uses: securego/gosec@v2.22.9
|
||||
with:
|
||||
# we let the report trigger content trigger a failure using the GitHub Security features.
|
||||
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
||||
- name: Upload SARIF file
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
uses: github/codeql-action/upload-sarif@v3
|
||||
uses: github/codeql-action/upload-sarif@v4
|
||||
with:
|
||||
# Path to SARIF file relative to the root of the repository
|
||||
sarif_file: results.sarif
|
||||
|
||||
2
.github/workflows/stalebot.yml
vendored
2
.github/workflows/stalebot.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@3a9db7e6a41a89f618792c92c0e97cc736e1b13f # v9
|
||||
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v9
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 10 days.'
|
||||
|
||||
10
.github/workflows/test.yml
vendored
10
.github/workflows/test.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ['1.21.x']
|
||||
go-version: ['1.25.x']
|
||||
steps:
|
||||
- name: Free Disk Space (Ubuntu)
|
||||
uses: jlumbroso/free-disk-space@main
|
||||
@@ -124,7 +124,7 @@ jobs:
|
||||
PATH="$PATH:/root/go/bin" GO_TAGS="tts" make --jobs 5 --output-sync=target test
|
||||
- name: Setup tmate session if tests fail
|
||||
if: ${{ failure() }}
|
||||
uses: mxschmitt/action-tmate@v3.22
|
||||
uses: mxschmitt/action-tmate@v3.23
|
||||
with:
|
||||
detached: true
|
||||
connect-timeout-seconds: 180
|
||||
@@ -183,7 +183,7 @@ jobs:
|
||||
PATH="$PATH:$HOME/go/bin" make backends/local-store backends/silero-vad backends/llama-cpp backends/whisper backends/piper backends/stablediffusion-ggml docker-build-aio e2e-aio
|
||||
- name: Setup tmate session if tests fail
|
||||
if: ${{ failure() }}
|
||||
uses: mxschmitt/action-tmate@v3.22
|
||||
uses: mxschmitt/action-tmate@v3.23
|
||||
with:
|
||||
detached: true
|
||||
connect-timeout-seconds: 180
|
||||
@@ -193,7 +193,7 @@ jobs:
|
||||
runs-on: macOS-14
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ['1.21.x']
|
||||
go-version: ['1.25.x']
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v5
|
||||
@@ -226,7 +226,7 @@ jobs:
|
||||
PATH="$PATH:$HOME/go/bin" BUILD_TYPE="GITHUB_CI_HAS_BROKEN_METAL" CMAKE_ARGS="-DGGML_F16C=OFF -DGGML_AVX512=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF" make --jobs 4 --output-sync=target test
|
||||
- name: Setup tmate session if tests fail
|
||||
if: ${{ failure() }}
|
||||
uses: mxschmitt/action-tmate@v3.22
|
||||
uses: mxschmitt/action-tmate@v3.23
|
||||
with:
|
||||
detached: true
|
||||
connect-timeout-seconds: 180
|
||||
|
||||
10
Dockerfile
10
Dockerfile
@@ -78,6 +78,16 @@ RUN <<EOT bash
|
||||
fi
|
||||
EOT
|
||||
|
||||
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||
RUN <<EOT bash
|
||||
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu2204-0.6.0_0.6.0-1_arm64.deb && \
|
||||
dpkg -i cudss-local-tegra-repo-ubuntu2204-0.6.0_0.6.0-1_arm64.deb && \
|
||||
cp /var/cudss-local-tegra-repo-ubuntu2204-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||
apt-get update && apt-get -y install cudss
|
||||
fi
|
||||
EOT
|
||||
|
||||
# If we are building with clblas support, we need the libraries for the builds
|
||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||
apt-get update && \
|
||||
|
||||
12
Makefile
12
Makefile
@@ -376,6 +376,9 @@ backends/llama-cpp-darwin: build
|
||||
bash ./scripts/build/llama-cpp-darwin.sh
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/llama-cpp.tar)"
|
||||
|
||||
backends/neutts: docker-build-neutts docker-save-neutts build
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/neutts.tar)"
|
||||
|
||||
build-darwin-python-backend: build
|
||||
bash ./scripts/build/python-darwin.sh
|
||||
|
||||
@@ -429,6 +432,15 @@ docker-build-kitten-tts:
|
||||
docker-save-kitten-tts: backend-images
|
||||
docker save local-ai-backend:kitten-tts -o backend-images/kitten-tts.tar
|
||||
|
||||
docker-save-chatterbox: backend-images
|
||||
docker save local-ai-backend:chatterbox -o backend-images/chatterbox.tar
|
||||
|
||||
docker-build-neutts:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:neutts -f backend/Dockerfile.python --build-arg BACKEND=neutts ./backend
|
||||
|
||||
docker-save-neutts: backend-images
|
||||
docker save local-ai-backend:neutts -o backend-images/neutts.tar
|
||||
|
||||
docker-build-kokoro:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:kokoro -f backend/Dockerfile.python --build-arg BACKEND=kokoro ./backend
|
||||
|
||||
|
||||
20
README.md
20
README.md
@@ -118,6 +118,13 @@ For more installation options, see [Installer Options](https://localai.io/docs/a
|
||||
|
||||
Or run with docker:
|
||||
|
||||
> **💡 Docker Run vs Docker Start**
|
||||
>
|
||||
> - `docker run` creates and starts a new container. If a container with the same name already exists, this command will fail.
|
||||
> - `docker start` starts an existing container that was previously created with `docker run`.
|
||||
>
|
||||
> If you've already run LocalAI before and want to start it again, use: `docker start -i local-ai`
|
||||
|
||||
### CPU only image:
|
||||
|
||||
```bash
|
||||
@@ -197,6 +204,8 @@ For more information, see [💻 Getting started](https://localai.io/basics/getti
|
||||
|
||||
## 📰 Latest project news
|
||||
|
||||
- October 2025: 🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support added for agentic capabilities with external tools
|
||||
- September 2025: New Launcher application for MacOS and Linux, extended support to many backends for Mac and Nvidia L4T devices. Models: Added MLX-Audio, WAN 2.2. WebUI improvements and Python-based backends now ships portable python environments.
|
||||
- August 2025: MLX, MLX-VLM, Diffusers and llama.cpp are now supported on Mac M1/M2/M3+ chips ( with `development` suffix in the gallery ): https://github.com/mudler/LocalAI/pull/6049 https://github.com/mudler/LocalAI/pull/6119 https://github.com/mudler/LocalAI/pull/6121 https://github.com/mudler/LocalAI/pull/6060
|
||||
- July/August 2025: 🔍 [Object Detection](https://localai.io/features/object-detection/) added to the API featuring [rf-detr](https://github.com/roboflow/rf-detr)
|
||||
- July 2025: All backends migrated outside of the main binary. LocalAI is now more lightweight, small, and automatically downloads the required backend to run the model. [Read the release notes](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
|
||||
@@ -235,7 +244,7 @@ Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3A
|
||||
- 🔍 [Object Detection](https://localai.io/features/object-detection/)
|
||||
- 📈 [Reranker API](https://localai.io/features/reranker/)
|
||||
- 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/)
|
||||
- [Agentic capabilities](https://github.com/mudler/LocalAGI)
|
||||
- 🆕🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) - Agentic capabilities with external tools and [LocalAGI's Agentic capabilities](https://github.com/mudler/LocalAGI)
|
||||
- 🔊 Voice activity detection (Silero-VAD support)
|
||||
- 🌍 Integrated WebUI!
|
||||
|
||||
@@ -266,6 +275,7 @@ LocalAI supports a comprehensive range of AI backends with multiple acceleration
|
||||
| **piper** | Fast neural TTS system | CPU |
|
||||
| **kitten-tts** | Kitten TTS models | CPU |
|
||||
| **silero-vad** | Voice Activity Detection | CPU |
|
||||
| **neutts** | Text-to-speech with voice cloning | CUDA 12, ROCm, CPU |
|
||||
|
||||
### Image & Video Generation
|
||||
| Backend | Description | Acceleration Support |
|
||||
@@ -287,7 +297,7 @@ LocalAI supports a comprehensive range of AI backends with multiple acceleration
|
||||
|-------------------|-------------------|------------------|
|
||||
| **NVIDIA CUDA 11** | llama.cpp, whisper, stablediffusion, diffusers, rerankers, bark, chatterbox | Nvidia hardware |
|
||||
| **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware |
|
||||
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, bark | AMD Graphics |
|
||||
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, bark, neutts | AMD Graphics |
|
||||
| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, exllama2, coqui, kokoro, bark | Intel Arc, Intel iGPUs |
|
||||
| **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, bark-cpp | Apple M1/M2/M3+ |
|
||||
| **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs |
|
||||
@@ -304,6 +314,12 @@ WebUIs:
|
||||
- https://github.com/go-skynet/LocalAI-frontend
|
||||
- QA-Pilot(An interactive chat project that leverages LocalAI LLMs for rapid understanding and navigation of GitHub code repository) https://github.com/reid41/QA-Pilot
|
||||
|
||||
Agentic Libraries:
|
||||
- https://github.com/mudler/cogito
|
||||
|
||||
MCPs:
|
||||
- https://github.com/mudler/MCPs
|
||||
|
||||
Model galleries
|
||||
- https://github.com/go-skynet/model-gallery
|
||||
|
||||
|
||||
@@ -197,7 +197,7 @@ EOT
|
||||
|
||||
|
||||
# Copy libraries using a script to handle architecture differences
|
||||
RUN make -C /LocalAI/backend/cpp/llama-cpp package
|
||||
RUN make -BC /LocalAI/backend/cpp/llama-cpp package
|
||||
|
||||
|
||||
FROM scratch
|
||||
|
||||
@@ -28,7 +28,7 @@ RUN apt-get update && \
|
||||
curl python3-pip \
|
||||
python-is-python3 \
|
||||
python3-dev llvm \
|
||||
python3-venv make && \
|
||||
python3-venv make cmake && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
pip install --upgrade pip
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=f432d8d83e7407073634c5e4fd81a3d23a10827f
|
||||
LLAMA_VERSION?=5a4ff43e7dd049e35942bc3d12361dab2f155544
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
@@ -14,7 +14,7 @@ CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=OFF
|
||||
|
||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||
ifeq ($(NATIVE),false)
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF -DLLAMA_OPENSSL=OFF
|
||||
endif
|
||||
# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
|
||||
ifeq ($(BUILD_TYPE),cublas)
|
||||
|
||||
@@ -92,7 +92,7 @@ static void start_llama_server(server_context& ctx_server) {
|
||||
ctx_server.queue_tasks.start_loop();
|
||||
}
|
||||
|
||||
json parse_options(bool streaming, const backend::PredictOptions* predict)
|
||||
json parse_options(bool streaming, const backend::PredictOptions* predict, const server_context& ctx_server)
|
||||
{
|
||||
|
||||
// Create now a json data from the prediction options instead
|
||||
@@ -147,6 +147,28 @@ json parse_options(bool streaming, const backend::PredictOptions* predict)
|
||||
// data["n_probs"] = predict->nprobs();
|
||||
//TODO: images,
|
||||
|
||||
// Serialize grammar triggers from server context to JSON array
|
||||
if (!ctx_server.params_base.sampling.grammar_triggers.empty()) {
|
||||
json grammar_triggers = json::array();
|
||||
for (const auto& trigger : ctx_server.params_base.sampling.grammar_triggers) {
|
||||
json trigger_json;
|
||||
trigger_json["value"] = trigger.value;
|
||||
// Always serialize as WORD type since upstream converts WORD to TOKEN internally
|
||||
trigger_json["type"] = static_cast<int>(COMMON_GRAMMAR_TRIGGER_TYPE_WORD);
|
||||
grammar_triggers.push_back(trigger_json);
|
||||
}
|
||||
data["grammar_triggers"] = grammar_triggers;
|
||||
}
|
||||
|
||||
// Serialize preserved tokens from server context to JSON array
|
||||
if (!ctx_server.params_base.sampling.preserved_tokens.empty()) {
|
||||
json preserved_tokens = json::array();
|
||||
for (const auto& token : ctx_server.params_base.sampling.preserved_tokens) {
|
||||
preserved_tokens.push_back(common_token_to_piece(ctx_server.ctx, token));
|
||||
}
|
||||
data["preserved_tokens"] = preserved_tokens;
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
@@ -207,7 +229,7 @@ static void add_rpc_devices(std::string servers) {
|
||||
}
|
||||
}
|
||||
|
||||
static void params_parse(const backend::ModelOptions* request,
|
||||
static void params_parse(server_context& ctx_server, const backend::ModelOptions* request,
|
||||
common_params & params) {
|
||||
|
||||
// this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809
|
||||
@@ -231,6 +253,7 @@ static void params_parse(const backend::ModelOptions* request,
|
||||
params.cpuparams.n_threads = request->threads();
|
||||
params.n_gpu_layers = request->ngpulayers();
|
||||
params.n_batch = request->nbatch();
|
||||
params.n_ubatch = request->nbatch(); // fixes issue with reranking models being limited to 512 tokens (the default n_ubatch size); allows for setting the maximum input amount of tokens thereby avoiding this error "input is too large to process. increase the physical batch size"
|
||||
// Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1
|
||||
//params.n_parallel = 1;
|
||||
const char *env_parallel = std::getenv("LLAMACPP_PARALLEL");
|
||||
@@ -268,6 +291,11 @@ static void params_parse(const backend::ModelOptions* request,
|
||||
}
|
||||
}
|
||||
|
||||
if (!params.kv_overrides.empty()) {
|
||||
params.kv_overrides.emplace_back();
|
||||
params.kv_overrides.back().key[0] = 0;
|
||||
}
|
||||
|
||||
// TODO: Add yarn
|
||||
|
||||
if (!request->tensorsplit().empty()) {
|
||||
@@ -346,14 +374,14 @@ static void params_parse(const backend::ModelOptions* request,
|
||||
}
|
||||
|
||||
if (request->grammartriggers_size() > 0) {
|
||||
params.sampling.grammar_lazy = true;
|
||||
//params.sampling.grammar_lazy = true;
|
||||
// Store grammar trigger words for processing after model is loaded
|
||||
for (int i = 0; i < request->grammartriggers_size(); i++) {
|
||||
const auto & word = request->grammartriggers(i).word();
|
||||
common_grammar_trigger trigger;
|
||||
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
|
||||
trigger.value = request->grammartriggers(i).word();
|
||||
// trigger.at_start = request->grammartriggers(i).at_start();
|
||||
params.sampling.grammar_triggers.push_back(trigger);
|
||||
|
||||
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
|
||||
trigger.value = word;
|
||||
params.sampling.grammar_triggers.push_back(std::move(trigger));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -376,7 +404,7 @@ public:
|
||||
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) {
|
||||
// Implement LoadModel RPC
|
||||
common_params params;
|
||||
params_parse(request, params);
|
||||
params_parse(ctx_server, request, params);
|
||||
|
||||
common_init();
|
||||
|
||||
@@ -395,6 +423,39 @@ public:
|
||||
return Status::CANCELLED;
|
||||
}
|
||||
|
||||
// Process grammar triggers now that vocab is available
|
||||
if (!params.sampling.grammar_triggers.empty()) {
|
||||
std::vector<common_grammar_trigger> processed_triggers;
|
||||
for (const auto& trigger : params.sampling.grammar_triggers) {
|
||||
if (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
||||
auto ids = common_tokenize(ctx_server.vocab, trigger.value, /* add_special= */ false, /* parse_special= */ true);
|
||||
if (ids.size() == 1) {
|
||||
auto token = ids[0];
|
||||
// Add the token to preserved_tokens if not already present
|
||||
if (params.sampling.preserved_tokens.find(token) == params.sampling.preserved_tokens.end()) {
|
||||
params.sampling.preserved_tokens.insert(token);
|
||||
LOG_INF("Added grammar trigger token to preserved tokens: %d (`%s`)\n", token, trigger.value.c_str());
|
||||
}
|
||||
LOG_INF("Grammar trigger token: %d (`%s`)\n", token, trigger.value.c_str());
|
||||
common_grammar_trigger processed_trigger;
|
||||
processed_trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
|
||||
processed_trigger.value = trigger.value;
|
||||
processed_trigger.token = token;
|
||||
processed_triggers.push_back(std::move(processed_trigger));
|
||||
} else {
|
||||
LOG_INF("Grammar trigger word: `%s`\n", trigger.value.c_str());
|
||||
processed_triggers.push_back(trigger);
|
||||
}
|
||||
} else {
|
||||
processed_triggers.push_back(trigger);
|
||||
}
|
||||
}
|
||||
// Update the grammar triggers in params_base
|
||||
ctx_server.params_base.sampling.grammar_triggers = std::move(processed_triggers);
|
||||
// Also update preserved_tokens in params_base
|
||||
ctx_server.params_base.sampling.preserved_tokens = params.sampling.preserved_tokens;
|
||||
}
|
||||
|
||||
//ctx_server.init();
|
||||
result->set_message("Loading succeeded");
|
||||
result->set_success(true);
|
||||
@@ -405,7 +466,7 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||
json data = parse_options(true, request);
|
||||
json data = parse_options(true, request, ctx_server);
|
||||
|
||||
|
||||
//Raise error if embeddings is set to true
|
||||
@@ -468,12 +529,12 @@ public:
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
|
||||
task.prompt_tokens = std::move(inputs[i]);
|
||||
task.tokens = std::move(inputs[i]);
|
||||
task.params = server_task::params_from_json_cmpl(
|
||||
ctx_server.ctx,
|
||||
ctx_server.params_base,
|
||||
data);
|
||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
@@ -555,7 +616,7 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
|
||||
json data = parse_options(true, request);
|
||||
json data = parse_options(true, request, ctx_server);
|
||||
|
||||
data["stream"] = false;
|
||||
//Raise error if embeddings is set to true
|
||||
@@ -623,12 +684,12 @@ public:
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
|
||||
task.prompt_tokens = std::move(inputs[i]);
|
||||
task.tokens = std::move(inputs[i]);
|
||||
task.params = server_task::params_from_json_cmpl(
|
||||
ctx_server.ctx,
|
||||
ctx_server.params_base,
|
||||
data);
|
||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||
task.id_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
@@ -690,7 +751,7 @@ public:
|
||||
|
||||
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) {
|
||||
|
||||
json body = parse_options(false, request);
|
||||
json body = parse_options(false, request, ctx_server);
|
||||
|
||||
body["stream"] = false;
|
||||
|
||||
@@ -724,7 +785,7 @@ public:
|
||||
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||
task.tokens = std::move(tokenized_prompts[i]);
|
||||
|
||||
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
|
||||
task.params.embd_normalize = embd_normalize;
|
||||
@@ -801,11 +862,6 @@ public:
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
|
||||
}
|
||||
|
||||
// Tokenize the query
|
||||
auto tokenized_query = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, request->query(), /* add_special */ false, true);
|
||||
if (tokenized_query.size() != 1) {
|
||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must contain only a single prompt");
|
||||
}
|
||||
// Create and queue the task
|
||||
json responses = json::array();
|
||||
bool error = false;
|
||||
@@ -817,14 +873,13 @@ public:
|
||||
documents.push_back(request->documents(i));
|
||||
}
|
||||
|
||||
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
|
||||
tasks.reserve(tokenized_docs.size());
|
||||
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
||||
auto tmp = format_rerank(ctx_server.vocab, tokenized_query[0], tokenized_docs[i]);
|
||||
tasks.reserve(documents.size());
|
||||
for (size_t i = 0; i < documents.size(); i++) {
|
||||
auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, request->query(), documents[i]);
|
||||
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
task.index = i;
|
||||
task.prompt_tokens = std::move(tmp);
|
||||
task.tokens = std::move(tmp);
|
||||
tasks.push_back(std::move(task));
|
||||
}
|
||||
|
||||
@@ -877,7 +932,7 @@ public:
|
||||
}
|
||||
|
||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
|
||||
json body = parse_options(false, request);
|
||||
json body = parse_options(false, request, ctx_server);
|
||||
body["stream"] = false;
|
||||
|
||||
json tokens_response = json::array();
|
||||
|
||||
@@ -8,7 +8,8 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=44fa2f647cf2a6953493b21ab83b50d5f5dbc483
|
||||
WHISPER_CPP_VERSION?=f16c12f3f55f5bd3d6ac8cf2f31ab90a42c884d5
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -57,15 +58,18 @@ sources/whisper.cpp:
|
||||
git checkout $(WHISPER_CPP_VERSION) && \
|
||||
git submodule update --init --recursive --depth 1 --single-branch
|
||||
|
||||
libgowhisper.so: sources/whisper.cpp CMakeLists.txt gowhisper.cpp gowhisper.h
|
||||
mkdir -p build && \
|
||||
cd build && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build/libgowhisper.so ./
|
||||
# Detect OS
|
||||
UNAME_S := $(shell uname -s)
|
||||
|
||||
whisper: main.go gowhisper.go libgowhisper.so
|
||||
# Only build CPU variants on Linux
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
VARIANT_TARGETS = libgowhisper-avx.so libgowhisper-avx2.so libgowhisper-avx512.so libgowhisper-fallback.so
|
||||
else
|
||||
# On non-Linux (e.g., Darwin), build only fallback variant
|
||||
VARIANT_TARGETS = libgowhisper-fallback.so
|
||||
endif
|
||||
|
||||
whisper: main.go gowhisper.go $(VARIANT_TARGETS)
|
||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o whisper ./
|
||||
|
||||
package: whisper
|
||||
@@ -73,5 +77,46 @@ package: whisper
|
||||
|
||||
build: package
|
||||
|
||||
clean:
|
||||
rm -rf libgowhisper.o build whisper
|
||||
clean: purge
|
||||
rm -rf libgowhisper*.so sources/whisper.cpp whisper
|
||||
|
||||
purge:
|
||||
rm -rf build*
|
||||
|
||||
# Build all variants (Linux only)
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
libgowhisper-avx.so: sources/whisper.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I whisper build info:avx${RESET})
|
||||
SO_TARGET=libgowhisper-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off" $(MAKE) libgowhisper-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgowhisper-avx2.so: sources/whisper.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I whisper build info:avx2${RESET})
|
||||
SO_TARGET=libgowhisper-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) libgowhisper-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgowhisper-avx512.so: sources/whisper.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I whisper build info:avx512${RESET})
|
||||
SO_TARGET=libgowhisper-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) libgowhisper-custom
|
||||
rm -rfv build*
|
||||
endif
|
||||
|
||||
# Build fallback variant (all platforms)
|
||||
libgowhisper-fallback.so: sources/whisper.cpp
|
||||
$(MAKE) purge
|
||||
$(info ${GREEN}I whisper build info:fallback${RESET})
|
||||
SO_TARGET=libgowhisper-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off" $(MAKE) libgowhisper-custom
|
||||
rm -rfv build*
|
||||
|
||||
libgowhisper-custom: CMakeLists.txt gowhisper.cpp gowhisper.h
|
||||
mkdir -p build-$(SO_TARGET) && \
|
||||
cd build-$(SO_TARGET) && \
|
||||
cmake .. $(CMAKE_ARGS) && \
|
||||
cmake --build . --config Release -j$(JOBS) && \
|
||||
cd .. && \
|
||||
mv build-$(SO_TARGET)/libgowhisper.so ./$(SO_TARGET)
|
||||
|
||||
all: whisper package
|
||||
|
||||
@@ -3,6 +3,7 @@ package main
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
import (
|
||||
"flag"
|
||||
"os"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
@@ -18,7 +19,13 @@ type LibFuncs struct {
|
||||
}
|
||||
|
||||
func main() {
|
||||
gosd, err := purego.Dlopen("./libgowhisper.so", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
// Get library name from environment variable, default to fallback
|
||||
libName := os.Getenv("WHISPER_LIBRARY")
|
||||
if libName == "" {
|
||||
libName = "./libgowhisper-fallback.so"
|
||||
}
|
||||
|
||||
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,8 @@ CURDIR=$(dirname "$(realpath $0)")
|
||||
# Create lib directory
|
||||
mkdir -p $CURDIR/package/lib
|
||||
|
||||
cp -avf $CURDIR/whisper $CURDIR/libgowhisper.so $CURDIR/package/
|
||||
cp -avf $CURDIR/whisper $CURDIR/package/
|
||||
cp -fv $CURDIR/libgowhisper-*.so $CURDIR/package/
|
||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||
|
||||
# Detect architecture and copy appropriate libraries
|
||||
|
||||
@@ -1,14 +1,52 @@
|
||||
#!/bin/bash
|
||||
set -ex
|
||||
|
||||
# Get the absolute current dir where the script is located
|
||||
CURDIR=$(dirname "$(realpath $0)")
|
||||
|
||||
cd /
|
||||
|
||||
echo "CPU info:"
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||
grep -e "flags" /proc/cpuinfo | head -1
|
||||
fi
|
||||
|
||||
LIBRARY="$CURDIR/libgowhisper-fallback.so"
|
||||
|
||||
if [ "$(uname)" != "Darwin" ]; then
|
||||
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX found OK"
|
||||
if [ -e $CURDIR/libgowhisper-avx.so ]; then
|
||||
LIBRARY="$CURDIR/libgowhisper-avx.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX2 found OK"
|
||||
if [ -e $CURDIR/libgowhisper-avx2.so ]; then
|
||||
LIBRARY="$CURDIR/libgowhisper-avx2.so"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check avx 512
|
||||
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||
echo "CPU: AVX512F found OK"
|
||||
if [ -e $CURDIR/libgowhisper-avx512.so ]; then
|
||||
LIBRARY="$CURDIR/libgowhisper-avx512.so"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||
export WHISPER_LIBRARY=$LIBRARY
|
||||
|
||||
# If there is a lib/ld.so, use it
|
||||
if [ -f $CURDIR/lib/ld.so ]; then
|
||||
echo "Using lib/ld.so"
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/lib/ld.so $CURDIR/whisper "$@"
|
||||
fi
|
||||
|
||||
echo "Using library: $LIBRARY"
|
||||
exec $CURDIR/whisper "$@"
|
||||
@@ -270,6 +270,7 @@
|
||||
nvidia: "cuda12-kokoro"
|
||||
intel: "intel-kokoro"
|
||||
amd: "rocm-kokoro"
|
||||
nvidia-l4t: "nvidia-l4t-kokoro"
|
||||
- &coqui
|
||||
urls:
|
||||
- https://github.com/idiap/coqui-ai-TTS
|
||||
@@ -352,6 +353,7 @@
|
||||
nvidia: "cuda12-chatterbox"
|
||||
metal: "metal-chatterbox"
|
||||
default: "cpu-chatterbox"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-chatterbox"
|
||||
- &piper
|
||||
name: "piper"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-piper"
|
||||
@@ -425,6 +427,68 @@
|
||||
- text-to-speech
|
||||
- TTS
|
||||
license: apache-2.0
|
||||
- &neutts
|
||||
name: "neutts"
|
||||
urls:
|
||||
- https://github.com/neuphonic/neutts-air
|
||||
description: |
|
||||
NeuTTS Air is the world’s first super-realistic, on-device, TTS speech language model with instant voice cloning. Built off a 0.5B LLM backbone, NeuTTS Air brings natural-sounding speech, real-time performance, built-in security and speaker cloning to your local device - unlocking a new category of embedded voice agents, assistants, toys, and compliance-safe apps.
|
||||
tags:
|
||||
- text-to-speech
|
||||
- TTS
|
||||
license: apache-2.0
|
||||
capabilities:
|
||||
default: "cpu-neutts"
|
||||
nvidia: "cuda12-neutts"
|
||||
amd: "rocm-neutts"
|
||||
nvidia-l4t: "nvidia-l4t-neutts"
|
||||
- !!merge <<: *neutts
|
||||
name: "neutts-development"
|
||||
capabilities:
|
||||
default: "cpu-neutts-development"
|
||||
nvidia: "cuda12-neutts-development"
|
||||
amd: "rocm-neutts-development"
|
||||
nvidia-l4t: "nvidia-l4t-neutts-development"
|
||||
- !!merge <<: *neutts
|
||||
name: "cpu-neutts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-neutts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-cpu-neutts
|
||||
- !!merge <<: *neutts
|
||||
name: "cuda12-neutts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-neutts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-cuda-12-neutts
|
||||
- !!merge <<: *neutts
|
||||
name: "rocm-neutts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-neutts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-rocm-hipblas-neutts
|
||||
- !!merge <<: *neutts
|
||||
name: "nvidia-l4t-neutts"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-neutts"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-nvidia-l4t-arm64-neutts
|
||||
- !!merge <<: *neutts
|
||||
name: "cpu-neutts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-neutts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-neutts
|
||||
- !!merge <<: *neutts
|
||||
name: "cuda12-neutts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-neutts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-cuda-12-neutts
|
||||
- !!merge <<: *neutts
|
||||
name: "rocm-neutts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-neutts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-rocm-hipblas-neutts
|
||||
- !!merge <<: *neutts
|
||||
name: "nvidia-l4t-neutts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-neutts"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-nvidia-l4t-arm64-neutts
|
||||
- !!merge <<: *mlx
|
||||
name: "mlx-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx"
|
||||
@@ -1049,6 +1113,7 @@
|
||||
nvidia: "cuda12-kokoro-development"
|
||||
intel: "intel-kokoro-development"
|
||||
amd: "rocm-kokoro-development"
|
||||
nvidia-l4t: "nvidia-l4t-kokoro-development"
|
||||
- !!merge <<: *kokoro
|
||||
name: "cuda11-kokoro-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-11-kokoro"
|
||||
@@ -1074,6 +1139,16 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-kokoro"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-intel-kokoro
|
||||
- !!merge <<: *kokoro
|
||||
name: "nvidia-l4t-kokoro"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-l4t-kokoro"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-l4t-kokoro
|
||||
- !!merge <<: *kokoro
|
||||
name: "nvidia-l4t-kokoro-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-l4t-kokoro"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-l4t-kokoro
|
||||
- !!merge <<: *kokoro
|
||||
name: "cuda11-kokoro"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-11-kokoro"
|
||||
@@ -1227,6 +1302,7 @@
|
||||
nvidia: "cuda12-chatterbox-development"
|
||||
metal: "metal-chatterbox-development"
|
||||
default: "cpu-chatterbox-development"
|
||||
nvidia-l4t: "nvidia-l4t-arm64-chatterbox"
|
||||
- !!merge <<: *chatterbox
|
||||
name: "cpu-chatterbox"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-chatterbox"
|
||||
@@ -1237,6 +1313,16 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-chatterbox"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-cpu-chatterbox
|
||||
- !!merge <<: *chatterbox
|
||||
name: "nvidia-l4t-arm64-chatterbox"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-l4t-arm64-chatterbox"
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-gpu-nvidia-l4t-arm64-chatterbox
|
||||
- !!merge <<: *chatterbox
|
||||
name: "nvidia-l4t-arm64-chatterbox-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-l4t-arm64-chatterbox"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-gpu-nvidia-l4t-arm64-chatterbox
|
||||
- !!merge <<: *chatterbox
|
||||
name: "metal-chatterbox"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-chatterbox"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
bark==0.1.5
|
||||
grpcio==1.74.0
|
||||
grpcio==1.76.0
|
||||
protobuf
|
||||
certifi
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This is an extra gRPC server of LocalAI for Bark TTS
|
||||
This is an extra gRPC server of LocalAI for Chatterbox TTS
|
||||
"""
|
||||
from concurrent import futures
|
||||
import time
|
||||
@@ -14,15 +14,98 @@ import backend_pb2_grpc
|
||||
import torch
|
||||
import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||
import grpc
|
||||
import tempfile
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def split_text_at_word_boundary(text, max_length=250):
|
||||
"""
|
||||
Split text at word boundaries without truncating words.
|
||||
Returns a list of text chunks.
|
||||
"""
|
||||
if not text or len(text) <= max_length:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
words = text.split()
|
||||
current_chunk = ""
|
||||
|
||||
for word in words:
|
||||
# Check if adding this word would exceed the limit
|
||||
if len(current_chunk) + len(word) + 1 <= max_length:
|
||||
if current_chunk:
|
||||
current_chunk += " " + word
|
||||
else:
|
||||
current_chunk = word
|
||||
else:
|
||||
# If current chunk is not empty, add it to chunks
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = word
|
||||
else:
|
||||
# If a single word is longer than max_length, we have to include it anyway
|
||||
chunks.append(word)
|
||||
current_chunk = ""
|
||||
|
||||
# Add the last chunk if it's not empty
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def merge_audio_files(audio_files, output_path, sample_rate):
|
||||
"""
|
||||
Merge multiple audio files into a single audio file.
|
||||
"""
|
||||
if not audio_files:
|
||||
return
|
||||
|
||||
if len(audio_files) == 1:
|
||||
# If only one file, just copy it
|
||||
import shutil
|
||||
shutil.copy2(audio_files[0], output_path)
|
||||
return
|
||||
|
||||
# Load all audio files
|
||||
waveforms = []
|
||||
for audio_file in audio_files:
|
||||
waveform, sr = ta.load(audio_file)
|
||||
if sr != sample_rate:
|
||||
# Resample if necessary
|
||||
resampler = ta.transforms.Resample(sr, sample_rate)
|
||||
waveform = resampler(waveform)
|
||||
waveforms.append(waveform)
|
||||
|
||||
# Concatenate all waveforms
|
||||
merged_waveform = torch.cat(waveforms, dim=1)
|
||||
|
||||
# Save the merged audio
|
||||
ta.save(output_path, merged_waveform, sample_rate)
|
||||
|
||||
# Clean up temporary files
|
||||
for audio_file in audio_files:
|
||||
if os.path.exists(audio_file):
|
||||
os.remove(audio_file)
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
COQUI_LANGUAGE = os.environ.get('COQUI_LANGUAGE', None)
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
@@ -47,6 +130,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if not torch.cuda.is_available() and request.CUDA:
|
||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
||||
|
||||
|
||||
options = request.Options
|
||||
|
||||
# empty dict
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We are storing all the options in a dict so we can use it later when
|
||||
# generating the images
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":")
|
||||
# if value is a number, convert it to the appropriate type
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
self.options[key] = value
|
||||
|
||||
self.AudioPath = None
|
||||
|
||||
if os.path.isabs(request.AudioPath):
|
||||
@@ -56,10 +161,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
modelFileBase = os.path.dirname(request.ModelFile)
|
||||
# modify LoraAdapter to be relative to modelFileBase
|
||||
self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
|
||||
|
||||
try:
|
||||
print("Preparing models, please wait", file=sys.stderr)
|
||||
self.model = ChatterboxTTS.from_pretrained(device=device)
|
||||
if "multilingual" in self.options:
|
||||
# remove key from options
|
||||
del self.options["multilingual"]
|
||||
self.model = ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||
else:
|
||||
self.model = ChatterboxTTS.from_pretrained(device=device)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
@@ -68,14 +177,43 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
# Generate audio using ChatterboxTTS
|
||||
kwargs = {}
|
||||
|
||||
if "language" in self.options:
|
||||
kwargs["language_id"] = self.options["language"]
|
||||
if self.AudioPath is not None:
|
||||
wav = self.model.generate(request.text, audio_prompt_path=self.AudioPath)
|
||||
kwargs["audio_prompt_path"] = self.AudioPath
|
||||
|
||||
# add options to kwargs
|
||||
kwargs.update(self.options)
|
||||
|
||||
# Check if text exceeds 250 characters
|
||||
# (chatterbox does not support long text)
|
||||
# https://github.com/resemble-ai/chatterbox/issues/60
|
||||
# https://github.com/resemble-ai/chatterbox/issues/110
|
||||
if len(request.text) > 250:
|
||||
# Split text at word boundaries
|
||||
text_chunks = split_text_at_word_boundary(request.text, max_length=250)
|
||||
print(f"Splitting text into chunks of 250 characters: {len(text_chunks)}", file=sys.stderr)
|
||||
# Generate audio for each chunk
|
||||
temp_audio_files = []
|
||||
for i, chunk in enumerate(text_chunks):
|
||||
# Generate audio for this chunk
|
||||
wav = self.model.generate(chunk, **kwargs)
|
||||
|
||||
# Create temporary file for this chunk
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
|
||||
temp_file.close()
|
||||
ta.save(temp_file.name, wav, self.model.sr)
|
||||
temp_audio_files.append(temp_file.name)
|
||||
|
||||
# Merge all audio files
|
||||
merge_audio_files(temp_audio_files, request.dst, self.model.sr)
|
||||
else:
|
||||
wav = self.model.generate(request.text)
|
||||
|
||||
# Save the generated audio
|
||||
ta.save(request.dst, wav, self.model.sr)
|
||||
# Generate audio using ChatterboxTTS for short text
|
||||
wav = self.model.generate(request.text, **kwargs)
|
||||
# Save the generated audio
|
||||
ta.save(request.dst, wav, self.model.sr)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
@@ -15,5 +15,6 @@ fi
|
||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --no-build-isolation"
|
||||
|
||||
installRequirements
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
accelerate
|
||||
torch==2.6.0
|
||||
torchaudio==2.6.0
|
||||
transformers==4.46.3
|
||||
chatterbox-tts==0.1.2
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
#chatterbox-tts==0.1.4
|
||||
@@ -2,5 +2,6 @@
|
||||
torch==2.6.0+cu118
|
||||
torchaudio==2.6.0+cu118
|
||||
transformers==4.46.3
|
||||
chatterbox-tts==0.1.2
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
@@ -1,5 +1,6 @@
|
||||
torch==2.6.0
|
||||
torchaudio==2.6.0
|
||||
transformers==4.46.3
|
||||
chatterbox-tts==0.1.2
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
||||
torch==2.6.0+rocm6.1
|
||||
torchaudio==2.6.0+rocm6.1
|
||||
transformers==4.46.3
|
||||
chatterbox-tts==0.1.2
|
||||
transformers
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
torchaudio==2.3.1+cxx11.abi
|
||||
transformers==4.46.3
|
||||
chatterbox-tts==0.1.2
|
||||
transformers
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
oneccl_bind_pt==2.3.100+xpu
|
||||
optimum[openvino]
|
||||
|
||||
6
backend/python/chatterbox/requirements-l4t.txt
Normal file
6
backend/python/chatterbox/requirements-l4t.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
@@ -2,4 +2,5 @@ grpcio==1.71.0
|
||||
protobuf
|
||||
certifi
|
||||
packaging
|
||||
setuptools
|
||||
setuptools
|
||||
poetry
|
||||
@@ -1,3 +1,3 @@
|
||||
grpcio==1.74.0
|
||||
grpcio==1.76.0
|
||||
protobuf
|
||||
grpcio-tools
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.74.0
|
||||
grpcio==1.76.0
|
||||
protobuf
|
||||
certifi
|
||||
packaging==24.1
|
||||
@@ -1,5 +1,5 @@
|
||||
setuptools
|
||||
grpcio==1.74.0
|
||||
grpcio==1.76.0
|
||||
pillow
|
||||
protobuf
|
||||
certifi
|
||||
|
||||
@@ -31,7 +31,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
This method tests if the server starts up successfully
|
||||
"""
|
||||
time.sleep(10)
|
||||
time.sleep(20)
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
@@ -48,7 +48,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
This method tests if the model is loaded successfully
|
||||
"""
|
||||
time.sleep(10)
|
||||
time.sleep(20)
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
@@ -66,7 +66,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
This method tests if the backend can generate images
|
||||
"""
|
||||
time.sleep(10)
|
||||
time.sleep(20)
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.74.0
|
||||
grpcio==1.76.0
|
||||
protobuf
|
||||
certifi
|
||||
wheel
|
||||
|
||||
@@ -64,15 +64,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# Generate audio using Kokoro pipeline
|
||||
generator = self.pipeline(request.text, voice=voice)
|
||||
|
||||
# Get the first (and typically only) audio segment
|
||||
speechs = []
|
||||
# Get all the audio segment
|
||||
for i, (gs, ps, audio) in enumerate(generator):
|
||||
# Save audio to the destination file
|
||||
sf.write(request.dst, audio, 24000)
|
||||
speechs.append(audio)
|
||||
print(f"Generated audio segment {i}: gs={gs}, ps={ps}", file=sys.stderr)
|
||||
# For now, we only process the first segment
|
||||
# If you need to handle multiple segments, you might want to modify this
|
||||
break
|
||||
|
||||
# Merges the audio segments and writes them to the destination
|
||||
speech = torch.cat(speechs, dim=0)
|
||||
sf.write(request.dst, speech, 24000)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
|
||||
7
backend/python/kokoro/requirements-l4t.txt
Normal file
7
backend/python/kokoro/requirements-l4t.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
accelerate
|
||||
kokoro
|
||||
soundfile
|
||||
23
backend/python/neutts/Makefile
Normal file
23
backend/python/neutts/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: neutts
|
||||
neutts:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: neutts
|
||||
@echo "Running neutts..."
|
||||
bash run.sh
|
||||
@echo "neutts run."
|
||||
|
||||
.PHONY: test
|
||||
test: neutts
|
||||
@echo "Testing neutts..."
|
||||
bash test.sh
|
||||
@echo "neutts tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
162
backend/python/neutts/backend.py
Normal file
162
backend/python/neutts/backend.py
Normal file
@@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This is an extra gRPC server of LocalAI for NeuTTSAir
|
||||
"""
|
||||
from concurrent import futures
|
||||
import time
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
import torch
|
||||
from neuttsair.neutts import NeuTTSAir
|
||||
import soundfile as sf
|
||||
|
||||
import grpc
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
def is_int(s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
BackendServicer is the class that implements the gRPC service
|
||||
"""
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
def LoadModel(self, request, context):
|
||||
|
||||
# Get device
|
||||
# device = "cuda" if request.CUDA else "cpu"
|
||||
if torch.cuda.is_available():
|
||||
print("CUDA is available", file=sys.stderr)
|
||||
device = "cuda"
|
||||
else:
|
||||
print("CUDA is not available", file=sys.stderr)
|
||||
device = "cpu"
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
if mps_available:
|
||||
device = "mps"
|
||||
if not torch.cuda.is_available() and request.CUDA:
|
||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
||||
|
||||
|
||||
options = request.Options
|
||||
|
||||
# empty dict
|
||||
self.options = {}
|
||||
self.ref_text = None
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We are storing all the options in a dict so we can use it later when
|
||||
# generating the images
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":")
|
||||
# if value is a number, convert it to the appropriate type
|
||||
if is_float(value):
|
||||
value = float(value)
|
||||
elif is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
self.options[key] = value
|
||||
|
||||
codec_repo = "neuphonic/neucodec"
|
||||
if "codec_repo" in self.options:
|
||||
codec_repo = self.options["codec_repo"]
|
||||
del self.options["codec_repo"]
|
||||
if "ref_text" in self.options:
|
||||
self.ref_text = self.options["ref_text"]
|
||||
del self.options["ref_text"]
|
||||
|
||||
self.AudioPath = None
|
||||
|
||||
if os.path.isabs(request.AudioPath):
|
||||
self.AudioPath = request.AudioPath
|
||||
elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath):
|
||||
# get base path of modelFile
|
||||
modelFileBase = os.path.dirname(request.ModelFile)
|
||||
# modify LoraAdapter to be relative to modelFileBase
|
||||
self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
|
||||
try:
|
||||
print("Preparing models, please wait", file=sys.stderr)
|
||||
self.model = NeuTTSAir(backbone_repo=request.Model, backbone_device=device, codec_repo=codec_repo, codec_device=device)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
# Replace this with your desired response
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
|
||||
def TTS(self, request, context):
|
||||
try:
|
||||
kwargs = {}
|
||||
|
||||
# add options to kwargs
|
||||
kwargs.update(self.options)
|
||||
|
||||
ref_codes = self.model.encode_reference(self.AudioPath)
|
||||
|
||||
wav = self.model.infer(request.text, ref_codes, self.ref_text)
|
||||
|
||||
sf.write(request.dst, wav, 24000)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
server.add_insecure_port(address)
|
||||
server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
|
||||
# Define the signal handler function
|
||||
def signal_handler(sig, frame):
|
||||
print("Received termination signal. Shutting down...")
|
||||
server.stop(0)
|
||||
sys.exit(0)
|
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
33
backend/python/neutts/install.sh
Executable file
33
backend/python/neutts/install.sh
Executable file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
|
||||
# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
|
||||
# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
|
||||
# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
|
||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_TYPE}" == "xcublas" ] || [ "x${BUILD_TYPE}" == "xl4t" ]; then
|
||||
export CMAKE_ARGS="-DGGML_CUDA=on"
|
||||
fi
|
||||
|
||||
if [ "x${BUILD_TYPE}" == "xhipblas" ]; then
|
||||
export CMAKE_ARGS="-DGGML_HIPBLAS=on"
|
||||
fi
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS+=" --no-build-isolation"
|
||||
|
||||
git clone https://github.com/neuphonic/neutts-air neutts-air
|
||||
|
||||
cp -rfv neutts-air/neuttsair ./
|
||||
|
||||
installRequirements
|
||||
2
backend/python/neutts/requirements-after.txt
Normal file
2
backend/python/neutts/requirements-after.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
datasets==4.1.1
|
||||
torchtune==0.6.1
|
||||
10
backend/python/neutts/requirements-cpu.txt
Normal file
10
backend/python/neutts/requirements-cpu.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||
accelerate
|
||||
torch==2.8.0
|
||||
transformers==4.56.1
|
||||
librosa==0.11.0
|
||||
neucodec>=0.0.4
|
||||
phonemizer==3.3.0
|
||||
soundfile==0.13.1
|
||||
resemble-perth==1.0.1
|
||||
llama-cpp-python
|
||||
8
backend/python/neutts/requirements-cublas12.txt
Normal file
8
backend/python/neutts/requirements-cublas12.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
librosa==0.11.0
|
||||
neucodec>=0.0.4
|
||||
phonemizer==3.3.0
|
||||
soundfile==0.13.1
|
||||
torch==2.8.0
|
||||
transformers==4.56.1
|
||||
resemble-perth==1.0.1
|
||||
accelerate
|
||||
10
backend/python/neutts/requirements-hipblas.txt
Normal file
10
backend/python/neutts/requirements-hipblas.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
||||
torch==2.8.0+rocm6.3
|
||||
transformers==4.56.1
|
||||
accelerate
|
||||
librosa==0.11.0
|
||||
neucodec>=0.0.4
|
||||
phonemizer==3.3.0
|
||||
soundfile==0.13.1
|
||||
resemble-perth==1.0.1
|
||||
llama-cpp-python
|
||||
10
backend/python/neutts/requirements-l4t.txt
Normal file
10
backend/python/neutts/requirements-l4t.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
|
||||
torch
|
||||
transformers
|
||||
accelerate
|
||||
librosa==0.11.0
|
||||
neucodec>=0.0.4
|
||||
phonemizer==3.3.0
|
||||
soundfile==0.13.1
|
||||
resemble-perth==1.0.1
|
||||
llama-cpp-python
|
||||
7
backend/python/neutts/requirements.txt
Normal file
7
backend/python/neutts/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
certifi
|
||||
packaging
|
||||
setuptools
|
||||
numpy==2.2.6
|
||||
scikit_build_core
|
||||
10
backend/python/neutts/run.sh
Executable file
10
backend/python/neutts/run.sh
Executable file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
|
||||
startBackend $@
|
||||
82
backend/python/neutts/test.py
Normal file
82
backend/python/neutts/test.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
A test script to test the gRPC service
|
||||
"""
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
This method sets up the gRPC service by starting the server
|
||||
"""
|
||||
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
|
||||
time.sleep(30)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""
|
||||
This method tears down the gRPC service by terminating the server
|
||||
"""
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
"""
|
||||
This method tests if the server starts up successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_load_model(self):
|
||||
"""
|
||||
This method tests if the model is loaded successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions())
|
||||
print(response)
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_tts(self):
|
||||
"""
|
||||
This method tests if the embeddings are generated successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions())
|
||||
self.assertTrue(response.success)
|
||||
tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
|
||||
tts_response = stub.TTS(tts_request)
|
||||
self.assertIsNotNone(tts_response)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TTS service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
11
backend/python/neutts/test.sh
Executable file
11
backend/python/neutts/test.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -1,3 +1,3 @@
|
||||
grpcio==1.74.0
|
||||
grpcio==1.76.0
|
||||
protobuf
|
||||
certifi
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.74.0
|
||||
grpcio==1.76.0
|
||||
protobuf==6.32.0
|
||||
certifi
|
||||
setuptools
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
grpcio==1.74.0
|
||||
grpcio==1.76.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
@@ -2,14 +2,12 @@ package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/fyne/v2/app"
|
||||
"fyne.io/fyne/v2/driver/desktop"
|
||||
coreLauncher "github.com/mudler/LocalAI/cmd/launcher/internal"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -42,7 +40,12 @@ func main() {
|
||||
}
|
||||
|
||||
// Setup signal handling for graceful shutdown
|
||||
setupSignalHandling(launcher)
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
// Perform cleanup
|
||||
if err := launcher.Shutdown(); err != nil {
|
||||
log.Printf("Error during shutdown: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Initialize the launcher state
|
||||
go func() {
|
||||
@@ -67,26 +70,3 @@ func main() {
|
||||
// Run the application in background (window only shown when "Settings" is clicked)
|
||||
myApp.Run()
|
||||
}
|
||||
|
||||
// setupSignalHandling sets up signal handlers for graceful shutdown
|
||||
func setupSignalHandling(launcher *coreLauncher.Launcher) {
|
||||
// Create a channel to receive OS signals
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
|
||||
// Register for interrupt and terminate signals
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Handle signals in a separate goroutine
|
||||
go func() {
|
||||
sig := <-sigChan
|
||||
log.Printf("Received signal %v, shutting down gracefully...", sig)
|
||||
|
||||
// Perform cleanup
|
||||
if err := launcher.Shutdown(); err != nil {
|
||||
log.Printf("Error during shutdown: %v", err)
|
||||
}
|
||||
|
||||
// Exit the application
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -129,7 +129,6 @@ func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
|
||||
triggers = append(triggers, &pb.GrammarTrigger{
|
||||
Word: t.Word,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
return &pb.ModelOptions{
|
||||
|
||||
@@ -60,7 +60,7 @@ func SoundGeneration(
|
||||
|
||||
// return RPC error if any
|
||||
if !res.Success {
|
||||
return "", nil, fmt.Errorf(res.Message)
|
||||
return "", nil, fmt.Errorf("error during sound generation: %s", res.Message)
|
||||
}
|
||||
|
||||
return filePath, res, err
|
||||
|
||||
@@ -70,7 +70,7 @@ func ModelTTS(
|
||||
|
||||
// return RPC error if any
|
||||
if !res.Success {
|
||||
return "", nil, fmt.Errorf(res.Message)
|
||||
return "", nil, fmt.Errorf("error during TTS: %s", res.Message)
|
||||
}
|
||||
|
||||
return filePath, res, err
|
||||
|
||||
@@ -5,9 +5,10 @@ import (
|
||||
"time"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/explorer"
|
||||
"github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ExplorerCMD struct {
|
||||
@@ -46,7 +47,11 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
|
||||
|
||||
appHTTP := http.Explorer(db)
|
||||
|
||||
signals.Handler(nil)
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
if err := appHTTP.Shutdown(); err != nil {
|
||||
log.Error().Err(err).Msg("error during shutdown")
|
||||
}
|
||||
})
|
||||
|
||||
return appHTTP.Listen(e.Address)
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
)
|
||||
|
||||
type FederatedCLI struct {
|
||||
@@ -20,7 +20,11 @@ func (f *FederatedCLI) Run(ctx *cliContext.Context) error {
|
||||
|
||||
fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker)
|
||||
|
||||
signals.Handler(nil)
|
||||
c, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return fs.Start(context.Background())
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
return fs.Start(c)
|
||||
}
|
||||
|
||||
@@ -10,11 +10,11 @@ import (
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
cli_api "github.com/mudler/LocalAI/core/cli/api"
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -127,6 +127,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
|
||||
config.WithLoadToMemory(r.LoadToMemory),
|
||||
config.WithMachineTag(r.MachineTag),
|
||||
config.WithAPIAddress(r.Address),
|
||||
}
|
||||
|
||||
if r.DisableMetricsEndpoint {
|
||||
@@ -225,8 +226,11 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Catch signals from the OS requesting us to exit, and stop all backends
|
||||
signals.Handler(app.ModelLoader())
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
if err := app.ModelLoader().StopAllGRPC(); err != nil {
|
||||
log.Error().Err(err).Msg("error while stopping all grpc backends")
|
||||
}
|
||||
})
|
||||
|
||||
return appHTTP.Listen(r.Address)
|
||||
}
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package signals
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Handler(m *model.ModelLoader) {
|
||||
// Catch signals from the OS requesting us to exit, and stop all backends
|
||||
go func(m *model.ModelLoader) {
|
||||
c := make(chan os.Signal, 1) // we need to reserve to buffer size 1, so the notifier are not blocked
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
|
||||
<-c
|
||||
if m != nil {
|
||||
if err := m.StopAllGRPC(); err != nil {
|
||||
log.Error().Err(err).Msg("error while stopping all grpc backends")
|
||||
}
|
||||
}
|
||||
os.Exit(1)
|
||||
}(m)
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -85,8 +84,6 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
|
||||
|
||||
args = append([]string{grpcProcess}, args...)
|
||||
|
||||
signals.Handler(nil)
|
||||
|
||||
return syscall.Exec(
|
||||
grpcProcess,
|
||||
args,
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/cli/signals"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/phayes/freeport"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -48,6 +48,9 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
|
||||
address := "127.0.0.1"
|
||||
|
||||
c, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if r.NoRunner {
|
||||
// Let override which port and address to bind if the user
|
||||
// configure the llama-cpp service on its own
|
||||
@@ -59,7 +62,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
p = r.RunnerPort
|
||||
}
|
||||
|
||||
_, err = p2p.ExposeService(context.Background(), address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||
_, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -101,13 +104,15 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = p2p.ExposeService(context.Background(), address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
signals.Handler(nil)
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
cancel()
|
||||
})
|
||||
|
||||
for {
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
@@ -63,6 +63,8 @@ type ApplicationConfig struct {
|
||||
WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration
|
||||
|
||||
MachineTag string
|
||||
|
||||
APIAddress string
|
||||
}
|
||||
|
||||
type AppOption func(*ApplicationConfig)
|
||||
@@ -343,6 +345,12 @@ func WithDisableApiKeyRequirementForHttpGet(required bool) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithAPIAddress(address string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.APIAddress = address
|
||||
}
|
||||
}
|
||||
|
||||
var DisableMetricsEndpoint AppOption = func(o *ApplicationConfig) {
|
||||
o.DisableMetrics = true
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ type TTSConfig struct {
|
||||
|
||||
// ModelConfig represents a model configuration
|
||||
type ModelConfig struct {
|
||||
modelConfigFile string `yaml:"-" json:"-"`
|
||||
schema.PredictionOptions `yaml:"parameters" json:"parameters"`
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
@@ -73,6 +74,55 @@ type ModelConfig struct {
|
||||
|
||||
Options []string `yaml:"options" json:"options"`
|
||||
Overrides []string `yaml:"overrides" json:"overrides"`
|
||||
|
||||
MCP MCPConfig `yaml:"mcp" json:"mcp"`
|
||||
Agent AgentConfig `yaml:"agent" json:"agent"`
|
||||
}
|
||||
|
||||
type MCPConfig struct {
|
||||
Servers string `yaml:"remote" json:"remote"`
|
||||
Stdio string `yaml:"stdio" json:"stdio"`
|
||||
}
|
||||
|
||||
type AgentConfig struct {
|
||||
MaxAttempts int `yaml:"max_attempts" json:"max_attempts"`
|
||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||
EnableReasoning bool `yaml:"enable_reasoning" json:"enable_reasoning"`
|
||||
EnablePlanning bool `yaml:"enable_planning" json:"enable_planning"`
|
||||
EnableMCPPrompts bool `yaml:"enable_mcp_prompts" json:"enable_mcp_prompts"`
|
||||
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator" json:"enable_plan_re_evaluator"`
|
||||
}
|
||||
|
||||
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers]) {
|
||||
var remote MCPGenericConfig[MCPRemoteServers]
|
||||
var stdio MCPGenericConfig[MCPSTDIOServers]
|
||||
|
||||
if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil {
|
||||
return remote, stdio
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil {
|
||||
return remote, stdio
|
||||
}
|
||||
|
||||
return remote, stdio
|
||||
}
|
||||
|
||||
type MCPGenericConfig[T any] struct {
|
||||
Servers T `yaml:"mcpServers" json:"mcpServers"`
|
||||
}
|
||||
type MCPRemoteServers map[string]MCPRemoteServer
|
||||
type MCPSTDIOServers map[string]MCPSTDIOServer
|
||||
|
||||
type MCPRemoteServer struct {
|
||||
URL string `json:"url"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
type MCPSTDIOServer struct {
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
Command string `json:"command"`
|
||||
}
|
||||
|
||||
// Pipeline defines other models to use for audio-to-audio
|
||||
@@ -445,6 +495,10 @@ func (c *ModelConfig) HasTemplate() bool {
|
||||
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != ""
|
||||
}
|
||||
|
||||
func (c *ModelConfig) GetModelConfigFile() string {
|
||||
return c.modelConfigFile
|
||||
}
|
||||
|
||||
type ModelConfigUsecases int
|
||||
|
||||
const (
|
||||
@@ -88,6 +88,7 @@ func readMultipleModelConfigsFromFile(file string, opts ...ConfigLoaderOption) (
|
||||
}
|
||||
|
||||
for _, cc := range *c {
|
||||
cc.modelConfigFile = file
|
||||
cc.SetDefaults(opts...)
|
||||
}
|
||||
|
||||
@@ -108,6 +109,8 @@ func readModelConfigFromFile(file string, opts ...ConfigLoaderOption) (*ModelCon
|
||||
}
|
||||
|
||||
c.SetDefaults(opts...)
|
||||
|
||||
c.modelConfigFile = file
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -72,3 +74,7 @@ func (m *GalleryBackend) GetDescription() string {
|
||||
func (m *GalleryBackend) GetTags() []string {
|
||||
return m.Tags
|
||||
}
|
||||
|
||||
func (m GalleryBackend) ID() string {
|
||||
return fmt.Sprintf("%s@%s", m.Gallery.Name, m.Name)
|
||||
}
|
||||
|
||||
91
core/gallery/btree.go
Normal file
91
core/gallery/btree.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/google/btree"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Gallery[T GalleryElement] struct {
|
||||
*btree.BTreeG[T]
|
||||
}
|
||||
|
||||
func less[T GalleryElement](a, b T) bool {
|
||||
return a.GetName() < b.GetName()
|
||||
}
|
||||
|
||||
func (g *Gallery[T]) Search(term string) GalleryElements[T] {
|
||||
var filteredModels GalleryElements[T]
|
||||
g.Ascend(func(item T) bool {
|
||||
if strings.Contains(item.GetName(), term) ||
|
||||
strings.Contains(item.GetDescription(), term) ||
|
||||
strings.Contains(item.GetGallery().Name, term) ||
|
||||
strings.Contains(strings.Join(item.GetTags(), ","), term) {
|
||||
filteredModels = append(filteredModels, item)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return filteredModels
|
||||
}
|
||||
|
||||
// processYAMLFile takes a single file path and adds its models to the existing tree.
|
||||
func processYAMLFile[T GalleryElement](filePath string, tree *Gallery[T]) error {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not open %s: %w", filePath, err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
decoder := yaml.NewDecoder(file)
|
||||
|
||||
// Stream documents from the file
|
||||
for {
|
||||
var model T
|
||||
err := decoder.Decode(&model)
|
||||
if err == io.EOF {
|
||||
break // End of file
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("error decoding %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
tree.ReplaceOrInsert(model)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadModelsFromDirectory scans a directory and processes all .yaml/.yml files.
|
||||
func LoadGalleryFromDirectory[T GalleryElement](dirPath string) (*Gallery[T], error) {
|
||||
tree := &Gallery[T]{btree.NewG[T](2, less[T])}
|
||||
|
||||
entries, err := os.ReadDir(dirPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not read directory: %w", err)
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
// Skip subdirectories and non-YAML files.
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
fileName := entry.Name()
|
||||
if !strings.HasSuffix(fileName, ".yaml") && !strings.HasSuffix(fileName, ".yml") {
|
||||
continue
|
||||
}
|
||||
|
||||
fullPath := filepath.Join(dirPath, fileName)
|
||||
|
||||
err := processYAMLFile(fullPath, tree)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return tree, nil
|
||||
}
|
||||
@@ -6,10 +6,12 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/lithammer/fuzzysearch/fuzzy"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
@@ -56,12 +58,12 @@ type GalleryElements[T GalleryElement] []T
|
||||
|
||||
func (gm GalleryElements[T]) Search(term string) GalleryElements[T] {
|
||||
var filteredModels GalleryElements[T]
|
||||
|
||||
term = strings.ToLower(term)
|
||||
for _, m := range gm {
|
||||
if strings.Contains(m.GetName(), term) ||
|
||||
strings.Contains(m.GetDescription(), term) ||
|
||||
strings.Contains(m.GetGallery().Name, term) ||
|
||||
strings.Contains(strings.Join(m.GetTags(), ","), term) {
|
||||
if fuzzy.Match(term, strings.ToLower(m.GetName())) ||
|
||||
fuzzy.Match(term, strings.ToLower(m.GetDescription())) ||
|
||||
fuzzy.Match(term, strings.ToLower(m.GetGallery().Name)) ||
|
||||
strings.Contains(strings.ToLower(strings.Join(m.GetTags(), ",")), term) {
|
||||
filteredModels = append(filteredModels, m)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -203,6 +203,9 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||
routes.RegisterLocalAIRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
// Create opcache for tracking UI operations
|
||||
opcache := services.NewOpCache(application.GalleryService())
|
||||
routes.RegisterUIAPIRoutes(router, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
|
||||
routes.RegisterUIRoutes(router, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
}
|
||||
routes.RegisterJINARoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
@@ -799,7 +799,7 @@ var _ = Describe("API test", func() {
|
||||
It("returns errors", func() {
|
||||
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error:"))
|
||||
Expect(err.Error()).To(ContainSubstring("error, status code: 500, status: 500 Internal Server Error, message: could not load model - all backends returned error:"))
|
||||
})
|
||||
|
||||
It("shows the external backend", func() {
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
package elements
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/chasefleming/elem-go"
|
||||
"github.com/chasefleming/elem-go/attrs"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
)
|
||||
|
||||
func installButton(galleryName string) elem.Node {
|
||||
return elem.Button(
|
||||
attrs.Props{
|
||||
"data-twe-ripple-init": "",
|
||||
"data-twe-ripple-color": "light",
|
||||
"class": "float-right inline-flex items-center rounded-lg bg-blue-600 hover:bg-blue-700 px-4 py-2 text-sm font-medium text-white transition duration-300 ease-in-out shadow hover:shadow-lg",
|
||||
"hx-swap": "outerHTML",
|
||||
// post the Model ID as param
|
||||
"hx-post": "browse/install/model/" + galleryName,
|
||||
},
|
||||
elem.I(
|
||||
attrs.Props{
|
||||
"class": "fa-solid fa-download pr-2",
|
||||
},
|
||||
),
|
||||
elem.Text("Install"),
|
||||
)
|
||||
}
|
||||
|
||||
func reInstallButton(galleryName string) elem.Node {
|
||||
return elem.Button(
|
||||
attrs.Props{
|
||||
"data-twe-ripple-init": "",
|
||||
"data-twe-ripple-color": "light",
|
||||
"class": "float-right inline-block rounded bg-primary ml-2 px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
|
||||
"hx-target": "#action-div-" + dropBadChars(galleryName),
|
||||
"hx-swap": "outerHTML",
|
||||
// post the Model ID as param
|
||||
"hx-post": "browse/install/model/" + galleryName,
|
||||
},
|
||||
elem.I(
|
||||
attrs.Props{
|
||||
"class": "fa-solid fa-arrow-rotate-right pr-2",
|
||||
},
|
||||
),
|
||||
elem.Text("Reinstall"),
|
||||
)
|
||||
}
|
||||
|
||||
func infoButton(m *gallery.GalleryModel) elem.Node {
|
||||
return elem.Button(
|
||||
attrs.Props{
|
||||
"data-twe-ripple-init": "",
|
||||
"data-twe-ripple-color": "light",
|
||||
"class": "inline-flex items-center rounded-lg bg-gray-700 hover:bg-gray-600 px-4 py-2 text-sm font-medium text-white transition duration-300 ease-in-out",
|
||||
"data-modal-target": modalName(m),
|
||||
"data-modal-toggle": modalName(m),
|
||||
},
|
||||
elem.P(
|
||||
attrs.Props{
|
||||
"class": "flex items-center",
|
||||
},
|
||||
elem.I(
|
||||
attrs.Props{
|
||||
"class": "fas fa-info-circle pr-2",
|
||||
},
|
||||
),
|
||||
elem.Text("Info"),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func getConfigButton(galleryName string) elem.Node {
|
||||
return elem.Button(
|
||||
attrs.Props{
|
||||
"data-twe-ripple-init": "",
|
||||
"data-twe-ripple-color": "light",
|
||||
"class": "float-right ml-2 inline-flex items-center rounded-lg bg-gray-700 hover:bg-gray-600 px-4 py-2 text-sm font-medium text-white transition duration-300 ease-in-out",
|
||||
"hx-swap": "outerHTML",
|
||||
"hx-post": "browse/config/model/" + galleryName,
|
||||
},
|
||||
elem.I(
|
||||
attrs.Props{
|
||||
"class": "fa-solid fa-download pr-2",
|
||||
},
|
||||
),
|
||||
elem.Text("Get Config"),
|
||||
)
|
||||
}
|
||||
|
||||
func deleteButton(galleryID string) elem.Node {
|
||||
return elem.Button(
|
||||
attrs.Props{
|
||||
"data-twe-ripple-init": "",
|
||||
"data-twe-ripple-color": "light",
|
||||
"hx-confirm": "Are you sure you wish to delete the model?",
|
||||
"class": "float-right inline-block rounded bg-red-800 px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-red-accent-300 hover:shadow-red-2 focus:bg-red-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-red-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
|
||||
"hx-target": "#action-div-" + dropBadChars(galleryID),
|
||||
"hx-swap": "outerHTML",
|
||||
// post the Model ID as param
|
||||
"hx-post": "browse/delete/model/" + galleryID,
|
||||
},
|
||||
elem.I(
|
||||
attrs.Props{
|
||||
"class": "fa-solid fa-cancel pr-2",
|
||||
},
|
||||
),
|
||||
elem.Text("Delete"),
|
||||
)
|
||||
}
|
||||
|
||||
// Javascript/HTMX doesn't like weird IDs
|
||||
func dropBadChars(s string) string {
|
||||
return strings.ReplaceAll(s, "@", "__")
|
||||
}
|
||||
@@ -1,757 +0,0 @@
|
||||
package elements
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/chasefleming/elem-go"
|
||||
"github.com/chasefleming/elem-go/attrs"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
const (
|
||||
noImage = "https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg"
|
||||
)
|
||||
|
||||
func cardSpan(text, icon string) elem.Node {
|
||||
return elem.Span(
|
||||
attrs.Props{
|
||||
"class": "inline-flex items-center px-3 py-1 rounded-lg text-xs font-medium bg-gray-700/70 text-gray-300 border border-gray-600/50 mr-2 mb-2",
|
||||
},
|
||||
elem.I(attrs.Props{
|
||||
"class": icon + " pr-2",
|
||||
}),
|
||||
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(text)),
|
||||
)
|
||||
}
|
||||
|
||||
func searchableElement(text, icon string) elem.Node {
|
||||
return elem.Form(
|
||||
attrs.Props{},
|
||||
elem.Input(
|
||||
attrs.Props{
|
||||
"type": "hidden",
|
||||
"name": "search",
|
||||
"value": text,
|
||||
},
|
||||
),
|
||||
elem.Span(
|
||||
attrs.Props{
|
||||
"class": "inline-flex items-center text-xs px-3 py-1 rounded-full bg-gray-700/60 text-gray-300 border border-gray-600/50 hover:bg-gray-600 hover:text-gray-100 transition duration-200 ease-in-out",
|
||||
},
|
||||
elem.A(
|
||||
attrs.Props{
|
||||
// "name": "search",
|
||||
// "value": text,
|
||||
//"class": "inline-block bg-gray-200 rounded-full px-3 py-1 text-sm font-semibold text-gray-700 mr-2 mb-2",
|
||||
//"href": "#!",
|
||||
"href": "browse?term=" + text,
|
||||
//"hx-post": "browse/search/models",
|
||||
//"hx-target": "#search-results",
|
||||
// TODO: this doesn't work
|
||||
// "hx-vals": `{ \"search\": \"` + text + `\" }`,
|
||||
//"hx-indicator": ".htmx-indicator",
|
||||
},
|
||||
elem.I(attrs.Props{
|
||||
"class": icon + " pr-2",
|
||||
}),
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(text)),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
/*
|
||||
func buttonLink(text, url string) elem.Node {
|
||||
return elem.A(
|
||||
attrs.Props{
|
||||
"class": "inline-block bg-gray-200 rounded-full px-3 py-1 text-sm font-semibold text-gray-700 mr-2 mb-2 hover:bg-gray-300 hover:shadow-gray-2",
|
||||
"href": url,
|
||||
"target": "_blank",
|
||||
},
|
||||
elem.I(attrs.Props{
|
||||
"class": "fas fa-link pr-2",
|
||||
}),
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(text)),
|
||||
)
|
||||
}
|
||||
*/
|
||||
|
||||
func link(text, url string) elem.Node {
|
||||
return elem.A(
|
||||
attrs.Props{
|
||||
"class": "text-base leading-relaxed text-gray-500 dark:text-gray-400",
|
||||
"href": url,
|
||||
"target": "_blank",
|
||||
},
|
||||
elem.I(attrs.Props{
|
||||
"class": "fas fa-link pr-2",
|
||||
}),
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(text)),
|
||||
)
|
||||
}
|
||||
|
||||
type ProcessTracker interface {
|
||||
Exists(string) bool
|
||||
Get(string) string
|
||||
}
|
||||
|
||||
func modalName(m *gallery.GalleryModel) string {
|
||||
return m.Name + "-modal"
|
||||
}
|
||||
|
||||
func modelModal(m *gallery.GalleryModel) elem.Node {
|
||||
urls := []elem.Node{}
|
||||
for _, url := range m.URLs {
|
||||
urls = append(urls,
|
||||
elem.Li(attrs.Props{}, link(url, url)),
|
||||
)
|
||||
}
|
||||
|
||||
tagsNodes := []elem.Node{}
|
||||
for _, tag := range m.Tags {
|
||||
tagsNodes = append(tagsNodes,
|
||||
searchableElement(tag, "fas fa-tag"),
|
||||
)
|
||||
}
|
||||
|
||||
return elem.Div(
|
||||
attrs.Props{
|
||||
"id": modalName(m),
|
||||
"tabindex": "-1",
|
||||
"aria-hidden": "true",
|
||||
"class": "hidden fixed top-0 right-0 left-0 z-50 justify-center items-center w-full md:inset-0 h-full max-h-full bg-gray-900/50",
|
||||
},
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "relative p-4 w-full max-w-2xl h-[90vh] mx-auto mt-[5vh]",
|
||||
},
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "relative bg-white rounded-lg shadow dark:bg-gray-700 h-full flex flex-col",
|
||||
},
|
||||
// header
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex items-center justify-between p-4 md:p-5 border-b rounded-t dark:border-gray-600",
|
||||
},
|
||||
elem.H3(
|
||||
attrs.Props{
|
||||
"class": "text-xl font-semibold text-gray-900 dark:text-white",
|
||||
},
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Name)),
|
||||
),
|
||||
elem.Button( // close button
|
||||
attrs.Props{
|
||||
"class": "text-gray-400 bg-transparent hover:bg-gray-200 hover:text-gray-900 rounded-lg text-sm w-8 h-8 ms-auto inline-flex justify-center items-center dark:hover:bg-gray-600 dark:hover:text-white",
|
||||
"data-modal-hide": modalName(m),
|
||||
},
|
||||
elem.Raw(
|
||||
`<svg class="w-3 h-3" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 14 14">
|
||||
<path stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="m1 1 6 6m0 0 6 6M7 7l6-6M7 7l-6 6"/>
|
||||
</svg>`,
|
||||
),
|
||||
elem.Span(
|
||||
attrs.Props{
|
||||
"class": "sr-only",
|
||||
},
|
||||
elem.Text("Close modal"),
|
||||
),
|
||||
),
|
||||
),
|
||||
// body
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "p-4 md:p-5 space-y-4 overflow-y-auto flex-1 min-h-0",
|
||||
},
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex justify-center items-center",
|
||||
},
|
||||
elem.Img(attrs.Props{
|
||||
"class": "lazy rounded-t-lg max-h-48 max-w-96 object-cover mt-3 entered loaded",
|
||||
"src": m.Icon,
|
||||
"loading": "lazy",
|
||||
}),
|
||||
),
|
||||
elem.P(
|
||||
attrs.Props{
|
||||
"class": "text-base leading-relaxed text-gray-500 dark:text-gray-400",
|
||||
},
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Description)),
|
||||
),
|
||||
elem.Hr(
|
||||
attrs.Props{},
|
||||
),
|
||||
elem.P(
|
||||
attrs.Props{
|
||||
"class": "text-sm font-semibold text-gray-900 dark:text-white",
|
||||
},
|
||||
elem.Text("Links"),
|
||||
),
|
||||
elem.Ul(
|
||||
attrs.Props{},
|
||||
urls...,
|
||||
),
|
||||
elem.If(
|
||||
len(m.Tags) > 0,
|
||||
elem.Div(
|
||||
attrs.Props{},
|
||||
elem.P(
|
||||
attrs.Props{
|
||||
"class": "text-sm mb-5 font-semibold text-gray-900 dark:text-white",
|
||||
},
|
||||
elem.Text("Tags"),
|
||||
),
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex flex-row flex-wrap content-center",
|
||||
},
|
||||
tagsNodes...,
|
||||
),
|
||||
),
|
||||
elem.Div(attrs.Props{}),
|
||||
),
|
||||
),
|
||||
// Footer
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex items-center p-4 md:p-5 border-t border-gray-200 rounded-b dark:border-gray-600",
|
||||
},
|
||||
elem.Button(
|
||||
attrs.Props{
|
||||
"data-modal-hide": modalName(m),
|
||||
"class": "py-2.5 px-5 ms-3 text-sm font-medium text-gray-900 focus:outline-none bg-white rounded-lg border border-gray-200 hover:bg-gray-100 hover:text-blue-700 focus:z-10 focus:ring-4 focus:ring-gray-100 dark:focus:ring-gray-700 dark:bg-gray-800 dark:text-gray-400 dark:border-gray-600 dark:hover:text-white dark:hover:bg-gray-700",
|
||||
},
|
||||
elem.Text("Close"),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func modelDescription(m *gallery.GalleryModel) elem.Node {
|
||||
return elem.Div(
|
||||
attrs.Props{
|
||||
"class": "p-6 text-surface dark:text-white",
|
||||
},
|
||||
elem.H5(
|
||||
attrs.Props{
|
||||
"class": "mb-2 text-xl font-bold leading-tight",
|
||||
},
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Name)),
|
||||
),
|
||||
elem.Div( // small description
|
||||
attrs.Props{
|
||||
"class": "mb-4 text-sm truncate text-base",
|
||||
},
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(m.Description)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func modelActionItems(m *gallery.GalleryModel, processTracker ProcessTracker, galleryService *services.GalleryService) elem.Node {
|
||||
galleryID := fmt.Sprintf("%s@%s", m.Gallery.Name, m.Name)
|
||||
currentlyProcessing := processTracker.Exists(galleryID)
|
||||
jobID := ""
|
||||
isDeletionOp := false
|
||||
if currentlyProcessing {
|
||||
status := galleryService.GetStatus(galleryID)
|
||||
if status != nil && status.Deletion {
|
||||
isDeletionOp = true
|
||||
}
|
||||
jobID = processTracker.Get(galleryID)
|
||||
// TODO:
|
||||
// case not handled, if status == nil : "Waiting"
|
||||
}
|
||||
|
||||
nodes := []elem.Node{
|
||||
cardSpan("Repository: "+m.Gallery.Name, "fa-brands fa-git-alt"),
|
||||
}
|
||||
|
||||
if m.License != "" {
|
||||
nodes = append(nodes,
|
||||
cardSpan("License: "+m.License, "fas fa-book"),
|
||||
)
|
||||
}
|
||||
/*
|
||||
tagsNodes := []elem.Node{}
|
||||
|
||||
for _, tag := range m.Tags {
|
||||
tagsNodes = append(tagsNodes,
|
||||
searchableElement(tag, "fas fa-tag"),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
nodes = append(nodes,
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex flex-row flex-wrap content-center",
|
||||
},
|
||||
tagsNodes...,
|
||||
),
|
||||
)
|
||||
|
||||
for i, url := range m.URLs {
|
||||
nodes = append(nodes,
|
||||
buttonLink("Link #"+fmt.Sprintf("%d", i+1), url),
|
||||
)
|
||||
}
|
||||
*/
|
||||
|
||||
progressMessage := "Installation"
|
||||
if isDeletionOp {
|
||||
progressMessage = "Deletion"
|
||||
}
|
||||
|
||||
return elem.Div(
|
||||
attrs.Props{
|
||||
"class": "px-6 pt-4 pb-2",
|
||||
},
|
||||
elem.P(
|
||||
attrs.Props{
|
||||
"class": "mb-4 text-base",
|
||||
},
|
||||
nodes...,
|
||||
),
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"id": "action-div-" + dropBadChars(galleryID),
|
||||
"class": "flow-root", // To order buttons left and right
|
||||
},
|
||||
infoButton(m),
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "float-right",
|
||||
},
|
||||
elem.If(
|
||||
currentlyProcessing,
|
||||
elem.Node( // If currently installing, show progress bar
|
||||
elem.Raw(StartModelProgressBar(jobID, "0", progressMessage)),
|
||||
), // Otherwise, show install button (if not installed) or display "Installed"
|
||||
elem.If(m.Installed,
|
||||
elem.Node(elem.Div(
|
||||
attrs.Props{},
|
||||
reInstallButton(m.ID()),
|
||||
deleteButton(m.ID()),
|
||||
)),
|
||||
// otherwise, show the install button, and the get config button
|
||||
elem.Node(elem.Div(
|
||||
attrs.Props{},
|
||||
getConfigButton(m.ID()),
|
||||
installButton(m.ID()),
|
||||
)),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func ListModels(models []*gallery.GalleryModel, processTracker ProcessTracker, galleryService *services.GalleryService) string {
|
||||
modelsElements := []elem.Node{}
|
||||
|
||||
for _, m := range models {
|
||||
elems := []elem.Node{}
|
||||
|
||||
if m.Icon == "" {
|
||||
m.Icon = noImage
|
||||
}
|
||||
|
||||
divProperties := attrs.Props{
|
||||
"class": "flex justify-center items-center",
|
||||
}
|
||||
|
||||
elems = append(elems,
|
||||
elem.Div(divProperties,
|
||||
elem.A(attrs.Props{
|
||||
"href": "#!",
|
||||
// "class": "justify-center items-center",
|
||||
},
|
||||
elem.Img(attrs.Props{
|
||||
// "class": "rounded-t-lg object-fit object-center h-96",
|
||||
"class": "rounded-t-lg max-h-48 max-w-96 object-cover mt-3",
|
||||
"src": m.Icon,
|
||||
"loading": "lazy",
|
||||
}),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
// Special/corner case: if a model sets Trust Remote Code as required, show a warning
|
||||
// TODO: handle this more generically later
|
||||
_, trustRemoteCodeExists := m.Overrides["trust_remote_code"]
|
||||
if trustRemoteCodeExists {
|
||||
elems = append(elems, elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex justify-center items-center bg-red-500 text-white p-2 rounded-lg mt-2",
|
||||
},
|
||||
elem.I(attrs.Props{
|
||||
"class": "fa-solid fa-circle-exclamation pr-2",
|
||||
}),
|
||||
elem.Text("Attention: Trust Remote Code is required for this model"),
|
||||
))
|
||||
}
|
||||
|
||||
elems = append(elems,
|
||||
modelDescription(m),
|
||||
modelActionItems(m, processTracker, galleryService),
|
||||
)
|
||||
modelsElements = append(modelsElements,
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": " me-4 mb-2 block rounded-lg bg-white shadow-secondary-1 dark:bg-gray-800 dark:bg-surface-dark dark:text-white text-surface pb-2 bg-gray-800/90 border border-gray-700/50 rounded-xl overflow-hidden transition-all duration-300 hover:shadow-lg hover:shadow-blue-900/20 hover:-translate-y-1 hover:border-blue-700/50",
|
||||
},
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
// "class": "p-6",
|
||||
},
|
||||
elems...,
|
||||
),
|
||||
),
|
||||
modelModal(m),
|
||||
)
|
||||
}
|
||||
|
||||
wrapper := elem.Div(attrs.Props{
|
||||
"class": "dark grid grid-cols-1 grid-rows-1 md:grid-cols-3 block rounded-lg shadow-secondary-1 dark:bg-surface-dark",
|
||||
}, modelsElements...)
|
||||
|
||||
return wrapper.Render()
|
||||
}
|
||||
|
||||
func ListBackends(backends []*gallery.GalleryBackend, processTracker ProcessTracker, galleryService *services.GalleryService) string {
|
||||
backendsElements := []elem.Node{}
|
||||
|
||||
for _, b := range backends {
|
||||
elems := []elem.Node{}
|
||||
|
||||
if b.Icon == "" {
|
||||
b.Icon = noImage
|
||||
}
|
||||
|
||||
divProperties := attrs.Props{
|
||||
"class": "flex justify-center items-center",
|
||||
}
|
||||
|
||||
elems = append(elems,
|
||||
elem.Div(divProperties,
|
||||
elem.A(attrs.Props{
|
||||
"href": "#!",
|
||||
},
|
||||
elem.Img(attrs.Props{
|
||||
"class": "rounded-t-lg max-h-48 max-w-96 object-cover mt-3",
|
||||
"src": b.Icon,
|
||||
"loading": "lazy",
|
||||
}),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
elems = append(elems,
|
||||
backendDescription(b),
|
||||
backendActionItems(b, processTracker, galleryService),
|
||||
)
|
||||
backendsElements = append(backendsElements,
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "me-4 mb-2 block rounded-lg bg-white shadow-secondary-1 dark:bg-gray-800 dark:bg-surface-dark dark:text-white text-surface pb-2 bg-gray-800/90 border border-gray-700/50 rounded-xl overflow-hidden transition-all duration-300 hover:shadow-lg hover:shadow-blue-900/20 hover:-translate-y-1 hover:border-blue-700/50",
|
||||
},
|
||||
elem.Div(
|
||||
attrs.Props{},
|
||||
elems...,
|
||||
),
|
||||
),
|
||||
backendModal(b),
|
||||
)
|
||||
}
|
||||
|
||||
wrapper := elem.Div(attrs.Props{
|
||||
"class": "dark grid grid-cols-1 grid-rows-1 md:grid-cols-3 block rounded-lg shadow-secondary-1 dark:bg-surface-dark",
|
||||
}, backendsElements...)
|
||||
|
||||
return wrapper.Render()
|
||||
}
|
||||
|
||||
func backendDescription(b *gallery.GalleryBackend) elem.Node {
|
||||
return elem.Div(
|
||||
attrs.Props{
|
||||
"class": "p-6 text-surface dark:text-white",
|
||||
},
|
||||
elem.H5(
|
||||
attrs.Props{
|
||||
"class": "mb-2 text-xl font-bold leading-tight",
|
||||
},
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(b.Name)),
|
||||
),
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "mb-4 text-sm truncate text-base",
|
||||
},
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(b.Description)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func backendActionItems(b *gallery.GalleryBackend, processTracker ProcessTracker, galleryService *services.GalleryService) elem.Node {
|
||||
galleryID := fmt.Sprintf("%s@%s", b.Gallery.Name, b.Name)
|
||||
currentlyProcessing := processTracker.Exists(galleryID)
|
||||
jobID := ""
|
||||
isDeletionOp := false
|
||||
if currentlyProcessing {
|
||||
status := galleryService.GetStatus(galleryID)
|
||||
if status != nil && status.Deletion {
|
||||
isDeletionOp = true
|
||||
}
|
||||
jobID = processTracker.Get(galleryID)
|
||||
}
|
||||
|
||||
nodes := []elem.Node{
|
||||
cardSpan("Repository: "+b.Gallery.Name, "fa-brands fa-git-alt"),
|
||||
}
|
||||
|
||||
if b.License != "" {
|
||||
nodes = append(nodes,
|
||||
cardSpan("License: "+b.License, "fas fa-book"),
|
||||
)
|
||||
}
|
||||
|
||||
progressMessage := "Installation"
|
||||
if isDeletionOp {
|
||||
progressMessage = "Deletion"
|
||||
}
|
||||
|
||||
return elem.Div(
|
||||
attrs.Props{
|
||||
"class": "px-6 pt-4 pb-2",
|
||||
},
|
||||
elem.P(
|
||||
attrs.Props{
|
||||
"class": "mb-4 text-base",
|
||||
},
|
||||
nodes...,
|
||||
),
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"id": "action-div-" + dropBadChars(galleryID),
|
||||
"class": "flow-root",
|
||||
},
|
||||
backendInfoButton(b),
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "float-right",
|
||||
},
|
||||
elem.If(
|
||||
currentlyProcessing,
|
||||
elem.Node(
|
||||
elem.Raw(StartModelProgressBar(jobID, "0", progressMessage)),
|
||||
),
|
||||
elem.If(b.Installed,
|
||||
elem.Node(elem.Div(
|
||||
attrs.Props{},
|
||||
backendReInstallButton(galleryID),
|
||||
backendDeleteButton(galleryID),
|
||||
)),
|
||||
backendInstallButton(galleryID),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func backendModal(b *gallery.GalleryBackend) elem.Node {
|
||||
urls := []elem.Node{}
|
||||
for _, url := range b.URLs {
|
||||
urls = append(urls,
|
||||
elem.Li(attrs.Props{}, link(url, url)),
|
||||
)
|
||||
}
|
||||
|
||||
tagsNodes := []elem.Node{}
|
||||
for _, tag := range b.Tags {
|
||||
tagsNodes = append(tagsNodes,
|
||||
searchableElement(tag, "fas fa-tag"),
|
||||
)
|
||||
}
|
||||
|
||||
modalID := fmt.Sprintf("modal-%s", dropBadChars(fmt.Sprintf("%s@%s", b.Gallery.Name, b.Name)))
|
||||
|
||||
return elem.Div(
|
||||
attrs.Props{
|
||||
"id": modalID,
|
||||
"tabindex": "-1",
|
||||
"aria-hidden": "true",
|
||||
"class": "hidden fixed top-0 right-0 left-0 z-50 justify-center items-center w-full md:inset-0 h-full max-h-full bg-gray-900/50",
|
||||
},
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "relative p-4 w-full max-w-2xl h-[90vh] mx-auto mt-[5vh]",
|
||||
},
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "relative bg-white rounded-lg shadow dark:bg-gray-700 h-full flex flex-col",
|
||||
},
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex items-center justify-between p-4 md:p-5 border-b rounded-t dark:border-gray-600",
|
||||
},
|
||||
elem.H3(
|
||||
attrs.Props{
|
||||
"class": "text-xl font-semibold text-gray-900 dark:text-white",
|
||||
},
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(b.Name)),
|
||||
),
|
||||
elem.Button(
|
||||
attrs.Props{
|
||||
"class": "text-gray-400 bg-transparent hover:bg-gray-200 hover:text-gray-900 rounded-lg text-sm w-8 h-8 ms-auto inline-flex justify-center items-center dark:hover:bg-gray-600 dark:hover:text-white",
|
||||
"data-modal-hide": modalID,
|
||||
},
|
||||
elem.Raw(
|
||||
`<svg class="w-3 h-3" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 14 14">
|
||||
<path stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="m1 1 6 6m0 0 6 6M7 7l6-6M7 7l-6 6"/>
|
||||
</svg>`,
|
||||
),
|
||||
elem.Span(
|
||||
attrs.Props{
|
||||
"class": "sr-only",
|
||||
},
|
||||
elem.Text("Close modal"),
|
||||
),
|
||||
),
|
||||
),
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "p-4 md:p-5 space-y-4 overflow-y-auto flex-grow",
|
||||
},
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex justify-center items-center",
|
||||
},
|
||||
elem.Img(attrs.Props{
|
||||
"src": b.Icon,
|
||||
"class": "rounded-t-lg max-h-48 max-w-96 object-cover mt-3",
|
||||
"loading": "lazy",
|
||||
}),
|
||||
),
|
||||
elem.P(
|
||||
attrs.Props{
|
||||
"class": "text-base leading-relaxed text-gray-500 dark:text-gray-400",
|
||||
},
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(b.Description)),
|
||||
),
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex flex-wrap gap-2",
|
||||
},
|
||||
tagsNodes...,
|
||||
),
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "text-base leading-relaxed text-gray-500 dark:text-gray-400",
|
||||
},
|
||||
elem.Ul(attrs.Props{}, urls...),
|
||||
),
|
||||
),
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex items-center p-4 md:p-5 border-t border-gray-200 rounded-b dark:border-gray-600",
|
||||
},
|
||||
elem.Button(
|
||||
attrs.Props{
|
||||
"data-modal-hide": modalID,
|
||||
"type": "button",
|
||||
"class": "text-white bg-blue-700 hover:bg-blue-800 focus:ring-4 focus:outline-none focus:ring-blue-300 font-medium rounded-lg text-sm px-5 py-2.5 text-center dark:bg-blue-600 dark:hover:bg-blue-700 dark:focus:ring-blue-800",
|
||||
},
|
||||
elem.Text("Close"),
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func backendInfoButton(b *gallery.GalleryBackend) elem.Node {
|
||||
modalID := fmt.Sprintf("modal-%s", dropBadChars(fmt.Sprintf("%s@%s", b.Gallery.Name, b.Name)))
|
||||
return elem.Button(
|
||||
attrs.Props{
|
||||
"data-twe-ripple-init": "",
|
||||
"data-twe-ripple-color": "light",
|
||||
"class": "inline-flex items-center rounded-lg bg-gray-700 hover:bg-gray-600 px-4 py-2 text-sm font-medium text-white transition duration-300 ease-in-out",
|
||||
"data-modal-target": modalID,
|
||||
"data-modal-toggle": modalID,
|
||||
},
|
||||
elem.P(
|
||||
attrs.Props{
|
||||
"class": "flex items-center",
|
||||
},
|
||||
elem.I(
|
||||
attrs.Props{
|
||||
"class": "fas fa-info-circle pr-2",
|
||||
},
|
||||
),
|
||||
elem.Text("Info"),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
func backendInstallButton(galleryID string) elem.Node {
|
||||
return elem.Button(
|
||||
attrs.Props{
|
||||
"data-twe-ripple-init": "",
|
||||
"data-twe-ripple-color": "light",
|
||||
"class": "float-right inline-flex items-center rounded-lg bg-blue-600 hover:bg-blue-700 px-4 py-2 text-sm font-medium text-white transition duration-300 ease-in-out shadow hover:shadow-lg",
|
||||
"hx-swap": "outerHTML",
|
||||
"hx-post": "browse/install/backend/" + galleryID,
|
||||
},
|
||||
elem.I(
|
||||
attrs.Props{
|
||||
"class": "fa-solid fa-download pr-2",
|
||||
},
|
||||
),
|
||||
elem.Text("Install"),
|
||||
)
|
||||
}
|
||||
|
||||
func backendReInstallButton(galleryID string) elem.Node {
|
||||
return elem.Button(
|
||||
attrs.Props{
|
||||
"data-twe-ripple-init": "",
|
||||
"data-twe-ripple-color": "light",
|
||||
"class": "float-right inline-block rounded bg-primary ml-2 px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-primary-accent-300 hover:shadow-primary-2 focus:bg-primary-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-primary-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
|
||||
"hx-target": "#action-div-" + dropBadChars(galleryID),
|
||||
"hx-swap": "outerHTML",
|
||||
"hx-post": "browse/install/backend/" + galleryID,
|
||||
},
|
||||
elem.I(
|
||||
attrs.Props{
|
||||
"class": "fa-solid fa-arrow-rotate-right pr-2",
|
||||
},
|
||||
),
|
||||
elem.Text("Reinstall"),
|
||||
)
|
||||
}
|
||||
|
||||
func backendDeleteButton(galleryID string) elem.Node {
|
||||
return elem.Button(
|
||||
attrs.Props{
|
||||
"data-twe-ripple-init": "",
|
||||
"data-twe-ripple-color": "light",
|
||||
"hx-confirm": "Are you sure you wish to delete the backend?",
|
||||
"class": "float-right inline-block rounded bg-red-800 px-6 pb-2.5 mb-3 pt-2.5 text-xs font-medium uppercase leading-normal text-white shadow-primary-3 transition duration-150 ease-in-out hover:bg-red-accent-300 hover:shadow-red-2 focus:bg-red-accent-300 focus:shadow-primary-2 focus:outline-none focus:ring-0 active:bg-red-600 active:shadow-primary-2 dark:shadow-black/30 dark:hover:shadow-dark-strong dark:focus:shadow-dark-strong dark:active:shadow-dark-strong",
|
||||
"hx-target": "#action-div-" + dropBadChars(galleryID),
|
||||
"hx-swap": "outerHTML",
|
||||
"hx-post": "browse/delete/backend/" + galleryID,
|
||||
},
|
||||
elem.I(
|
||||
attrs.Props{
|
||||
"class": "fa-solid fa-cancel pr-2",
|
||||
},
|
||||
),
|
||||
elem.Text("Delete"),
|
||||
)
|
||||
}
|
||||
@@ -1,156 +0,0 @@
|
||||
package elements
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/chasefleming/elem-go"
|
||||
"github.com/chasefleming/elem-go/attrs"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
func renderElements(n []elem.Node) string {
|
||||
render := ""
|
||||
for _, r := range n {
|
||||
render += r.Render()
|
||||
}
|
||||
return render
|
||||
}
|
||||
|
||||
func P2PNodeStats(nodes []schema.NodeData) string {
|
||||
online := 0
|
||||
for _, n := range nodes {
|
||||
if n.IsOnline() {
|
||||
online++
|
||||
}
|
||||
}
|
||||
|
||||
class := "text-green-400"
|
||||
if online == 0 {
|
||||
class = "text-red-400"
|
||||
} else if online < len(nodes) {
|
||||
class = "text-yellow-400"
|
||||
}
|
||||
|
||||
nodesElements := []elem.Node{
|
||||
elem.Span(
|
||||
attrs.Props{
|
||||
"class": class + " font-bold text-2xl",
|
||||
},
|
||||
elem.Text(fmt.Sprintf("%d", online)),
|
||||
),
|
||||
elem.Span(
|
||||
attrs.Props{
|
||||
"class": "text-gray-300 text-xl",
|
||||
},
|
||||
elem.Text(fmt.Sprintf("/%d", len(nodes))),
|
||||
),
|
||||
}
|
||||
|
||||
return renderElements(nodesElements)
|
||||
}
|
||||
|
||||
func P2PNodeBoxes(nodes []schema.NodeData) string {
|
||||
if len(nodes) == 0 {
|
||||
return `<div class="col-span-full flex flex-col items-center justify-center py-12 text-center bg-gray-800/50 border border-gray-700/50 rounded-xl">
|
||||
<i class="fas fa-server text-gray-500 text-4xl mb-4"></i>
|
||||
<p class="text-gray-400 text-lg font-medium">No nodes available</p>
|
||||
<p class="text-gray-500 text-sm mt-2">Start some workers to see them here</p>
|
||||
</div>`
|
||||
}
|
||||
|
||||
render := ""
|
||||
for _, n := range nodes {
|
||||
nodeID := bluemonday.StrictPolicy().Sanitize(n.ID)
|
||||
|
||||
// Define status-specific classes
|
||||
statusIconClass := "text-green-400"
|
||||
statusText := "Online"
|
||||
statusTextClass := "text-green-400"
|
||||
cardHoverClass := "hover:shadow-green-500/20 hover:border-green-400/50"
|
||||
|
||||
if !n.IsOnline() {
|
||||
statusIconClass = "text-red-400"
|
||||
statusText = "Offline"
|
||||
statusTextClass = "text-red-400"
|
||||
cardHoverClass = "hover:shadow-red-500/20 hover:border-red-400/50"
|
||||
}
|
||||
|
||||
nodeCard := elem.Div(
|
||||
attrs.Props{
|
||||
"class": "bg-gradient-to-br from-gray-800/90 to-gray-900/80 border border-gray-700/50 rounded-xl p-5 shadow-xl transition-all duration-300 " + cardHoverClass + " backdrop-blur-sm",
|
||||
},
|
||||
// Header with node icon and status
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex items-center justify-between mb-4",
|
||||
},
|
||||
// Node info
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex items-center",
|
||||
},
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "w-10 h-10 bg-blue-500/20 rounded-lg flex items-center justify-center mr-3",
|
||||
},
|
||||
elem.I(
|
||||
attrs.Props{
|
||||
"class": "fas fa-server text-blue-400 text-lg",
|
||||
},
|
||||
),
|
||||
),
|
||||
elem.Div(
|
||||
attrs.Props{},
|
||||
elem.H4(
|
||||
attrs.Props{
|
||||
"class": "text-white font-semibold text-sm",
|
||||
},
|
||||
elem.Text("Node"),
|
||||
),
|
||||
elem.P(
|
||||
attrs.Props{
|
||||
"class": "text-gray-400 text-xs font-mono break-all",
|
||||
},
|
||||
elem.Text(nodeID),
|
||||
),
|
||||
),
|
||||
),
|
||||
// Status badge
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "flex items-center bg-gray-900/50 rounded-full px-3 py-1.5 border border-gray-700/50",
|
||||
},
|
||||
elem.I(
|
||||
attrs.Props{
|
||||
"class": "fas fa-circle animate-pulse " + statusIconClass + " mr-2 text-xs",
|
||||
},
|
||||
),
|
||||
elem.Span(
|
||||
attrs.Props{
|
||||
"class": statusTextClass + " text-xs font-medium",
|
||||
},
|
||||
elem.Text(statusText),
|
||||
),
|
||||
),
|
||||
),
|
||||
// Footer with timestamp
|
||||
elem.Div(
|
||||
attrs.Props{
|
||||
"class": "text-xs text-gray-500 pt-3 border-t border-gray-700/30 flex items-center",
|
||||
},
|
||||
elem.I(
|
||||
attrs.Props{
|
||||
"class": "fas fa-clock mr-2",
|
||||
},
|
||||
),
|
||||
elem.Text("Updated: "+time.Now().UTC().Format("15:04:05")),
|
||||
),
|
||||
)
|
||||
|
||||
render += nodeCard.Render()
|
||||
}
|
||||
|
||||
return render
|
||||
}
|
||||
@@ -1,115 +0,0 @@
|
||||
package elements
|
||||
|
||||
import (
|
||||
"github.com/chasefleming/elem-go"
|
||||
"github.com/chasefleming/elem-go/attrs"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
)
|
||||
|
||||
func DoneModelProgress(galleryID, text string, showDelete bool) string {
|
||||
return elem.Div(
|
||||
attrs.Props{
|
||||
"id": "action-div-" + dropBadChars(galleryID),
|
||||
},
|
||||
elem.H3(
|
||||
attrs.Props{
|
||||
"role": "status",
|
||||
"id": "pblabel",
|
||||
"tabindex": "-1",
|
||||
"autofocus": "",
|
||||
},
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(text)),
|
||||
),
|
||||
elem.If(showDelete, deleteButton(galleryID), reInstallButton(galleryID)),
|
||||
).Render()
|
||||
}
|
||||
|
||||
func DoneBackendProgress(galleryID, text string, showDelete bool) string {
|
||||
return elem.Div(
|
||||
attrs.Props{
|
||||
"id": "action-div-" + dropBadChars(galleryID),
|
||||
},
|
||||
elem.H3(
|
||||
attrs.Props{
|
||||
"role": "status",
|
||||
"id": "pblabel",
|
||||
"tabindex": "-1",
|
||||
"autofocus": "",
|
||||
},
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(text)),
|
||||
),
|
||||
elem.If(showDelete, backendDeleteButton(galleryID), reInstallButton(galleryID)),
|
||||
).Render()
|
||||
}
|
||||
|
||||
func ErrorProgress(err, galleryName string) string {
|
||||
return elem.Div(
|
||||
attrs.Props{},
|
||||
elem.H3(
|
||||
attrs.Props{
|
||||
"role": "status",
|
||||
"id": "pblabel",
|
||||
"tabindex": "-1",
|
||||
"autofocus": "",
|
||||
},
|
||||
elem.Text("Error "+bluemonday.StrictPolicy().Sanitize(err)),
|
||||
),
|
||||
installButton(galleryName),
|
||||
).Render()
|
||||
}
|
||||
|
||||
func ProgressBar(progress string) string {
|
||||
return elem.Div(attrs.Props{
|
||||
"class": "progress",
|
||||
"role": "progressbar",
|
||||
"aria-valuemin": "0",
|
||||
"aria-valuemax": "100",
|
||||
"aria-valuenow": "0",
|
||||
"aria-labelledby": "pblabel",
|
||||
},
|
||||
elem.Div(attrs.Props{
|
||||
"id": "pb",
|
||||
"class": "progress-bar",
|
||||
"style": "width:" + progress + "%",
|
||||
}),
|
||||
).Render()
|
||||
}
|
||||
|
||||
func StartModelProgressBar(uid, progress, text string) string {
|
||||
return progressBar(uid, "browse/job/", progress, text)
|
||||
}
|
||||
|
||||
func StartBackendProgressBar(uid, progress, text string) string {
|
||||
return progressBar(uid, "browse/backend/job/", progress, text)
|
||||
}
|
||||
|
||||
func progressBar(uid, url, progress, text string) string {
|
||||
if progress == "" {
|
||||
progress = "0"
|
||||
}
|
||||
return elem.Div(
|
||||
attrs.Props{
|
||||
"hx-trigger": "done",
|
||||
"hx-get": url + uid,
|
||||
"hx-swap": "outerHTML",
|
||||
"hx-target": "this",
|
||||
},
|
||||
elem.H3(
|
||||
attrs.Props{
|
||||
"role": "status",
|
||||
"id": "pblabel",
|
||||
"tabindex": "-1",
|
||||
"autofocus": "",
|
||||
},
|
||||
elem.Text(bluemonday.StrictPolicy().Sanitize(text)), //Perhaps overly defensive
|
||||
elem.Div(attrs.Props{
|
||||
"hx-get": url + "progress/" + uid,
|
||||
"hx-trigger": "every 600ms",
|
||||
"hx-target": "this",
|
||||
"hx-swap": "innerHTML",
|
||||
},
|
||||
elem.Raw(ProgressBar(progress)),
|
||||
),
|
||||
),
|
||||
).Render()
|
||||
}
|
||||
@@ -1,11 +1,8 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -37,21 +34,19 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
return c.Status(404).JSON(response)
|
||||
}
|
||||
|
||||
configData, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
modelConfigFile := modelConfig.GetModelConfigFile()
|
||||
if modelConfigFile == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to marshal configuration: " + err.Error(),
|
||||
Error: "Model configuration file not found",
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.Status(404).JSON(response)
|
||||
}
|
||||
|
||||
// Marshal the config to JSON for the template
|
||||
configJSON, err := json.Marshal(modelConfig)
|
||||
configData, err := os.ReadFile(modelConfigFile)
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to marshal configuration: " + err.Error(),
|
||||
Error: "Failed to read configuration file: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
}
|
||||
@@ -69,7 +64,6 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Title: "LocalAI - Edit Model " + modelName,
|
||||
ModelName: modelName,
|
||||
Config: &modelConfig,
|
||||
ConfigJSON: string(configJSON),
|
||||
ConfigYAML: string(configData),
|
||||
BaseURL: httpUtils.BaseURL(c),
|
||||
Version: internal.PrintableVersion(),
|
||||
@@ -91,6 +85,15 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Existing model configuration not found",
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
}
|
||||
|
||||
// Get the raw body
|
||||
body := c.Body()
|
||||
if len(body) == 0 {
|
||||
@@ -101,50 +104,16 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
// Check content type to determine how to parse
|
||||
contentType := string(c.Context().Request.Header.ContentType())
|
||||
// Check content to see if it's a valid model config
|
||||
var req config.ModelConfig
|
||||
var err error
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
// Parse JSON
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
|
||||
// Parse YAML
|
||||
if err := yaml.Unmarshal(body, &req); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
} else {
|
||||
// Try to auto-detect format
|
||||
if strings.TrimSpace(string(body))[0] == '{' {
|
||||
// Looks like JSON
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
} else {
|
||||
// Assume YAML
|
||||
if err := yaml.Unmarshal(body, &req); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
// Parse YAML
|
||||
if err := yaml.Unmarshal(body, &req); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
@@ -156,19 +125,6 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
// Load the existing configuration
|
||||
configPath := filepath.Join(appConfig.SystemState.Model.ModelsPath, modelName+".yaml")
|
||||
if err := utils.VerifyPath(modelName+".yaml", appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model configuration not trusted: " + err.Error(),
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
req.SetDefaults()
|
||||
|
||||
// Validate the configuration
|
||||
if !req.Validate() {
|
||||
response := ModelResponse{
|
||||
@@ -179,18 +135,18 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
return c.Status(400).JSON(response)
|
||||
}
|
||||
|
||||
// Create the YAML file
|
||||
yamlData, err := yaml.Marshal(req)
|
||||
if err != nil {
|
||||
// Load the existing configuration
|
||||
configPath := modelConfig.GetModelConfigFile()
|
||||
if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to marshal configuration: " + err.Error(),
|
||||
Error: "Model configuration not trusted: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.Status(404).JSON(response)
|
||||
}
|
||||
|
||||
// Write to file
|
||||
if err := os.WriteFile(configPath, yamlData, 0644); err != nil {
|
||||
// Write new content to file
|
||||
if err := os.WriteFile(configPath, body, 0644); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
|
||||
120
core/http/endpoints/mcp/tools.go
Normal file
120
core/http/endpoints/mcp/tools.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/signals"
|
||||
|
||||
"github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type sessionCache struct {
|
||||
mu sync.Mutex
|
||||
cache map[string][]*mcp.ClientSession
|
||||
}
|
||||
|
||||
var (
|
||||
cache = sessionCache{
|
||||
cache: make(map[string][]*mcp.ClientSession),
|
||||
}
|
||||
|
||||
client = mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil)
|
||||
)
|
||||
|
||||
func SessionsFromMCPConfig(
|
||||
name string,
|
||||
remote config.MCPGenericConfig[config.MCPRemoteServers],
|
||||
stdio config.MCPGenericConfig[config.MCPSTDIOServers],
|
||||
) ([]*mcp.ClientSession, error) {
|
||||
cache.mu.Lock()
|
||||
defer cache.mu.Unlock()
|
||||
|
||||
sessions, exists := cache.cache[name]
|
||||
if exists {
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
allSessions := []*mcp.ClientSession{}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Get the list of all the tools that the Agent will be esposed to
|
||||
for _, server := range remote.Servers {
|
||||
log.Debug().Msgf("[MCP remote server] Configuration : %+v", server)
|
||||
// Create HTTP client with custom roundtripper for bearer token injection
|
||||
httpClient := &http.Client{
|
||||
Timeout: 360 * time.Second,
|
||||
Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport),
|
||||
}
|
||||
|
||||
transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient}
|
||||
mcpSession, err := client.Connect(ctx, transport, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to connect to MCP server %s", server.URL)
|
||||
continue
|
||||
}
|
||||
log.Debug().Msgf("[MCP remote server] Connected to MCP server %s", server.URL)
|
||||
cache.cache[name] = append(cache.cache[name], mcpSession)
|
||||
allSessions = append(allSessions, mcpSession)
|
||||
}
|
||||
|
||||
for _, server := range stdio.Servers {
|
||||
log.Debug().Msgf("[MCP stdio server] Configuration : %+v", server)
|
||||
command := exec.Command(server.Command, server.Args...)
|
||||
command.Env = os.Environ()
|
||||
for key, value := range server.Env {
|
||||
command.Env = append(command.Env, key+"="+value)
|
||||
}
|
||||
transport := &mcp.CommandTransport{Command: command}
|
||||
mcpSession, err := client.Connect(ctx, transport, nil)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Failed to start MCP server %s", command)
|
||||
continue
|
||||
}
|
||||
log.Debug().Msgf("[MCP stdio server] Connected to MCP server %s", command)
|
||||
cache.cache[name] = append(cache.cache[name], mcpSession)
|
||||
allSessions = append(allSessions, mcpSession)
|
||||
}
|
||||
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
for _, session := range allSessions {
|
||||
session.Close()
|
||||
}
|
||||
cancel()
|
||||
})
|
||||
|
||||
return allSessions, nil
|
||||
}
|
||||
|
||||
// bearerTokenRoundTripper is a custom roundtripper that injects a bearer token
|
||||
// into HTTP requests
|
||||
type bearerTokenRoundTripper struct {
|
||||
token string
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface
|
||||
func (rt *bearerTokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if rt.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+rt.token)
|
||||
}
|
||||
return rt.base.RoundTrip(req)
|
||||
}
|
||||
|
||||
// newBearerTokenRoundTripper creates a new roundtripper that injects the given token
|
||||
func newBearerTokenRoundTripper(token string, base http.RoundTripper) http.RoundTripper {
|
||||
if base == nil {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
return &bearerTokenRoundTripper{
|
||||
token: token,
|
||||
base: base,
|
||||
}
|
||||
}
|
||||
@@ -247,6 +247,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
g, err := fs.Grammar(config.FunctionsConfig.GrammarOptions()...)
|
||||
if err == nil {
|
||||
input.Grammar = g
|
||||
} else {
|
||||
log.Error().Err(err).Msg("Failed generating grammar")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -272,7 +274,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
|
||||
// Append the no action function
|
||||
if !config.FunctionsConfig.DisableNoAction {
|
||||
if !config.FunctionsConfig.DisableNoAction && !strictMode {
|
||||
funcs = append(funcs, noActionGrammar)
|
||||
}
|
||||
|
||||
@@ -286,11 +288,15 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...)
|
||||
if err == nil {
|
||||
config.Grammar = g
|
||||
} else {
|
||||
log.Error().Err(err).Msg("Failed generating grammar")
|
||||
}
|
||||
case input.JSONFunctionGrammarObject != nil:
|
||||
g, err := input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarOptions()...)
|
||||
if err == nil {
|
||||
config.Grammar = g
|
||||
} else {
|
||||
log.Error().Err(err).Msg("Failed generating grammar")
|
||||
}
|
||||
default:
|
||||
// Force picking one of the functions by the request
|
||||
|
||||
@@ -28,10 +28,10 @@ import (
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
CompletionTokens: tokenUsage.Completion,
|
||||
@@ -65,6 +65,9 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
// Handle Correlation
|
||||
id := c.Get("X-Correlation-ID", uuid.New().String())
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
|
||||
144
core/http/endpoints/openai/mcp.go
Normal file
144
core/http/endpoints/openai/mcp.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/cogito"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MCPCompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions
|
||||
// @Summary Generate completions for a given prompt and model.
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /mcp/v1/completions [post]
|
||||
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
// We do not support streaming mode (Yet?)
|
||||
return func(c *fiber.Ctx) error {
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
ctx := c.Context()
|
||||
|
||||
// Handle Correlation
|
||||
id := c.Get("X-Correlation-ID", uuid.New().String())
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
|
||||
return fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
|
||||
// Get MCP config from model config
|
||||
remote, stdio := config.MCP.MCPConfigFromYAML()
|
||||
|
||||
// Check if we have tools in cache, or we have to have an initial connection
|
||||
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(sessions) == 0 {
|
||||
return fmt.Errorf("no working MCP servers found")
|
||||
}
|
||||
|
||||
fragment := cogito.NewEmptyFragment()
|
||||
|
||||
for _, message := range input.Messages {
|
||||
fragment = fragment.AddMessage(message.Role, message.StringContent)
|
||||
}
|
||||
|
||||
port := appConfig.APIAddress[strings.LastIndex(appConfig.APIAddress, ":")+1:]
|
||||
apiKey := ""
|
||||
if appConfig.ApiKeys != nil {
|
||||
apiKey = appConfig.ApiKeys[0]
|
||||
}
|
||||
// TODO: instead of connecting to the API, we should just wire this internally
|
||||
// and act like completion.go.
|
||||
// We can do this as cogito expects an interface and we can create one that
|
||||
// we satisfy to just call internally ComputeChoices
|
||||
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
|
||||
|
||||
cogitoOpts := []cogito.Option{
|
||||
cogito.WithStatusCallback(func(s string) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
|
||||
}),
|
||||
cogito.WithContext(ctx),
|
||||
cogito.WithMCPs(sessions...),
|
||||
cogito.WithIterations(3), // default to 3 iterations
|
||||
cogito.WithMaxAttempts(3), // default to 3 attempts
|
||||
cogito.WithForceReasoning(),
|
||||
}
|
||||
|
||||
if config.Agent.EnableReasoning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
|
||||
}
|
||||
|
||||
if config.Agent.EnablePlanning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
|
||||
}
|
||||
|
||||
if config.Agent.EnableMCPPrompts {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
|
||||
}
|
||||
|
||||
if config.Agent.EnablePlanReEvaluator {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
|
||||
}
|
||||
|
||||
if config.Agent.MaxIterations != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithIterations(config.Agent.MaxIterations))
|
||||
}
|
||||
|
||||
if config.Agent.MaxAttempts != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(config.Agent.MaxAttempts))
|
||||
}
|
||||
|
||||
f, err := cogito.ExecuteTools(
|
||||
defaultLLM, fragment,
|
||||
cogitoOpts...,
|
||||
)
|
||||
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err = defaultLLM.Ask(ctx, f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Message: &schema.Message{Role: "assistant", Content: &f.LastMessage().Content}}},
|
||||
Object: "text_completion",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
@@ -256,21 +256,13 @@ func registerRealtime(application *application.Application) func(c *websocket.Co
|
||||
sessions[sessionID] = session
|
||||
sessionLock.Unlock()
|
||||
|
||||
// Send session.created and conversation.created events to the client
|
||||
sendEvent(c, types.SessionCreatedEvent{
|
||||
sendEvent(c, types.TranscriptionSessionCreatedEvent{
|
||||
ServerEventBase: types.ServerEventBase{
|
||||
EventID: "event_TODO",
|
||||
Type: types.ServerEventTypeSessionCreated,
|
||||
Type: types.ServerEventTypeTranscriptionSessionCreated,
|
||||
},
|
||||
Session: session.ToServer(),
|
||||
})
|
||||
sendEvent(c, types.ConversationCreatedEvent{
|
||||
ServerEventBase: types.ServerEventBase{
|
||||
EventID: "event_TODO",
|
||||
Type: types.ServerEventTypeConversationCreated,
|
||||
},
|
||||
Conversation: conversation.ToServer(),
|
||||
})
|
||||
|
||||
var (
|
||||
// mt int
|
||||
|
||||
@@ -704,6 +704,7 @@ const (
|
||||
ServerEventTypeError ServerEventType = "error"
|
||||
ServerEventTypeSessionCreated ServerEventType = "session.created"
|
||||
ServerEventTypeSessionUpdated ServerEventType = "session.updated"
|
||||
ServerEventTypeTranscriptionSessionCreated ServerEventType = "transcription_session.created"
|
||||
ServerEventTypeTranscriptionSessionUpdated ServerEventType = "transcription_session.updated"
|
||||
ServerEventTypeConversationCreated ServerEventType = "conversation.created"
|
||||
ServerEventTypeInputAudioBufferCommitted ServerEventType = "input_audio_buffer.committed"
|
||||
@@ -767,6 +768,15 @@ type SessionCreatedEvent struct {
|
||||
Session ServerSession `json:"session"`
|
||||
}
|
||||
|
||||
// TranscriptionSessionCreatedEvent is the event for session created.
|
||||
// Returned when a transcription session is created.
|
||||
// See https://platform.openai.com/docs/api-reference/realtime-server-events/session/created
|
||||
type TranscriptionSessionCreatedEvent struct {
|
||||
ServerEventBase
|
||||
// The transcription session resource.
|
||||
Session ServerSession `json:"session"`
|
||||
}
|
||||
|
||||
// SessionUpdatedEvent is the event for session updated.
|
||||
// Returned when a session is updated.
|
||||
// See https://platform.openai.com/docs/api-reference/realtime-server-events/session/updated
|
||||
|
||||
136
core/http/endpoints/openai/video.go
Normal file
136
core/http/endpoints/openai/video.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
var raw map[string]interface{}
|
||||
if body := c.Body(); len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &raw)
|
||||
}
|
||||
// Build VideoRequest using shared mapper
|
||||
vr := MapOpenAIToVideo(input, raw)
|
||||
// Place VideoRequest into locals so localai.VideoEndpoint can consume it
|
||||
c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
|
||||
// Delegate to existing localai handler
|
||||
return localai.VideoEndpoint(cl, ml, appConfig)(c)
|
||||
}
|
||||
}
|
||||
|
||||
// VideoEndpoint godoc
|
||||
// @Summary Generate a video from an OpenAI-compatible request
|
||||
// @Description Accepts an OpenAI-style request and delegates to the LocalAI video generator
|
||||
// @Tags openai
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body schema.OpenAIRequest true "OpenAI-style request"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Failure 400 {object} map[string]interface{}
|
||||
// @Router /v1/videos [post]
|
||||
|
||||
func MapOpenAIToVideo(input *schema.OpenAIRequest, raw map[string]interface{}) *schema.VideoRequest {
|
||||
vr := &schema.VideoRequest{}
|
||||
if input == nil {
|
||||
return vr
|
||||
}
|
||||
|
||||
if input.Model != "" {
|
||||
vr.Model = input.Model
|
||||
}
|
||||
|
||||
// Prompt mapping
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
vr.Prompt = p
|
||||
case []interface{}:
|
||||
if len(p) > 0 {
|
||||
if s, ok := p[0].(string); ok {
|
||||
vr.Prompt = s
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Size
|
||||
size := input.Size
|
||||
if size == "" && raw != nil {
|
||||
if v, ok := raw["size"].(string); ok {
|
||||
size = v
|
||||
}
|
||||
}
|
||||
if size != "" {
|
||||
parts := strings.SplitN(size, "x", 2)
|
||||
if len(parts) == 2 {
|
||||
if wi, err := strconv.Atoi(parts[0]); err == nil {
|
||||
vr.Width = int32(wi)
|
||||
}
|
||||
if hi, err := strconv.Atoi(parts[1]); err == nil {
|
||||
vr.Height = int32(hi)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// seconds -> num frames
|
||||
secondsStr := ""
|
||||
if raw != nil {
|
||||
if v, ok := raw["seconds"].(string); ok {
|
||||
secondsStr = v
|
||||
} else if v, ok := raw["seconds"].(float64); ok {
|
||||
secondsStr = fmt.Sprintf("%v", int(v))
|
||||
}
|
||||
}
|
||||
fps := int32(30)
|
||||
if raw != nil {
|
||||
if rawFPS, ok := raw["fps"]; ok {
|
||||
switch rf := rawFPS.(type) {
|
||||
case float64:
|
||||
fps = int32(rf)
|
||||
case string:
|
||||
if fi, err := strconv.Atoi(rf); err == nil {
|
||||
fps = int32(fi)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if secondsStr != "" {
|
||||
if secF, err := strconv.Atoi(secondsStr); err == nil {
|
||||
vr.FPS = fps
|
||||
vr.NumFrames = int32(secF) * fps
|
||||
}
|
||||
}
|
||||
|
||||
// input_reference
|
||||
if raw != nil {
|
||||
if v, ok := raw["input_reference"].(string); ok {
|
||||
vr.StartImage = v
|
||||
}
|
||||
}
|
||||
|
||||
// response format
|
||||
if input.ResponseFormat != nil {
|
||||
if rf, ok := input.ResponseFormat.(string); ok {
|
||||
vr.ResponseFormat = rf
|
||||
}
|
||||
}
|
||||
|
||||
if input.Step != 0 {
|
||||
vr.Step = int32(input.Step)
|
||||
}
|
||||
|
||||
return vr
|
||||
}
|
||||
@@ -3,12 +3,14 @@ package middleware
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/dave-gray101/v2keyauth"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/keyauth"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
|
||||
@@ -40,6 +42,19 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er
|
||||
if applicationConfig.OpaqueErrors {
|
||||
return ctx.SendStatus(401)
|
||||
}
|
||||
|
||||
// Check if the request content type is JSON
|
||||
contentType := string(ctx.Context().Request.Header.ContentType())
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
return ctx.Status(401).JSON(schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "An authentication key is required",
|
||||
Code: 401,
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return ctx.Status(401).Render("views/login", fiber.Map{
|
||||
"BaseURL": utils.BaseURL(ctx),
|
||||
})
|
||||
|
||||
75
core/http/openai_mapping_test.go
Normal file
75
core/http/openai_mapping_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
openai "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("MapOpenAIToVideo", func() {
|
||||
It("maps size and seconds correctly", func() {
|
||||
cases := []struct {
|
||||
name string
|
||||
input *schema.OpenAIRequest
|
||||
raw map[string]interface{}
|
||||
expectsW int32
|
||||
expectsH int32
|
||||
expectsF int32
|
||||
expectsN int32
|
||||
}{
|
||||
{
|
||||
name: "size in input",
|
||||
input: &schema.OpenAIRequest{
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: "m"},
|
||||
},
|
||||
Size: "256x128",
|
||||
},
|
||||
expectsW: 256,
|
||||
expectsH: 128,
|
||||
},
|
||||
{
|
||||
name: "size in raw and seconds as string",
|
||||
input: &schema.OpenAIRequest{PredictionOptions: schema.PredictionOptions{BasicModelRequest: schema.BasicModelRequest{Model: "m"}}},
|
||||
raw: map[string]interface{}{"size": "720x480", "seconds": "2"},
|
||||
expectsW: 720,
|
||||
expectsH: 480,
|
||||
expectsF: 30,
|
||||
expectsN: 60,
|
||||
},
|
||||
{
|
||||
name: "seconds as number and fps override",
|
||||
input: &schema.OpenAIRequest{PredictionOptions: schema.PredictionOptions{BasicModelRequest: schema.BasicModelRequest{Model: "m"}}},
|
||||
raw: map[string]interface{}{"seconds": 3.0, "fps": 24.0},
|
||||
expectsF: 24,
|
||||
expectsN: 72,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
By(c.name)
|
||||
vr := openai.MapOpenAIToVideo(c.input, c.raw)
|
||||
if c.expectsW != 0 {
|
||||
Expect(vr.Width).To(Equal(c.expectsW))
|
||||
}
|
||||
if c.expectsH != 0 {
|
||||
Expect(vr.Height).To(Equal(c.expectsH))
|
||||
}
|
||||
if c.expectsF != 0 {
|
||||
Expect(vr.FPS).To(Equal(c.expectsF))
|
||||
}
|
||||
if c.expectsN != 0 {
|
||||
Expect(vr.NumFrames).To(Equal(c.expectsN))
|
||||
}
|
||||
|
||||
b, err := json.Marshal(vr)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_ = b
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
162
core/http/openai_videos_test.go
Normal file
162
core/http/openai_videos_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package http_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"fmt"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
const testAPIKey = "joshua"
|
||||
|
||||
type fakeAI struct{}
|
||||
|
||||
func (f *fakeAI) Busy() bool { return false }
|
||||
func (f *fakeAI) Lock() {}
|
||||
func (f *fakeAI) Unlock() {}
|
||||
func (f *fakeAI) Locking() bool { return false }
|
||||
func (f *fakeAI) Predict(*pb.PredictOptions) (string, error) { return "", nil }
|
||||
func (f *fakeAI) PredictStream(*pb.PredictOptions, chan string) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeAI) Load(*pb.ModelOptions) error { return nil }
|
||||
func (f *fakeAI) Embeddings(*pb.PredictOptions) ([]float32, error) { return nil, nil }
|
||||
func (f *fakeAI) GenerateImage(*pb.GenerateImageRequest) error { return nil }
|
||||
func (f *fakeAI) GenerateVideo(*pb.GenerateVideoRequest) error { return nil }
|
||||
func (f *fakeAI) Detect(*pb.DetectOptions) (pb.DetectResponse, error) { return pb.DetectResponse{}, nil }
|
||||
func (f *fakeAI) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
|
||||
return pb.TranscriptResult{}, nil
|
||||
}
|
||||
func (f *fakeAI) TTS(*pb.TTSRequest) error { return nil }
|
||||
func (f *fakeAI) SoundGeneration(*pb.SoundGenerationRequest) error { return nil }
|
||||
func (f *fakeAI) TokenizeString(*pb.PredictOptions) (pb.TokenizationResponse, error) {
|
||||
return pb.TokenizationResponse{}, nil
|
||||
}
|
||||
func (f *fakeAI) Status() (pb.StatusResponse, error) { return pb.StatusResponse{}, nil }
|
||||
func (f *fakeAI) StoresSet(*pb.StoresSetOptions) error { return nil }
|
||||
func (f *fakeAI) StoresDelete(*pb.StoresDeleteOptions) error { return nil }
|
||||
func (f *fakeAI) StoresGet(*pb.StoresGetOptions) (pb.StoresGetResult, error) {
|
||||
return pb.StoresGetResult{}, nil
|
||||
}
|
||||
func (f *fakeAI) StoresFind(*pb.StoresFindOptions) (pb.StoresFindResult, error) {
|
||||
return pb.StoresFindResult{}, nil
|
||||
}
|
||||
func (f *fakeAI) VAD(*pb.VADRequest) (pb.VADResponse, error) { return pb.VADResponse{}, nil }
|
||||
|
||||
var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
|
||||
var tmpdir string
|
||||
var appServer *application.Application
|
||||
var app *fiber.App
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmpdir, err = os.MkdirTemp("", "")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
modelDir := filepath.Join(tmpdir, "models")
|
||||
err = os.Mkdir(modelDir, 0750)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(modelDir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
grpc.Provide("embedded://fake", &fakeAI{})
|
||||
|
||||
appServer, err = application.New(
|
||||
config.WithContext(ctx),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithApiKeys([]string{testAPIKey}),
|
||||
config.WithGeneratedContentDir(tmpdir),
|
||||
config.WithExternalBackend("fake", "embedded://fake"),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
_ = app.Shutdown()
|
||||
}
|
||||
_ = os.RemoveAll(tmpdir)
|
||||
})
|
||||
|
||||
It("accepts OpenAI-style video create and delegates to backend", func() {
|
||||
var err error
|
||||
app, err = API(appServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9091")
|
||||
|
||||
// wait for server
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
Eventually(func() error {
|
||||
req, _ := http.NewRequest("GET", "http://127.0.0.1:9091/v1/models", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+testAPIKey)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return fmt.Errorf("bad status: %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}, "30s", "500ms").Should(Succeed())
|
||||
|
||||
body := map[string]interface{}{
|
||||
"model": "fake-model",
|
||||
"backend": "fake",
|
||||
"prompt": "a test video",
|
||||
"size": "256x256",
|
||||
"seconds": "1",
|
||||
}
|
||||
payload, err := json.Marshal(body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
req, err := http.NewRequest("POST", "http://127.0.0.1:9091/v1/videos", bytes.NewBuffer(payload))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+testAPIKey)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer resp.Body.Close()
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
dat, err := io.ReadAll(resp.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var out map[string]interface{}
|
||||
err = json.Unmarshal(dat, &out)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
data, ok := out["data"].([]interface{})
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(len(data)).To(BeNumerically(">", 0))
|
||||
first := data[0].(map[string]interface{})
|
||||
url, ok := first["url"].(string)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(url).To(ContainSubstring("/generated-videos/"))
|
||||
Expect(url).To(ContainSubstring(".mp4"))
|
||||
})
|
||||
})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user