mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-24 08:38:02 -04:00
Compare commits
318 Commits
v3.5.0
...
feat/stats
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eebda7204e | ||
|
|
ba1b8e7757 | ||
|
|
79b68fdc25 | ||
|
|
a946cb08b5 | ||
|
|
d95d4992fe | ||
|
|
e13cb8346d | ||
|
|
615c56503e | ||
|
|
79a8edd8b9 | ||
|
|
8d138dd68f | ||
|
|
2b33844562 | ||
|
|
63e6721c2f | ||
|
|
4859d809aa | ||
|
|
be027b1ccd | ||
|
|
3ecadeeb93 | ||
|
|
4af3348f91 | ||
|
|
dde08845bf | ||
|
|
76d1ba168d | ||
|
|
80605e4f66 | ||
|
|
5b99584a31 | ||
|
|
fc134b18fe | ||
|
|
c2006273c5 | ||
|
|
5343889098 | ||
|
|
c42afc56d9 | ||
|
|
53f44dac89 | ||
|
|
0468456fad | ||
|
|
df899ee26a | ||
|
|
93fe25468f | ||
|
|
238aad666e | ||
|
|
4408ed4f88 | ||
|
|
5df1f59a3c | ||
|
|
8225697139 | ||
|
|
0c0186d866 | ||
|
|
ce2f8828f9 | ||
|
|
7a8565a45e | ||
|
|
192589a17f | ||
|
|
28ab73d4a1 | ||
|
|
ed4ac0b61e | ||
|
|
e41d8b65ce | ||
|
|
c28e5b39d6 | ||
|
|
b66bd2706f | ||
|
|
fa7a9d96f8 | ||
|
|
61d972a2ef | ||
|
|
fffdbc31c6 | ||
|
|
32c0ab3a7f | ||
|
|
24ce79a67c | ||
|
|
bfa8530088 | ||
|
|
4278144dd5 | ||
|
|
79fa4d691e | ||
|
|
7a3d9ee5c1 | ||
|
|
22923d3b23 | ||
|
|
d32a459209 | ||
|
|
47b2a502dd | ||
|
|
b85f339eb4 | ||
|
|
8821865eac | ||
|
|
4b30846d57 | ||
|
|
7a35986407 | ||
|
|
ee34aa7bd5 | ||
|
|
40cf798dfe | ||
|
|
18810038f5 | ||
|
|
8fb79bc6f6 | ||
|
|
4b5ad1405f | ||
|
|
4493078cdd | ||
|
|
7f68c89cbe | ||
|
|
69adc46936 | ||
|
|
d22439918f | ||
|
|
103d4e87e5 | ||
|
|
8c5ba9e0d7 | ||
|
|
f1b713df08 | ||
|
|
f94b89c1b5 | ||
|
|
a1b056737a | ||
|
|
a22f6a499d | ||
|
|
e5bf2a9a11 | ||
|
|
05aba5a311 | ||
|
|
354bf5debb | ||
|
|
7f88abb3b1 | ||
|
|
36b3a538f8 | ||
|
|
e293b65ad9 | ||
|
|
cce185b345 | ||
|
|
03ed4382c7 | ||
|
|
1c73e10676 | ||
|
|
4ade65f959 | ||
|
|
c54f5cdf12 | ||
|
|
33c48164d7 | ||
|
|
7aed3b3bac | ||
|
|
9e349c715e | ||
|
|
639ecb59b3 | ||
|
|
bfb0794f87 | ||
|
|
05f1e9e757 | ||
|
|
1ca6f6dada | ||
|
|
bc5397bcfc | ||
|
|
f452a027a2 | ||
|
|
7bac49fb87 | ||
|
|
02300cfbd1 | ||
|
|
17c5c732c7 | ||
|
|
10a66938f9 | ||
|
|
f0245fa36c | ||
|
|
83534f8e00 | ||
|
|
75eaf8c853 | ||
|
|
03096154d4 | ||
|
|
22c9e8c09e | ||
|
|
da16727ad6 | ||
|
|
ad44df6d83 | ||
|
|
276c552583 | ||
|
|
9109e5c149 | ||
|
|
71a84b91e3 | ||
|
|
209d40be71 | ||
|
|
bfd76805e8 | ||
|
|
561aa5e443 | ||
|
|
b0eb1ab2a1 | ||
|
|
1208fb6fa1 | ||
|
|
f98fe85c42 | ||
|
|
167c183c84 | ||
|
|
244e47e1e0 | ||
|
|
9680a0b0fe | ||
|
|
acbd10a661 | ||
|
|
c6b989be13 | ||
|
|
670103705c | ||
|
|
cb90bd226e | ||
|
|
df9b2abf84 | ||
|
|
582114bda9 | ||
|
|
91ffe5ac38 | ||
|
|
8a58d76254 | ||
|
|
c3442fe574 | ||
|
|
1087bd217e | ||
|
|
7ed3666d2e | ||
|
|
2e2e89e499 | ||
|
|
13c9c20f42 | ||
|
|
b3d3988d85 | ||
|
|
0529c7d0a0 | ||
|
|
af31a77061 | ||
|
|
2d8956167f | ||
|
|
509f85f82c | ||
|
|
bb2b377b18 | ||
|
|
48917889ce | ||
|
|
ef754259b0 | ||
|
|
7e26f28113 | ||
|
|
d7c8129549 | ||
|
|
3a8fbb698e | ||
|
|
b1ef34ef9f | ||
|
|
b7822250fe | ||
|
|
05055f7e95 | ||
|
|
c856d7dc73 | ||
|
|
69d565e55d | ||
|
|
fa6bbd9fa2 | ||
|
|
3f767121d2 | ||
|
|
e963e16bc5 | ||
|
|
1e9b115251 | ||
|
|
cd1e1124ea | ||
|
|
81b31b4283 | ||
|
|
d763bce46d | ||
|
|
4aac0ef42e | ||
|
|
7a36e8d967 | ||
|
|
dc2be93412 | ||
|
|
69a2b91495 | ||
|
|
791bc769c1 | ||
|
|
a15a1f07e3 | ||
|
|
c6f0b44228 | ||
|
|
cb0ed55d89 | ||
|
|
2fe97110fd | ||
|
|
fa8037b21d | ||
|
|
99a72a4b11 | ||
|
|
1a52ce1bd4 | ||
|
|
925d752f8d | ||
|
|
c0b9d00f35 | ||
|
|
fcf8d41a00 | ||
|
|
27c4161401 | ||
|
|
459b6ab86d | ||
|
|
336257cc3c | ||
|
|
df46a438b8 | ||
|
|
5e1d809904 | ||
|
|
a9c7ce7275 | ||
|
|
8c47c8c8ed | ||
|
|
8e8d427549 | ||
|
|
ee251115f4 | ||
|
|
661e66090c | ||
|
|
c38564e22c | ||
|
|
20f1e842b3 | ||
|
|
aa8965b634 | ||
|
|
35c676188b | ||
|
|
183559bb98 | ||
|
|
1123a5c49c | ||
|
|
6f17c260a7 | ||
|
|
da6278aae9 | ||
|
|
2e51871ad5 | ||
|
|
8067d25710 | ||
|
|
cb2df6c5bf | ||
|
|
07e1519b3f | ||
|
|
8fc41673fa | ||
|
|
fff0e5911b | ||
|
|
09346bdc06 | ||
|
|
d4d42740c8 | ||
|
|
5de7a43319 | ||
|
|
85e27ec74c | ||
|
|
698205a2f3 | ||
|
|
3ed582b091 | ||
|
|
752e33f676 | ||
|
|
930553ef60 | ||
|
|
fc8d5c9198 | ||
|
|
60b6472fa0 | ||
|
|
6b2c8277c2 | ||
|
|
6d5d3ebcf6 | ||
|
|
530c174fd3 | ||
|
|
8fb95686af | ||
|
|
4132085c01 | ||
|
|
c14f1ffcfd | ||
|
|
07cca4b69a | ||
|
|
dd927c36f6 | ||
|
|
052f42e926 | ||
|
|
30d43588ab | ||
|
|
d21ec22f74 | ||
|
|
04fecd634a | ||
|
|
33c14198db | ||
|
|
967c2727e3 | ||
|
|
f41f30ad92 | ||
|
|
e77340e8a5 | ||
|
|
d51a3090f7 | ||
|
|
1bf3bc932c | ||
|
|
564a47da4e | ||
|
|
c37ee93ff2 | ||
|
|
f4b65db4e7 | ||
|
|
f5fa8e6649 | ||
|
|
570e39bdcf | ||
|
|
2ebe37b671 | ||
|
|
dca685f784 | ||
|
|
84ebf2a2c9 | ||
|
|
ce5662ba90 | ||
|
|
9878f27813 | ||
|
|
f2b9452ec4 | ||
|
|
585da99c52 | ||
|
|
fd4f432079 | ||
|
|
238c68c57b | ||
|
|
04fbf5cb82 | ||
|
|
c85d559919 | ||
|
|
b5efc4f89e | ||
|
|
3f9c09a4c5 | ||
|
|
4a84660475 | ||
|
|
737248256e | ||
|
|
0ae334fc62 | ||
|
|
36c373b7c9 | ||
|
|
6afcb932b7 | ||
|
|
357bf571a3 | ||
|
|
e74ade9ebb | ||
|
|
f7f26b8efa | ||
|
|
75eb98f8bd | ||
|
|
c337e7baf7 | ||
|
|
660bd45be8 | ||
|
|
c27da0a0f6 | ||
|
|
ac043ed9ba | ||
|
|
2e0d66a1c8 | ||
|
|
41a0f361eb | ||
|
|
d3c5c02837 | ||
|
|
ae3d8fb0c4 | ||
|
|
902e47f0b0 | ||
|
|
50bb78fd24 | ||
|
|
542f07ab2d | ||
|
|
77c5acb9db | ||
|
|
44bbf4d778 | ||
|
|
633c12f93d | ||
|
|
6f24135f1d | ||
|
|
b72aa7b4fa | ||
|
|
e94e725479 | ||
|
|
e4ac7b14a3 | ||
|
|
ddb39c73f2 | ||
|
|
264b09fb1e | ||
|
|
36dd45df51 | ||
|
|
e5599f87b8 | ||
|
|
e89b5cc0e3 | ||
|
|
10bf1084cc | ||
|
|
b08ae559b3 | ||
|
|
aa7cb7e18c | ||
|
|
eadd3d4e46 | ||
|
|
2a18206033 | ||
|
|
39798d734e | ||
|
|
d0e99562af | ||
|
|
6410c99bf2 | ||
|
|
55766d269b | ||
|
|
ffa0ad1eac | ||
|
|
623789a29e | ||
|
|
2b9a3d32c9 | ||
|
|
f8b71dc5d0 | ||
|
|
1d3331b5cb | ||
|
|
2c0b9c6349 | ||
|
|
3c6c976755 | ||
|
|
ebbcba342a | ||
|
|
0de75519dc | ||
|
|
37f5e4f5c1 | ||
|
|
ffa934b959 | ||
|
|
59311d8b1e | ||
|
|
d9e25af7b5 | ||
|
|
e4f8b63b40 | ||
|
|
1364ae9be6 | ||
|
|
cfd6a9150d | ||
|
|
cd352d0c5f | ||
|
|
8d47309695 | ||
|
|
5f6fc02a55 | ||
|
|
0b528458d8 | ||
|
|
caab380c5d | ||
|
|
8a3a362504 | ||
|
|
07238eb743 | ||
|
|
e905e90dd7 | ||
|
|
08432d49e5 | ||
|
|
e51e2aacb9 | ||
|
|
9c3d85fc28 | ||
|
|
007ca647a7 | ||
|
|
59af928379 | ||
|
|
dbc2bb561b | ||
|
|
c72c85dcac | ||
|
|
ef984901e6 | ||
|
|
9911ec84a3 | ||
|
|
1956681d4c | ||
|
|
326f6e5ccb | ||
|
|
302958efd6 | ||
|
|
3dc86b247d | ||
|
|
5ec724af06 | ||
|
|
1f1e156bf0 | ||
|
|
df625e366a | ||
|
|
9e6685ac9c | ||
|
|
90c818aa71 |
288
.github/gallery-agent/agent.go
vendored
Normal file
288
.github/gallery-agent/agent.go
vendored
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||||
|
"github.com/mudler/cogito"
|
||||||
|
|
||||||
|
"github.com/mudler/cogito/structures"
|
||||||
|
"github.com/sashabaranov/go-openai/jsonschema"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
openAIModel = os.Getenv("OPENAI_MODEL")
|
||||||
|
openAIKey = os.Getenv("OPENAI_KEY")
|
||||||
|
openAIBaseURL = os.Getenv("OPENAI_BASE_URL")
|
||||||
|
galleryIndexPath = os.Getenv("GALLERY_INDEX_PATH")
|
||||||
|
//defaultclient
|
||||||
|
llm = cogito.NewOpenAILLM(openAIModel, openAIKey, openAIBaseURL)
|
||||||
|
)
|
||||||
|
|
||||||
|
// cleanTextContent removes trailing spaces, tabs, and normalizes line endings
|
||||||
|
// to prevent YAML linting issues like trailing spaces and multiple empty lines
|
||||||
|
func cleanTextContent(text string) string {
|
||||||
|
lines := strings.Split(text, "\n")
|
||||||
|
var cleanedLines []string
|
||||||
|
var prevEmpty bool
|
||||||
|
for _, line := range lines {
|
||||||
|
// Remove all trailing whitespace (spaces, tabs, etc.)
|
||||||
|
trimmed := strings.TrimRight(line, " \t\r")
|
||||||
|
// Avoid multiple consecutive empty lines
|
||||||
|
if trimmed == "" {
|
||||||
|
if !prevEmpty {
|
||||||
|
cleanedLines = append(cleanedLines, "")
|
||||||
|
}
|
||||||
|
prevEmpty = true
|
||||||
|
} else {
|
||||||
|
cleanedLines = append(cleanedLines, trimmed)
|
||||||
|
prevEmpty = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Remove trailing empty lines from the result
|
||||||
|
result := strings.Join(cleanedLines, "\n")
|
||||||
|
return strings.TrimRight(result, "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isModelExisting checks if a specific model ID exists in the gallery using text search
|
||||||
|
func isModelExisting(modelID string) (bool, error) {
|
||||||
|
indexPath := getGalleryIndexPath()
|
||||||
|
content, err := os.ReadFile(indexPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to read %s: %w", indexPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
contentStr := string(content)
|
||||||
|
// Simple text search - if the model ID appears anywhere in the file, it exists
|
||||||
|
return strings.Contains(contentStr, modelID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterExistingModels removes models that already exist in the gallery
|
||||||
|
func filterExistingModels(models []ProcessedModel) ([]ProcessedModel, error) {
|
||||||
|
var filteredModels []ProcessedModel
|
||||||
|
for _, model := range models {
|
||||||
|
exists, err := isModelExisting(model.ModelID)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error checking if model %s exists: %v, skipping\n", model.ModelID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
filteredModels = append(filteredModels, model)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Skipping existing model: %s\n", model.ModelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Filtered out %d existing models, %d new models remaining\n",
|
||||||
|
len(models)-len(filteredModels), len(filteredModels))
|
||||||
|
|
||||||
|
return filteredModels, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getGalleryIndexPath returns the gallery index file path, with a default fallback
|
||||||
|
func getGalleryIndexPath() string {
|
||||||
|
if galleryIndexPath != "" {
|
||||||
|
return galleryIndexPath
|
||||||
|
}
|
||||||
|
return "gallery/index.yaml"
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRealReadme(ctx context.Context, repository string) (string, error) {
|
||||||
|
// Create a conversation fragment
|
||||||
|
fragment := cogito.NewEmptyFragment().
|
||||||
|
AddMessage("user",
|
||||||
|
`Your task is to get a clear description of a large language model from huggingface by using the provided tool. I will share with you a repository that might be quantized, and as such probably not by the original model author. We need to get the real description of the model, and not the one that might be quantized. You will have to call the tool to get the readme more than once by figuring out from the quantized readme which is the base model readme. This is the repository: `+repository)
|
||||||
|
|
||||||
|
// Execute with tools
|
||||||
|
result, err := cogito.ExecuteTools(llm, fragment,
|
||||||
|
cogito.WithIterations(3),
|
||||||
|
cogito.WithMaxAttempts(3),
|
||||||
|
cogito.WithTools(&HFReadmeTool{client: hfapi.NewClient()}))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
result = result.AddMessage("user", "Describe the model in a clear and concise way that can be shared in a model gallery.")
|
||||||
|
|
||||||
|
// Get a response
|
||||||
|
newFragment, err := llm.Ask(ctx, result)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
content := newFragment.LastMessage().Content
|
||||||
|
return cleanTextContent(content), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectMostInterestingModels(ctx context.Context, searchResult *SearchResult) ([]ProcessedModel, error) {
|
||||||
|
// Create a conversation fragment
|
||||||
|
fragment := cogito.NewEmptyFragment().
|
||||||
|
AddMessage("user",
|
||||||
|
`Your task is to analyze a list of AI models and select the most interesting ones for a model gallery. You will be given detailed information about multiple models including their metadata, file information, and README content.
|
||||||
|
|
||||||
|
Consider the following criteria when selecting models:
|
||||||
|
1. Model popularity (download count)
|
||||||
|
2. Model recency (last modified date)
|
||||||
|
3. Model completeness (has preferred model file, README, etc.)
|
||||||
|
4. Model uniqueness (not duplicates or very similar models)
|
||||||
|
5. Model quality (based on README content and description)
|
||||||
|
6. Model utility (practical applications)
|
||||||
|
|
||||||
|
You should select models that would be most valuable for users browsing a model gallery. Prioritize models that are:
|
||||||
|
- Well-documented with clear READMEs
|
||||||
|
- Recently updated
|
||||||
|
- Popular (high download count)
|
||||||
|
- Have the preferred quantization format available
|
||||||
|
- Offer unique capabilities or are from reputable authors
|
||||||
|
|
||||||
|
Return your analysis and selection reasoning.`)
|
||||||
|
|
||||||
|
// Add the search results as context
|
||||||
|
modelsInfo := fmt.Sprintf("Found %d models matching '%s' with quantization preference '%s':\n\n",
|
||||||
|
searchResult.TotalModelsFound, searchResult.SearchTerm, searchResult.Quantization)
|
||||||
|
|
||||||
|
for i, model := range searchResult.Models {
|
||||||
|
modelsInfo += fmt.Sprintf("Model %d:\n", i+1)
|
||||||
|
modelsInfo += fmt.Sprintf(" ID: %s\n", model.ModelID)
|
||||||
|
modelsInfo += fmt.Sprintf(" Author: %s\n", model.Author)
|
||||||
|
modelsInfo += fmt.Sprintf(" Downloads: %d\n", model.Downloads)
|
||||||
|
modelsInfo += fmt.Sprintf(" Last Modified: %s\n", model.LastModified)
|
||||||
|
modelsInfo += fmt.Sprintf(" Files: %d files\n", len(model.Files))
|
||||||
|
|
||||||
|
if model.PreferredModelFile != nil {
|
||||||
|
modelsInfo += fmt.Sprintf(" Preferred Model File: %s (%d bytes)\n",
|
||||||
|
model.PreferredModelFile.Path, model.PreferredModelFile.Size)
|
||||||
|
} else {
|
||||||
|
modelsInfo += " No preferred model file found\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
if model.ReadmeContent != "" {
|
||||||
|
modelsInfo += fmt.Sprintf(" README: %s\n", model.ReadmeContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
if model.ProcessingError != "" {
|
||||||
|
modelsInfo += fmt.Sprintf(" Processing Error: %s\n", model.ProcessingError)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelsInfo += "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
fragment = fragment.AddMessage("user", modelsInfo)
|
||||||
|
|
||||||
|
fragment = fragment.AddMessage("user", "Based on your analysis, select the top 5 most interesting models and provide a brief explanation for each selection. Also, create a filtered SearchResult with only the selected models. Return just a list of repositories IDs, you will later be asked to output it as a JSON array with the json tool.")
|
||||||
|
|
||||||
|
// Get a response
|
||||||
|
newFragment, err := llm.Ask(ctx, fragment)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(newFragment.LastMessage().Content)
|
||||||
|
repositories := struct {
|
||||||
|
Repositories []string `json:"repositories"`
|
||||||
|
}{}
|
||||||
|
|
||||||
|
s := structures.Structure{
|
||||||
|
Schema: jsonschema.Definition{
|
||||||
|
Type: jsonschema.Object,
|
||||||
|
AdditionalProperties: false,
|
||||||
|
Properties: map[string]jsonschema.Definition{
|
||||||
|
"repositories": {
|
||||||
|
Type: jsonschema.Array,
|
||||||
|
Items: &jsonschema.Definition{Type: jsonschema.String},
|
||||||
|
Description: "The trending repositories IDs",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"repositories"},
|
||||||
|
},
|
||||||
|
Object: &repositories,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = newFragment.ExtractStructure(ctx, llm, s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredModels := []ProcessedModel{}
|
||||||
|
for _, m := range searchResult.Models {
|
||||||
|
if slices.Contains(repositories.Repositories, m.ModelID) {
|
||||||
|
filteredModels = append(filteredModels, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredModels, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelFamily represents a YAML anchor/family
|
||||||
|
type ModelFamily struct {
|
||||||
|
Anchor string `json:"anchor"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectModelFamily selects the appropriate model family/anchor for a given model
|
||||||
|
func selectModelFamily(ctx context.Context, model ProcessedModel, availableFamilies []ModelFamily) (string, error) {
|
||||||
|
// Create a conversation fragment
|
||||||
|
fragment := cogito.NewEmptyFragment().
|
||||||
|
AddMessage("user",
|
||||||
|
`Your task is to select the most appropriate model family/anchor for a given AI model. You will be provided with:
|
||||||
|
1. Information about the model (name, description, etc.)
|
||||||
|
2. A list of available model families/anchors
|
||||||
|
|
||||||
|
You need to select the family that best matches the model's architecture, capabilities, or characteristics. Consider:
|
||||||
|
- Model architecture (e.g., Llama, Qwen, Mistral, etc.)
|
||||||
|
- Model capabilities (e.g., vision, coding, chat, etc.)
|
||||||
|
- Model size/type (e.g., small, medium, large)
|
||||||
|
- Model purpose (e.g., general purpose, specialized, etc.)
|
||||||
|
|
||||||
|
Return the anchor name that best fits the model.`)
|
||||||
|
|
||||||
|
// Add model information
|
||||||
|
modelInfo := "Model Information:\n"
|
||||||
|
modelInfo += fmt.Sprintf(" ID: %s\n", model.ModelID)
|
||||||
|
modelInfo += fmt.Sprintf(" Author: %s\n", model.Author)
|
||||||
|
modelInfo += fmt.Sprintf(" Downloads: %d\n", model.Downloads)
|
||||||
|
modelInfo += fmt.Sprintf(" Description: %s\n", model.ReadmeContentPreview)
|
||||||
|
|
||||||
|
fragment = fragment.AddMessage("user", modelInfo)
|
||||||
|
|
||||||
|
// Add available families
|
||||||
|
familiesInfo := "Available Model Families:\n"
|
||||||
|
for _, family := range availableFamilies {
|
||||||
|
familiesInfo += fmt.Sprintf(" - %s (%s)\n", family.Anchor, family.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
fragment = fragment.AddMessage("user", familiesInfo)
|
||||||
|
fragment = fragment.AddMessage("user", "Select the most appropriate family anchor for this model. Return just the anchor name.")
|
||||||
|
|
||||||
|
// Get a response
|
||||||
|
newFragment, err := llm.Ask(ctx, fragment)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the selected family
|
||||||
|
selectedFamily := strings.TrimSpace(newFragment.LastMessage().Content)
|
||||||
|
|
||||||
|
// Validate that the selected family exists in our list
|
||||||
|
for _, family := range availableFamilies {
|
||||||
|
if family.Anchor == selectedFamily {
|
||||||
|
return selectedFamily, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no exact match, try to find a close match
|
||||||
|
for _, family := range availableFamilies {
|
||||||
|
if strings.Contains(strings.ToLower(family.Anchor), strings.ToLower(selectedFamily)) ||
|
||||||
|
strings.Contains(strings.ToLower(selectedFamily), strings.ToLower(family.Anchor)) {
|
||||||
|
return family.Anchor, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default fallback
|
||||||
|
return "llama3", nil
|
||||||
|
}
|
||||||
203
.github/gallery-agent/gallery.go
vendored
Normal file
203
.github/gallery-agent/gallery.go
vendored
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// generateYAMLEntry generates a YAML entry for a model using the specified anchor
|
||||||
|
func generateYAMLEntry(model ProcessedModel, familyAnchor string) string {
|
||||||
|
// Extract model name from ModelID
|
||||||
|
parts := strings.Split(model.ModelID, "/")
|
||||||
|
modelName := model.ModelID
|
||||||
|
if len(parts) > 0 {
|
||||||
|
modelName = strings.ToLower(parts[len(parts)-1])
|
||||||
|
}
|
||||||
|
// Remove common suffixes
|
||||||
|
modelName = strings.ReplaceAll(modelName, "-gguf", "")
|
||||||
|
modelName = strings.ReplaceAll(modelName, "-q4_k_m", "")
|
||||||
|
modelName = strings.ReplaceAll(modelName, "-q4_k_s", "")
|
||||||
|
modelName = strings.ReplaceAll(modelName, "-q3_k_m", "")
|
||||||
|
modelName = strings.ReplaceAll(modelName, "-q2_k", "")
|
||||||
|
|
||||||
|
fileName := ""
|
||||||
|
checksum := ""
|
||||||
|
if model.PreferredModelFile != nil {
|
||||||
|
fileParts := strings.Split(model.PreferredModelFile.Path, "/")
|
||||||
|
if len(fileParts) > 0 {
|
||||||
|
fileName = fileParts[len(fileParts)-1]
|
||||||
|
}
|
||||||
|
checksum = model.PreferredModelFile.SHA256
|
||||||
|
} else {
|
||||||
|
fileName = model.ModelID
|
||||||
|
}
|
||||||
|
|
||||||
|
description := model.ReadmeContent
|
||||||
|
if description == "" {
|
||||||
|
description = fmt.Sprintf("AI model: %s", modelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up description to prevent YAML linting issues
|
||||||
|
description = cleanTextContent(description)
|
||||||
|
|
||||||
|
// Format description for YAML (indent each line and ensure no trailing spaces)
|
||||||
|
lines := strings.Split(description, "\n")
|
||||||
|
var formattedLines []string
|
||||||
|
for _, line := range lines {
|
||||||
|
if strings.TrimSpace(line) == "" {
|
||||||
|
// Keep empty lines as empty (no indentation)
|
||||||
|
formattedLines = append(formattedLines, "")
|
||||||
|
} else {
|
||||||
|
// Add indentation to non-empty lines
|
||||||
|
formattedLines = append(formattedLines, " "+line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
formattedDescription := strings.Join(formattedLines, "\n")
|
||||||
|
// Remove any trailing spaces from the formatted description
|
||||||
|
formattedDescription = strings.TrimRight(formattedDescription, " \t")
|
||||||
|
yamlTemplate := ""
|
||||||
|
if checksum != "" {
|
||||||
|
yamlTemplate = `- !!merge <<: *%s
|
||||||
|
name: "%s"
|
||||||
|
urls:
|
||||||
|
- https://huggingface.co/%s
|
||||||
|
description: |
|
||||||
|
%s
|
||||||
|
overrides:
|
||||||
|
parameters:
|
||||||
|
model: %s
|
||||||
|
files:
|
||||||
|
- filename: %s
|
||||||
|
sha256: %s
|
||||||
|
uri: huggingface://%s/%s`
|
||||||
|
return fmt.Sprintf(yamlTemplate,
|
||||||
|
familyAnchor,
|
||||||
|
modelName,
|
||||||
|
model.ModelID,
|
||||||
|
formattedDescription,
|
||||||
|
fileName,
|
||||||
|
fileName,
|
||||||
|
checksum,
|
||||||
|
model.ModelID,
|
||||||
|
fileName,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
yamlTemplate = `- !!merge <<: *%s
|
||||||
|
name: "%s"
|
||||||
|
urls:
|
||||||
|
- https://huggingface.co/%s
|
||||||
|
description: |
|
||||||
|
%s
|
||||||
|
overrides:
|
||||||
|
parameters:
|
||||||
|
model: %s`
|
||||||
|
return fmt.Sprintf(yamlTemplate,
|
||||||
|
familyAnchor,
|
||||||
|
modelName,
|
||||||
|
model.ModelID,
|
||||||
|
formattedDescription,
|
||||||
|
fileName,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractModelFamilies extracts all YAML anchors from the gallery index.yaml file
|
||||||
|
func extractModelFamilies() ([]ModelFamily, error) {
|
||||||
|
// Read the index.yaml file
|
||||||
|
indexPath := getGalleryIndexPath()
|
||||||
|
content, err := os.ReadFile(indexPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read %s: %w", indexPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := strings.Split(string(content), "\n")
|
||||||
|
var families []ModelFamily
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
// Look for YAML anchors (lines starting with "- &")
|
||||||
|
if strings.HasPrefix(line, "- &") {
|
||||||
|
// Extract the anchor name (everything after "- &")
|
||||||
|
anchor := strings.TrimPrefix(line, "- &")
|
||||||
|
// Remove any trailing colon or other characters
|
||||||
|
anchor = strings.Split(anchor, ":")[0]
|
||||||
|
anchor = strings.Split(anchor, " ")[0]
|
||||||
|
|
||||||
|
if anchor != "" {
|
||||||
|
families = append(families, ModelFamily{
|
||||||
|
Anchor: anchor,
|
||||||
|
Name: anchor, // Use anchor as name for now
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return families, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateYAMLForModels generates YAML entries for selected models and appends to index.yaml
|
||||||
|
func generateYAMLForModels(ctx context.Context, models []ProcessedModel) error {
|
||||||
|
// Extract available model families
|
||||||
|
families, err := extractModelFamilies()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to extract model families: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Found %d model families: %v\n", len(families),
|
||||||
|
func() []string {
|
||||||
|
var names []string
|
||||||
|
for _, f := range families {
|
||||||
|
names = append(names, f.Anchor)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}())
|
||||||
|
|
||||||
|
// Generate YAML entries for each model
|
||||||
|
var yamlEntries []string
|
||||||
|
for _, model := range models {
|
||||||
|
fmt.Printf("Selecting family for model: %s\n", model.ModelID)
|
||||||
|
|
||||||
|
// Select appropriate family for this model
|
||||||
|
familyAnchor, err := selectModelFamily(ctx, model, families)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error selecting family for %s: %v, using default\n", model.ModelID, err)
|
||||||
|
familyAnchor = "llama3" // Default fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Selected family '%s' for model %s\n", familyAnchor, model.ModelID)
|
||||||
|
|
||||||
|
// Generate YAML entry
|
||||||
|
yamlEntry := generateYAMLEntry(model, familyAnchor)
|
||||||
|
yamlEntries = append(yamlEntries, yamlEntry)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append to index.yaml
|
||||||
|
if len(yamlEntries) > 0 {
|
||||||
|
indexPath := getGalleryIndexPath()
|
||||||
|
fmt.Printf("Appending YAML entries to %s...\n", indexPath)
|
||||||
|
|
||||||
|
// Read current content
|
||||||
|
content, err := os.ReadFile(indexPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read %s: %w", indexPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append new entries
|
||||||
|
// Remove trailing whitespace from existing content and join entries without extra newlines
|
||||||
|
existingContent := strings.TrimRight(string(content), " \t\n\r")
|
||||||
|
yamlBlock := strings.Join(yamlEntries, "\n")
|
||||||
|
newContent := existingContent + "\n" + yamlBlock + "\n"
|
||||||
|
|
||||||
|
// Write back to file
|
||||||
|
err = os.WriteFile(indexPath, []byte(newContent), 0644)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to write %s: %w", indexPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Successfully added %d models to %s\n", len(yamlEntries), indexPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
39
.github/gallery-agent/go.mod
vendored
Normal file
39
.github/gallery-agent/go.mod
vendored
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
module github.com/go-skynet/LocalAI/.github/gallery-agent
|
||||||
|
|
||||||
|
go 1.24.1
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/mudler/cogito v0.3.0
|
||||||
|
github.com/onsi/ginkgo/v2 v2.25.3
|
||||||
|
github.com/onsi/gomega v1.38.2
|
||||||
|
github.com/sashabaranov/go-openai v1.41.2
|
||||||
|
github.com/tmc/langchaingo v0.1.13
|
||||||
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
dario.cat/mergo v1.0.1 // indirect
|
||||||
|
github.com/Masterminds/goutils v1.1.1 // indirect
|
||||||
|
github.com/Masterminds/semver/v3 v3.4.0 // indirect
|
||||||
|
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
|
||||||
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
|
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||||
|
github.com/google/go-cmp v0.7.0 // indirect
|
||||||
|
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||||
|
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect
|
||||||
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
|
github.com/huandu/xstrings v1.5.0 // indirect
|
||||||
|
github.com/mitchellh/copystructure v1.2.0 // indirect
|
||||||
|
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||||
|
github.com/modelcontextprotocol/go-sdk v1.0.0 // indirect
|
||||||
|
github.com/shopspring/decimal v1.4.0 // indirect
|
||||||
|
github.com/spf13/cast v1.7.0 // indirect
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||||
|
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||||
|
golang.org/x/crypto v0.41.0 // indirect
|
||||||
|
golang.org/x/net v0.43.0 // indirect
|
||||||
|
golang.org/x/sys v0.35.0 // indirect
|
||||||
|
golang.org/x/text v0.28.0 // indirect
|
||||||
|
golang.org/x/tools v0.36.0 // indirect
|
||||||
|
)
|
||||||
168
.github/gallery-agent/go.sum
vendored
Normal file
168
.github/gallery-agent/go.sum
vendored
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
||||||
|
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||||
|
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
|
||||||
|
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||||
|
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
|
||||||
|
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
|
||||||
|
github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0=
|
||||||
|
github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
|
||||||
|
github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs=
|
||||||
|
github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
|
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
|
||||||
|
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
|
||||||
|
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||||
|
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||||
|
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||||
|
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||||
|
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||||
|
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||||
|
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||||
|
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||||
|
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||||
|
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
|
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||||
|
github.com/docker/docker v28.2.2+incompatible h1:CjwRSksz8Yo4+RmQ339Dp/D2tGO5JxwYeqtMOEe0LDw=
|
||||||
|
github.com/docker/docker v28.2.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||||
|
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||||
|
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||||
|
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||||
|
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||||
|
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||||
|
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||||
|
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||||
|
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||||
|
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||||
|
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||||
|
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
|
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||||
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
|
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
|
||||||
|
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
|
||||||
|
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||||
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
|
||||||
|
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||||
|
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8=
|
||||||
|
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||||
|
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||||
|
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||||
|
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||||
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||||
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||||
|
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||||
|
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||||
|
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
|
||||||
|
github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
|
||||||
|
github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ=
|
||||||
|
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||||
|
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||||
|
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
||||||
|
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
|
||||||
|
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
||||||
|
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
||||||
|
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||||
|
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||||
|
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
|
||||||
|
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
||||||
|
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
||||||
|
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||||
|
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||||
|
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||||
|
github.com/modelcontextprotocol/go-sdk v1.0.0 h1:Z4MSjLi38bTgLrd/LjSmofqRqyBiVKRyQSJgw8q8V74=
|
||||||
|
github.com/modelcontextprotocol/go-sdk v1.0.0/go.mod h1:nYtYQroQ2KQiM0/SbyEPUWQ6xs4B95gJjEalc9AQyOs=
|
||||||
|
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||||
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
|
github.com/mudler/cogito v0.3.0 h1:NbVAO3bLkK5oGSY0xq87jlz8C9OIsLW55s+8Hfzeu9s=
|
||||||
|
github.com/mudler/cogito v0.3.0/go.mod h1:abMwl+CUjCp87IufA2quZdZt0bbLaHHN79o17HbUKxU=
|
||||||
|
github.com/onsi/ginkgo/v2 v2.25.3 h1:Ty8+Yi/ayDAGtk4XxmmfUy4GabvM+MegeB4cDLRi6nw=
|
||||||
|
github.com/onsi/ginkgo/v2 v2.25.3/go.mod h1:43uiyQC4Ed2tkOzLsEYm7hnrb7UJTWHYNsuy3bG/snE=
|
||||||
|
github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A=
|
||||||
|
github.com/onsi/gomega v1.38.2/go.mod h1:W2MJcYxRGV63b418Ai34Ud0hEdTVXq9NW9+Sx6uXf3k=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||||
|
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||||
|
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||||
|
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||||
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||||
|
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||||
|
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||||
|
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||||
|
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||||
|
github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM=
|
||||||
|
github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||||
|
github.com/shirou/gopsutil/v4 v4.25.5 h1:rtd9piuSMGeU8g1RMXjZs9y9luK5BwtnG7dZaQUJAsc=
|
||||||
|
github.com/shirou/gopsutil/v4 v4.25.5/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||||
|
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||||
|
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||||
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
|
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
||||||
|
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/testcontainers/testcontainers-go v0.38.0 h1:d7uEapLcv2P8AvH8ahLqDMMxda2W9gQN1nRbHS28HBw=
|
||||||
|
github.com/testcontainers/testcontainers-go v0.38.0/go.mod h1:C52c9MoHpWO+C4aqmgSU+hxlR5jlEayWtgYrb8Pzz1w=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||||
|
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||||
|
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||||
|
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||||
|
github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA=
|
||||||
|
github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg=
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||||
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
|
||||||
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
|
||||||
|
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||||
|
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||||
|
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||||
|
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||||
|
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||||
|
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||||
|
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
||||||
|
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||||
|
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||||
|
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||||
|
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||||
|
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||||
|
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||||
|
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||||
|
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
|
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||||
|
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||||
|
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
|
||||||
|
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
|
||||||
|
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||||
|
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
299
.github/gallery-agent/hfapi/client.go
vendored
Normal file
299
.github/gallery-agent/hfapi/client.go
vendored
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
package hfapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Model represents a model from the Hugging Face API
|
||||||
|
type Model struct {
|
||||||
|
ModelID string `json:"modelId"`
|
||||||
|
Author string `json:"author"`
|
||||||
|
Downloads int `json:"downloads"`
|
||||||
|
LastModified string `json:"lastModified"`
|
||||||
|
PipelineTag string `json:"pipelineTag"`
|
||||||
|
Private bool `json:"private"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
CreatedAt string `json:"createdAt"`
|
||||||
|
UpdatedAt string `json:"updatedAt"`
|
||||||
|
Sha string `json:"sha"`
|
||||||
|
Config map[string]interface{} `json:"config"`
|
||||||
|
ModelIndex string `json:"model_index"`
|
||||||
|
LibraryName string `json:"library_name"`
|
||||||
|
MaskToken string `json:"mask_token"`
|
||||||
|
TokenizerClass string `json:"tokenizer_class"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileInfo represents file information from HuggingFace
|
||||||
|
type FileInfo struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Oid string `json:"oid"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
LFS *LFSInfo `json:"lfs,omitempty"`
|
||||||
|
XetHash string `json:"xetHash,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LFSInfo represents LFS (Large File Storage) information
|
||||||
|
type LFSInfo struct {
|
||||||
|
Oid string `json:"oid"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
PointerSize int `json:"pointerSize"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelFile represents a file in a model repository
|
||||||
|
type ModelFile struct {
|
||||||
|
Path string
|
||||||
|
Size int64
|
||||||
|
SHA256 string
|
||||||
|
IsReadme bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelDetails represents detailed information about a model
|
||||||
|
type ModelDetails struct {
|
||||||
|
ModelID string
|
||||||
|
Author string
|
||||||
|
Files []ModelFile
|
||||||
|
ReadmeFile *ModelFile
|
||||||
|
ReadmeContent string
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchParams represents the parameters for searching models
|
||||||
|
type SearchParams struct {
|
||||||
|
Sort string `json:"sort"`
|
||||||
|
Direction int `json:"direction"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
Search string `json:"search"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client represents a Hugging Face API client
|
||||||
|
type Client struct {
|
||||||
|
baseURL string
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new Hugging Face API client
|
||||||
|
func NewClient() *Client {
|
||||||
|
return &Client{
|
||||||
|
baseURL: "https://huggingface.co/api/models",
|
||||||
|
client: &http.Client{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchModels searches for models using the Hugging Face API
|
||||||
|
func (c *Client) SearchModels(params SearchParams) ([]Model, error) {
|
||||||
|
req, err := http.NewRequest("GET", c.baseURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add query parameters
|
||||||
|
q := req.URL.Query()
|
||||||
|
q.Add("sort", params.Sort)
|
||||||
|
q.Add("direction", fmt.Sprintf("%d", params.Direction))
|
||||||
|
q.Add("limit", fmt.Sprintf("%d", params.Limit))
|
||||||
|
q.Add("search", params.Search)
|
||||||
|
req.URL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
// Make the HTTP request
|
||||||
|
resp, err := c.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("failed to fetch models. Status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the response body
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the JSON response
|
||||||
|
var models []Model
|
||||||
|
if err := json.Unmarshal(body, &models); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLatest fetches the latest GGUF models
|
||||||
|
func (c *Client) GetLatest(searchTerm string, limit int) ([]Model, error) {
|
||||||
|
params := SearchParams{
|
||||||
|
Sort: "lastModified",
|
||||||
|
Direction: -1,
|
||||||
|
Limit: limit,
|
||||||
|
Search: searchTerm,
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.SearchModels(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaseURL returns the current base URL
|
||||||
|
func (c *Client) BaseURL() string {
|
||||||
|
return c.baseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBaseURL sets a new base URL (useful for testing)
|
||||||
|
func (c *Client) SetBaseURL(url string) {
|
||||||
|
c.baseURL = url
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFiles lists all files in a HuggingFace repository
|
||||||
|
func (c *Client) ListFiles(repoID string) ([]FileInfo, error) {
|
||||||
|
baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
|
||||||
|
url := fmt.Sprintf("%s/api/models/%s/tree/main", baseURL, repoID)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to make request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("failed to fetch files. Status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var files []FileInfo
|
||||||
|
if err := json.Unmarshal(body, &files); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse JSON response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return files, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFileSHA gets the SHA256 checksum for a specific file by searching through the file list
|
||||||
|
func (c *Client) GetFileSHA(repoID, fileName string) (string, error) {
|
||||||
|
files, err := c.ListFiles(repoID)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to list files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, file := range files {
|
||||||
|
if filepath.Base(file.Path) == fileName {
|
||||||
|
if file.LFS != nil && file.LFS.Oid != "" {
|
||||||
|
// The LFS OID contains the SHA256 hash
|
||||||
|
return file.LFS.Oid, nil
|
||||||
|
}
|
||||||
|
// If no LFS, return the regular OID
|
||||||
|
return file.Oid, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("file %s not found", fileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelDetails gets detailed information about a model including files and checksums
|
||||||
|
func (c *Client) GetModelDetails(repoID string) (*ModelDetails, error) {
|
||||||
|
files, err := c.ListFiles(repoID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
details := &ModelDetails{
|
||||||
|
ModelID: repoID,
|
||||||
|
Author: strings.Split(repoID, "/")[0],
|
||||||
|
Files: make([]ModelFile, 0, len(files)),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process each file
|
||||||
|
for _, file := range files {
|
||||||
|
fileName := filepath.Base(file.Path)
|
||||||
|
isReadme := strings.Contains(strings.ToLower(fileName), "readme")
|
||||||
|
|
||||||
|
// Extract SHA256 from LFS or use OID
|
||||||
|
sha256 := ""
|
||||||
|
if file.LFS != nil && file.LFS.Oid != "" {
|
||||||
|
sha256 = file.LFS.Oid
|
||||||
|
} else {
|
||||||
|
sha256 = file.Oid
|
||||||
|
}
|
||||||
|
|
||||||
|
modelFile := ModelFile{
|
||||||
|
Path: file.Path,
|
||||||
|
Size: file.Size,
|
||||||
|
SHA256: sha256,
|
||||||
|
IsReadme: isReadme,
|
||||||
|
}
|
||||||
|
|
||||||
|
details.Files = append(details.Files, modelFile)
|
||||||
|
|
||||||
|
// Set the readme file
|
||||||
|
if isReadme && details.ReadmeFile == nil {
|
||||||
|
details.ReadmeFile = &modelFile
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return details, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetReadmeContent gets the content of a README file
|
||||||
|
func (c *Client) GetReadmeContent(repoID, readmePath string) (string, error) {
|
||||||
|
baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
|
||||||
|
url := fmt.Sprintf("%s/%s/raw/main/%s", baseURL, repoID, readmePath)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to make request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("failed to fetch readme content. Status code: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(body), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterFilesByQuantization filters files by quantization type
|
||||||
|
func FilterFilesByQuantization(files []ModelFile, quantization string) []ModelFile {
|
||||||
|
var filtered []ModelFile
|
||||||
|
for _, file := range files {
|
||||||
|
fileName := filepath.Base(file.Path)
|
||||||
|
if strings.Contains(strings.ToLower(fileName), strings.ToLower(quantization)) {
|
||||||
|
filtered = append(filtered, file)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
// FindPreferredModelFile finds the preferred model file based on quantization preferences
|
||||||
|
func FindPreferredModelFile(files []ModelFile, preferences []string) *ModelFile {
|
||||||
|
for _, preference := range preferences {
|
||||||
|
for i := range files {
|
||||||
|
fileName := filepath.Base(files[i].Path)
|
||||||
|
if strings.Contains(strings.ToLower(fileName), strings.ToLower(preference)) {
|
||||||
|
return &files[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
511
.github/gallery-agent/hfapi/client_test.go
vendored
Normal file
511
.github/gallery-agent/hfapi/client_test.go
vendored
Normal file
@@ -0,0 +1,511 @@
|
|||||||
|
package hfapi_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("HuggingFace API Client", func() {
|
||||||
|
var (
|
||||||
|
client *hfapi.Client
|
||||||
|
server *httptest.Server
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
client = hfapi.NewClient()
|
||||||
|
})
|
||||||
|
|
||||||
|
AfterEach(func() {
|
||||||
|
if server != nil {
|
||||||
|
server.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when creating a new client", func() {
|
||||||
|
It("should initialize with correct base URL", func() {
|
||||||
|
Expect(client).ToNot(BeNil())
|
||||||
|
Expect(client.BaseURL()).To(Equal("https://huggingface.co/api/models"))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when searching for models", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
// Mock response data
|
||||||
|
mockResponse := `[
|
||||||
|
{
|
||||||
|
"modelId": "test-model-1",
|
||||||
|
"author": "test-author",
|
||||||
|
"downloads": 1000,
|
||||||
|
"lastModified": "2024-01-01T00:00:00.000Z",
|
||||||
|
"pipelineTag": "text-generation",
|
||||||
|
"private": false,
|
||||||
|
"tags": ["gguf", "llama"],
|
||||||
|
"createdAt": "2024-01-01T00:00:00.000Z",
|
||||||
|
"updatedAt": "2024-01-01T00:00:00.000Z",
|
||||||
|
"sha": "abc123",
|
||||||
|
"config": {},
|
||||||
|
"model_index": "test-index",
|
||||||
|
"library_name": "transformers",
|
||||||
|
"mask_token": null,
|
||||||
|
"tokenizer_class": "LlamaTokenizer"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"modelId": "test-model-2",
|
||||||
|
"author": "test-author-2",
|
||||||
|
"downloads": 2000,
|
||||||
|
"lastModified": "2024-01-02T00:00:00.000Z",
|
||||||
|
"pipelineTag": "text-generation",
|
||||||
|
"private": false,
|
||||||
|
"tags": ["gguf", "mistral"],
|
||||||
|
"createdAt": "2024-01-02T00:00:00.000Z",
|
||||||
|
"updatedAt": "2024-01-02T00:00:00.000Z",
|
||||||
|
"sha": "def456",
|
||||||
|
"config": {},
|
||||||
|
"model_index": "test-index-2",
|
||||||
|
"library_name": "transformers",
|
||||||
|
"mask_token": null,
|
||||||
|
"tokenizer_class": "MistralTokenizer"
|
||||||
|
}
|
||||||
|
]`
|
||||||
|
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Verify request parameters
|
||||||
|
Expect(r.URL.Query().Get("sort")).To(Equal("lastModified"))
|
||||||
|
Expect(r.URL.Query().Get("direction")).To(Equal("-1"))
|
||||||
|
Expect(r.URL.Query().Get("limit")).To(Equal("30"))
|
||||||
|
Expect(r.URL.Query().Get("search")).To(Equal("GGUF"))
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(mockResponse))
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Override the client's base URL to use our mock server
|
||||||
|
client.SetBaseURL(server.URL)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should successfully search for models", func() {
|
||||||
|
params := hfapi.SearchParams{
|
||||||
|
Sort: "lastModified",
|
||||||
|
Direction: -1,
|
||||||
|
Limit: 30,
|
||||||
|
Search: "GGUF",
|
||||||
|
}
|
||||||
|
|
||||||
|
models, err := client.SearchModels(params)
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(models).To(HaveLen(2))
|
||||||
|
|
||||||
|
// Verify first model
|
||||||
|
Expect(models[0].ModelID).To(Equal("test-model-1"))
|
||||||
|
Expect(models[0].Author).To(Equal("test-author"))
|
||||||
|
Expect(models[0].Downloads).To(Equal(1000))
|
||||||
|
Expect(models[0].PipelineTag).To(Equal("text-generation"))
|
||||||
|
Expect(models[0].Private).To(BeFalse())
|
||||||
|
Expect(models[0].Tags).To(ContainElements("gguf", "llama"))
|
||||||
|
|
||||||
|
// Verify second model
|
||||||
|
Expect(models[1].ModelID).To(Equal("test-model-2"))
|
||||||
|
Expect(models[1].Author).To(Equal("test-author-2"))
|
||||||
|
Expect(models[1].Downloads).To(Equal(2000))
|
||||||
|
Expect(models[1].Tags).To(ContainElements("gguf", "mistral"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should handle empty search results", func() {
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("[]"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
client.SetBaseURL(server.URL)
|
||||||
|
|
||||||
|
params := hfapi.SearchParams{
|
||||||
|
Sort: "lastModified",
|
||||||
|
Direction: -1,
|
||||||
|
Limit: 30,
|
||||||
|
Search: "nonexistent",
|
||||||
|
}
|
||||||
|
|
||||||
|
models, err := client.SearchModels(params)
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(models).To(HaveLen(0))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should handle HTTP errors", func() {
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
w.Write([]byte("Internal Server Error"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
client.SetBaseURL(server.URL)
|
||||||
|
|
||||||
|
params := hfapi.SearchParams{
|
||||||
|
Sort: "lastModified",
|
||||||
|
Direction: -1,
|
||||||
|
Limit: 30,
|
||||||
|
Search: "GGUF",
|
||||||
|
}
|
||||||
|
|
||||||
|
models, err := client.SearchModels(params)
|
||||||
|
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring("Status code: 500"))
|
||||||
|
Expect(models).To(BeNil())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should handle malformed JSON response", func() {
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("invalid json"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
client.SetBaseURL(server.URL)
|
||||||
|
|
||||||
|
params := hfapi.SearchParams{
|
||||||
|
Sort: "lastModified",
|
||||||
|
Direction: -1,
|
||||||
|
Limit: 30,
|
||||||
|
Search: "GGUF",
|
||||||
|
}
|
||||||
|
|
||||||
|
models, err := client.SearchModels(params)
|
||||||
|
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring("failed to parse JSON response"))
|
||||||
|
Expect(models).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when getting latest GGUF models", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
mockResponse := `[
|
||||||
|
{
|
||||||
|
"modelId": "latest-gguf-model",
|
||||||
|
"author": "gguf-author",
|
||||||
|
"downloads": 5000,
|
||||||
|
"lastModified": "2024-01-03T00:00:00.000Z",
|
||||||
|
"pipelineTag": "text-generation",
|
||||||
|
"private": false,
|
||||||
|
"tags": ["gguf", "latest"],
|
||||||
|
"createdAt": "2024-01-03T00:00:00.000Z",
|
||||||
|
"updatedAt": "2024-01-03T00:00:00.000Z",
|
||||||
|
"sha": "latest123",
|
||||||
|
"config": {},
|
||||||
|
"model_index": "latest-index",
|
||||||
|
"library_name": "transformers",
|
||||||
|
"mask_token": null,
|
||||||
|
"tokenizer_class": "LlamaTokenizer"
|
||||||
|
}
|
||||||
|
]`
|
||||||
|
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Verify the search parameters are correct for GGUF search
|
||||||
|
Expect(r.URL.Query().Get("search")).To(Equal("GGUF"))
|
||||||
|
Expect(r.URL.Query().Get("sort")).To(Equal("lastModified"))
|
||||||
|
Expect(r.URL.Query().Get("direction")).To(Equal("-1"))
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(mockResponse))
|
||||||
|
}))
|
||||||
|
|
||||||
|
client.SetBaseURL(server.URL)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should fetch latest GGUF models with correct parameters", func() {
|
||||||
|
models, err := client.GetLatest("GGUF", 10)
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(models).To(HaveLen(1))
|
||||||
|
Expect(models[0].ModelID).To(Equal("latest-gguf-model"))
|
||||||
|
Expect(models[0].Author).To(Equal("gguf-author"))
|
||||||
|
Expect(models[0].Downloads).To(Equal(5000))
|
||||||
|
Expect(models[0].Tags).To(ContainElements("gguf", "latest"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should use custom search term", func() {
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
Expect(r.URL.Query().Get("search")).To(Equal("custom-search"))
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("[]"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
client.SetBaseURL(server.URL)
|
||||||
|
|
||||||
|
models, err := client.GetLatest("custom-search", 5)
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(models).To(HaveLen(0))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when handling network errors", func() {
|
||||||
|
It("should handle connection failures gracefully", func() {
|
||||||
|
// Use an invalid URL to simulate connection failure
|
||||||
|
client.SetBaseURL("http://invalid-url-that-does-not-exist")
|
||||||
|
|
||||||
|
params := hfapi.SearchParams{
|
||||||
|
Sort: "lastModified",
|
||||||
|
Direction: -1,
|
||||||
|
Limit: 30,
|
||||||
|
Search: "GGUF",
|
||||||
|
}
|
||||||
|
|
||||||
|
models, err := client.SearchModels(params)
|
||||||
|
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring("failed to make request"))
|
||||||
|
Expect(models).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when listing files", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
mockFilesResponse := `[
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"path": "model-Q4_K_M.gguf",
|
||||||
|
"size": 1000000,
|
||||||
|
"oid": "abc123",
|
||||||
|
"lfs": {
|
||||||
|
"oid": "def456789",
|
||||||
|
"size": 1000000,
|
||||||
|
"pointerSize": 135
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"path": "README.md",
|
||||||
|
"size": 5000,
|
||||||
|
"oid": "readme123"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "file",
|
||||||
|
"path": "config.json",
|
||||||
|
"size": 1000,
|
||||||
|
"oid": "config123"
|
||||||
|
}
|
||||||
|
]`
|
||||||
|
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.Contains(r.URL.Path, "/tree/main") {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(mockFilesResponse))
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
client.SetBaseURL(server.URL)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should list files successfully", func() {
|
||||||
|
files, err := client.ListFiles("test/model")
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(files).To(HaveLen(3))
|
||||||
|
|
||||||
|
Expect(files[0].Path).To(Equal("model-Q4_K_M.gguf"))
|
||||||
|
Expect(files[0].Size).To(Equal(int64(1000000)))
|
||||||
|
Expect(files[0].LFS).ToNot(BeNil())
|
||||||
|
Expect(files[0].LFS.Oid).To(Equal("def456789"))
|
||||||
|
|
||||||
|
Expect(files[1].Path).To(Equal("README.md"))
|
||||||
|
Expect(files[1].Size).To(Equal(int64(5000)))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when getting file SHA", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
mockFileInfoResponse := `{
|
||||||
|
"path": "model-Q4_K_M.gguf",
|
||||||
|
"size": 1000000,
|
||||||
|
"oid": "abc123",
|
||||||
|
"lfs": {
|
||||||
|
"oid": "sha256:def456",
|
||||||
|
"size": 1000000,
|
||||||
|
"pointer": "version https://git-lfs.github.com/spec/v1",
|
||||||
|
"sha256": "def456789"
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.Contains(r.URL.Path, "/paths-info") {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(mockFileInfoResponse))
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
client.SetBaseURL(server.URL)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should get file SHA successfully", func() {
|
||||||
|
sha, err := client.GetFileSHA("test/model", "model-Q4_K_M.gguf")
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(sha).To(Equal("def456789"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should handle missing SHA gracefully", func() {
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(`{"path": "file.txt", "size": 100}`))
|
||||||
|
}))
|
||||||
|
|
||||||
|
client.SetBaseURL(server.URL)
|
||||||
|
|
||||||
|
sha, err := client.GetFileSHA("test/model", "file.txt")
|
||||||
|
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(ContainSubstring("no SHA256 found"))
|
||||||
|
Expect(sha).To(Equal(""))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when getting model details", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
mockFilesResponse := `[
|
||||||
|
{
|
||||||
|
"path": "model-Q4_K_M.gguf",
|
||||||
|
"size": 1000000,
|
||||||
|
"oid": "abc123",
|
||||||
|
"lfs": {
|
||||||
|
"oid": "sha256:def456",
|
||||||
|
"size": 1000000,
|
||||||
|
"pointer": "version https://git-lfs.github.com/spec/v1",
|
||||||
|
"sha256": "def456789"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "README.md",
|
||||||
|
"size": 5000,
|
||||||
|
"oid": "readme123"
|
||||||
|
}
|
||||||
|
]`
|
||||||
|
|
||||||
|
mockFileInfoResponse := `{
|
||||||
|
"path": "model-Q4_K_M.gguf",
|
||||||
|
"size": 1000000,
|
||||||
|
"oid": "abc123",
|
||||||
|
"lfs": {
|
||||||
|
"oid": "sha256:def456",
|
||||||
|
"size": 1000000,
|
||||||
|
"pointer": "version https://git-lfs.github.com/spec/v1",
|
||||||
|
"sha256": "def456789"
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.Contains(r.URL.Path, "/tree/main") {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(mockFilesResponse))
|
||||||
|
} else if strings.Contains(r.URL.Path, "/paths-info") {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(mockFileInfoResponse))
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
client.SetBaseURL(server.URL)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should get model details successfully", func() {
|
||||||
|
details, err := client.GetModelDetails("test/model")
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(details.ModelID).To(Equal("test/model"))
|
||||||
|
Expect(details.Author).To(Equal("test"))
|
||||||
|
Expect(details.Files).To(HaveLen(2))
|
||||||
|
|
||||||
|
Expect(details.ReadmeFile).ToNot(BeNil())
|
||||||
|
Expect(details.ReadmeFile.Path).To(Equal("README.md"))
|
||||||
|
Expect(details.ReadmeFile.IsReadme).To(BeTrue())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when getting README content", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
mockReadmeContent := "# Test Model\n\nThis is a test model for demonstration purposes."
|
||||||
|
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if strings.Contains(r.URL.Path, "/raw/main/") {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(mockReadmeContent))
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
client.SetBaseURL(server.URL)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should get README content successfully", func() {
|
||||||
|
content, err := client.GetReadmeContent("test/model", "README.md")
|
||||||
|
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(content).To(Equal("# Test Model\n\nThis is a test model for demonstration purposes."))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when filtering files", func() {
|
||||||
|
It("should filter files by quantization", func() {
|
||||||
|
files := []hfapi.ModelFile{
|
||||||
|
{Path: "model-Q4_K_M.gguf"},
|
||||||
|
{Path: "model-Q3_K_M.gguf"},
|
||||||
|
{Path: "README.md", IsReadme: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := hfapi.FilterFilesByQuantization(files, "Q4_K_M")
|
||||||
|
|
||||||
|
Expect(filtered).To(HaveLen(1))
|
||||||
|
Expect(filtered[0].Path).To(Equal("model-Q4_K_M.gguf"))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should find preferred model file", func() {
|
||||||
|
files := []hfapi.ModelFile{
|
||||||
|
{Path: "model-Q3_K_M.gguf"},
|
||||||
|
{Path: "model-Q4_K_M.gguf"},
|
||||||
|
{Path: "README.md", IsReadme: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
preferences := []string{"Q4_K_M", "Q3_K_M"}
|
||||||
|
preferred := hfapi.FindPreferredModelFile(files, preferences)
|
||||||
|
|
||||||
|
Expect(preferred).ToNot(BeNil())
|
||||||
|
Expect(preferred.Path).To(Equal("model-Q4_K_M.gguf"))
|
||||||
|
Expect(preferred.IsReadme).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should return nil if no preferred file found", func() {
|
||||||
|
files := []hfapi.ModelFile{
|
||||||
|
{Path: "model-Q2_K.gguf"},
|
||||||
|
{Path: "README.md", IsReadme: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
preferences := []string{"Q4_K_M", "Q3_K_M"}
|
||||||
|
preferred := hfapi.FindPreferredModelFile(files, preferences)
|
||||||
|
|
||||||
|
Expect(preferred).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
13
.github/gallery-agent/hfapi/hfapi_suite_test.go
vendored
Normal file
13
.github/gallery-agent/hfapi/hfapi_suite_test.go
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package hfapi_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHfapi(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "HuggingFace API Suite")
|
||||||
|
}
|
||||||
351
.github/gallery-agent/main.go
vendored
Normal file
351
.github/gallery-agent/main.go
vendored
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProcessedModelFile represents a processed model file with additional metadata
|
||||||
|
type ProcessedModelFile struct {
|
||||||
|
Path string `json:"path"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
SHA256 string `json:"sha256"`
|
||||||
|
IsReadme bool `json:"is_readme"`
|
||||||
|
FileType string `json:"file_type"` // "model", "readme", "other"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessedModel represents a processed model with all gathered metadata
|
||||||
|
type ProcessedModel struct {
|
||||||
|
ModelID string `json:"model_id"`
|
||||||
|
Author string `json:"author"`
|
||||||
|
Downloads int `json:"downloads"`
|
||||||
|
LastModified string `json:"last_modified"`
|
||||||
|
Files []ProcessedModelFile `json:"files"`
|
||||||
|
PreferredModelFile *ProcessedModelFile `json:"preferred_model_file,omitempty"`
|
||||||
|
ReadmeFile *ProcessedModelFile `json:"readme_file,omitempty"`
|
||||||
|
ReadmeContent string `json:"readme_content,omitempty"`
|
||||||
|
ReadmeContentPreview string `json:"readme_content_preview,omitempty"`
|
||||||
|
QuantizationPreferences []string `json:"quantization_preferences"`
|
||||||
|
ProcessingError string `json:"processing_error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SearchResult represents the complete result of searching and processing models
|
||||||
|
type SearchResult struct {
|
||||||
|
SearchTerm string `json:"search_term"`
|
||||||
|
Limit int `json:"limit"`
|
||||||
|
Quantization string `json:"quantization"`
|
||||||
|
TotalModelsFound int `json:"total_models_found"`
|
||||||
|
Models []ProcessedModel `json:"models"`
|
||||||
|
FormattedOutput string `json:"formatted_output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddedModelSummary represents a summary of models added to the gallery
|
||||||
|
type AddedModelSummary struct {
|
||||||
|
SearchTerm string `json:"search_term"`
|
||||||
|
TotalFound int `json:"total_found"`
|
||||||
|
ModelsAdded int `json:"models_added"`
|
||||||
|
AddedModelIDs []string `json:"added_model_ids"`
|
||||||
|
AddedModelURLs []string `json:"added_model_urls"`
|
||||||
|
Quantization string `json:"quantization"`
|
||||||
|
ProcessingTime string `json:"processing_time"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// Check for synthetic mode
|
||||||
|
syntheticMode := os.Getenv("SYNTHETIC_MODE")
|
||||||
|
if syntheticMode == "true" || syntheticMode == "1" {
|
||||||
|
fmt.Println("Running in SYNTHETIC MODE - generating random test data")
|
||||||
|
err := runSyntheticMode()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error in synthetic mode: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get configuration from environment variables
|
||||||
|
searchTerm := os.Getenv("SEARCH_TERM")
|
||||||
|
if searchTerm == "" {
|
||||||
|
searchTerm = "GGUF"
|
||||||
|
}
|
||||||
|
|
||||||
|
limitStr := os.Getenv("LIMIT")
|
||||||
|
if limitStr == "" {
|
||||||
|
limitStr = "5"
|
||||||
|
}
|
||||||
|
limit, err := strconv.Atoi(limitStr)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error parsing LIMIT: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
quantization := os.Getenv("QUANTIZATION")
|
||||||
|
|
||||||
|
maxModels := os.Getenv("MAX_MODELS")
|
||||||
|
if maxModels == "" {
|
||||||
|
maxModels = "1"
|
||||||
|
}
|
||||||
|
maxModelsInt, err := strconv.Atoi(maxModels)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error parsing MAX_MODELS: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print configuration
|
||||||
|
fmt.Printf("Gallery Agent Configuration:\n")
|
||||||
|
fmt.Printf(" Search Term: %s\n", searchTerm)
|
||||||
|
fmt.Printf(" Limit: %d\n", limit)
|
||||||
|
fmt.Printf(" Quantization: %s\n", quantization)
|
||||||
|
fmt.Printf(" Max Models to Add: %d\n", maxModelsInt)
|
||||||
|
fmt.Printf(" Gallery Index Path: %s\n", os.Getenv("GALLERY_INDEX_PATH"))
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
result, err := searchAndProcessModels(searchTerm, limit, quantization)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(result.FormattedOutput)
|
||||||
|
|
||||||
|
// Use AI agent to select the most interesting models
|
||||||
|
fmt.Println("Using AI agent to select the most interesting models...")
|
||||||
|
models, err := selectMostInterestingModels(context.Background(), result)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error in model selection: %v\n", err)
|
||||||
|
// Continue with original result if selection fails
|
||||||
|
models = result.Models
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Print(models)
|
||||||
|
|
||||||
|
// Filter out models that already exist in the gallery
|
||||||
|
fmt.Println("Filtering out existing models...")
|
||||||
|
models, err = filterExistingModels(models)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error filtering existing models: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit to maxModelsInt after filtering
|
||||||
|
if len(models) > maxModelsInt {
|
||||||
|
models = models[:maxModelsInt]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track added models for summary
|
||||||
|
var addedModelIDs []string
|
||||||
|
var addedModelURLs []string
|
||||||
|
|
||||||
|
// Generate YAML entries and append to gallery/index.yaml
|
||||||
|
if len(models) > 0 {
|
||||||
|
for _, model := range models {
|
||||||
|
addedModelIDs = append(addedModelIDs, model.ModelID)
|
||||||
|
// Generate Hugging Face URL for the model
|
||||||
|
modelURL := fmt.Sprintf("https://huggingface.co/%s", model.ModelID)
|
||||||
|
addedModelURLs = append(addedModelURLs, modelURL)
|
||||||
|
}
|
||||||
|
fmt.Println("Generating YAML entries for selected models...")
|
||||||
|
err = generateYAMLForModels(context.Background(), models)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error generating YAML entries: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Println("No new models to add to the gallery.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and write summary
|
||||||
|
processingTime := time.Since(startTime).String()
|
||||||
|
summary := AddedModelSummary{
|
||||||
|
SearchTerm: searchTerm,
|
||||||
|
TotalFound: result.TotalModelsFound,
|
||||||
|
ModelsAdded: len(addedModelIDs),
|
||||||
|
AddedModelIDs: addedModelIDs,
|
||||||
|
AddedModelURLs: addedModelURLs,
|
||||||
|
Quantization: quantization,
|
||||||
|
ProcessingTime: processingTime,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write summary to file
|
||||||
|
summaryData, err := json.MarshalIndent(summary, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error marshaling summary: %v\n", err)
|
||||||
|
} else {
|
||||||
|
err = os.WriteFile("gallery-agent-summary.json", summaryData, 0644)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error writing summary file: %v\n", err)
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Summary written to gallery-agent-summary.json\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func searchAndProcessModels(searchTerm string, limit int, quantization string) (*SearchResult, error) {
|
||||||
|
client := hfapi.NewClient()
|
||||||
|
var outputBuilder strings.Builder
|
||||||
|
|
||||||
|
fmt.Println("Searching for models...")
|
||||||
|
// Initialize the result struct
|
||||||
|
result := &SearchResult{
|
||||||
|
SearchTerm: searchTerm,
|
||||||
|
Limit: limit,
|
||||||
|
Quantization: quantization,
|
||||||
|
Models: []ProcessedModel{},
|
||||||
|
}
|
||||||
|
|
||||||
|
models, err := client.GetLatest(searchTerm, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to fetch models: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Models found:", len(models))
|
||||||
|
result.TotalModelsFound = len(models)
|
||||||
|
|
||||||
|
if len(models) == 0 {
|
||||||
|
outputBuilder.WriteString("No models found.\n")
|
||||||
|
result.FormattedOutput = outputBuilder.String()
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
outputBuilder.WriteString(fmt.Sprintf("Found %d models matching '%s':\n\n", len(models), searchTerm))
|
||||||
|
|
||||||
|
// Process each model
|
||||||
|
for i, model := range models {
|
||||||
|
outputBuilder.WriteString(fmt.Sprintf("%d. Processing Model: %s\n", i+1, model.ModelID))
|
||||||
|
outputBuilder.WriteString(fmt.Sprintf(" Author: %s\n", model.Author))
|
||||||
|
outputBuilder.WriteString(fmt.Sprintf(" Downloads: %d\n", model.Downloads))
|
||||||
|
outputBuilder.WriteString(fmt.Sprintf(" Last Modified: %s\n", model.LastModified))
|
||||||
|
|
||||||
|
// Initialize processed model struct
|
||||||
|
processedModel := ProcessedModel{
|
||||||
|
ModelID: model.ModelID,
|
||||||
|
Author: model.Author,
|
||||||
|
Downloads: model.Downloads,
|
||||||
|
LastModified: model.LastModified,
|
||||||
|
QuantizationPreferences: []string{quantization, "Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get detailed model information
|
||||||
|
details, err := client.GetModelDetails(model.ModelID)
|
||||||
|
if err != nil {
|
||||||
|
errorMsg := fmt.Sprintf(" Error getting model details: %v\n", err)
|
||||||
|
outputBuilder.WriteString(errorMsg)
|
||||||
|
processedModel.ProcessingError = err.Error()
|
||||||
|
result.Models = append(result.Models, processedModel)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define quantization preferences (in order of preference)
|
||||||
|
quantizationPreferences := []string{quantization, "Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"}
|
||||||
|
|
||||||
|
// Find preferred model file
|
||||||
|
preferredModelFile := hfapi.FindPreferredModelFile(details.Files, quantizationPreferences)
|
||||||
|
|
||||||
|
// Process files
|
||||||
|
processedFiles := make([]ProcessedModelFile, len(details.Files))
|
||||||
|
for j, file := range details.Files {
|
||||||
|
fileType := "other"
|
||||||
|
if file.IsReadme {
|
||||||
|
fileType = "readme"
|
||||||
|
} else if preferredModelFile != nil && file.Path == preferredModelFile.Path {
|
||||||
|
fileType = "model"
|
||||||
|
}
|
||||||
|
|
||||||
|
processedFiles[j] = ProcessedModelFile{
|
||||||
|
Path: file.Path,
|
||||||
|
Size: file.Size,
|
||||||
|
SHA256: file.SHA256,
|
||||||
|
IsReadme: file.IsReadme,
|
||||||
|
FileType: fileType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
processedModel.Files = processedFiles
|
||||||
|
|
||||||
|
// Set preferred model file
|
||||||
|
if preferredModelFile != nil {
|
||||||
|
for _, file := range processedFiles {
|
||||||
|
if file.Path == preferredModelFile.Path {
|
||||||
|
processedModel.PreferredModelFile = &file
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print file information
|
||||||
|
outputBuilder.WriteString(fmt.Sprintf(" Files found: %d\n", len(details.Files)))
|
||||||
|
|
||||||
|
if preferredModelFile != nil {
|
||||||
|
outputBuilder.WriteString(fmt.Sprintf(" Preferred Model File: %s (SHA256: %s)\n",
|
||||||
|
preferredModelFile.Path,
|
||||||
|
preferredModelFile.SHA256))
|
||||||
|
} else {
|
||||||
|
outputBuilder.WriteString(fmt.Sprintf(" No model file found with quantization preferences: %v\n", quantizationPreferences))
|
||||||
|
}
|
||||||
|
|
||||||
|
if details.ReadmeFile != nil {
|
||||||
|
outputBuilder.WriteString(fmt.Sprintf(" README File: %s\n", details.ReadmeFile.Path))
|
||||||
|
|
||||||
|
// Find and set readme file
|
||||||
|
for _, file := range processedFiles {
|
||||||
|
if file.IsReadme {
|
||||||
|
processedModel.ReadmeFile = &file
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Getting real readme for", model.ModelID, "waiting...")
|
||||||
|
// Use agent to get the real readme and prepare the model description
|
||||||
|
readmeContent, err := getRealReadme(context.Background(), model.ModelID)
|
||||||
|
if err == nil {
|
||||||
|
processedModel.ReadmeContent = readmeContent
|
||||||
|
processedModel.ReadmeContentPreview = truncateString(readmeContent, 200)
|
||||||
|
outputBuilder.WriteString(fmt.Sprintf(" README Content Preview: %s\n",
|
||||||
|
processedModel.ReadmeContentPreview))
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Println("Real readme got", readmeContent)
|
||||||
|
// Get README content
|
||||||
|
// readmeContent, err := client.GetReadmeContent(model.ModelID, details.ReadmeFile.Path)
|
||||||
|
// if err == nil {
|
||||||
|
// processedModel.ReadmeContent = readmeContent
|
||||||
|
// processedModel.ReadmeContentPreview = truncateString(readmeContent, 200)
|
||||||
|
// outputBuilder.WriteString(fmt.Sprintf(" README Content Preview: %s\n",
|
||||||
|
// processedModel.ReadmeContentPreview))
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print all files with their checksums
|
||||||
|
outputBuilder.WriteString(" All Files:\n")
|
||||||
|
for _, file := range processedFiles {
|
||||||
|
outputBuilder.WriteString(fmt.Sprintf(" - %s (%s, %d bytes", file.Path, file.FileType, file.Size))
|
||||||
|
if file.SHA256 != "" {
|
||||||
|
outputBuilder.WriteString(fmt.Sprintf(", SHA256: %s", file.SHA256))
|
||||||
|
}
|
||||||
|
outputBuilder.WriteString(")\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
outputBuilder.WriteString("\n")
|
||||||
|
result.Models = append(result.Models, processedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
result.FormattedOutput = outputBuilder.String()
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateString(s string, maxLen int) string {
|
||||||
|
if len(s) <= maxLen {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return s[:maxLen] + "..."
|
||||||
|
}
|
||||||
190
.github/gallery-agent/testing.go
vendored
Normal file
190
.github/gallery-agent/testing.go
vendored
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// runSyntheticMode generates synthetic test data and appends it to the gallery
|
||||||
|
func runSyntheticMode() error {
|
||||||
|
generator := NewSyntheticDataGenerator()
|
||||||
|
|
||||||
|
// Generate a random number of synthetic models (1-3)
|
||||||
|
numModels := generator.rand.Intn(3) + 1
|
||||||
|
fmt.Printf("Generating %d synthetic models for testing...\n", numModels)
|
||||||
|
|
||||||
|
var models []ProcessedModel
|
||||||
|
for i := 0; i < numModels; i++ {
|
||||||
|
model := generator.GenerateProcessedModel()
|
||||||
|
models = append(models, model)
|
||||||
|
fmt.Printf("Generated synthetic model: %s\n", model.ModelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate YAML entries and append to gallery/index.yaml
|
||||||
|
fmt.Println("Generating YAML entries for synthetic models...")
|
||||||
|
err := generateYAMLForModels(context.Background(), models)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error generating YAML entries: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Successfully added %d synthetic models to the gallery for testing!\n", len(models))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyntheticDataGenerator provides methods to generate synthetic test data
|
||||||
|
type SyntheticDataGenerator struct {
|
||||||
|
rand *rand.Rand
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSyntheticDataGenerator creates a new synthetic data generator
|
||||||
|
func NewSyntheticDataGenerator() *SyntheticDataGenerator {
|
||||||
|
return &SyntheticDataGenerator{
|
||||||
|
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateProcessedModelFile creates a synthetic ProcessedModelFile
|
||||||
|
func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile {
|
||||||
|
fileTypes := []string{"model", "readme", "other"}
|
||||||
|
fileType := fileTypes[g.rand.Intn(len(fileTypes))]
|
||||||
|
|
||||||
|
var path string
|
||||||
|
var isReadme bool
|
||||||
|
|
||||||
|
switch fileType {
|
||||||
|
case "model":
|
||||||
|
path = fmt.Sprintf("model-%s.gguf", g.randomString(8))
|
||||||
|
isReadme = false
|
||||||
|
case "readme":
|
||||||
|
path = "README.md"
|
||||||
|
isReadme = true
|
||||||
|
default:
|
||||||
|
path = fmt.Sprintf("file-%s.txt", g.randomString(6))
|
||||||
|
isReadme = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return ProcessedModelFile{
|
||||||
|
Path: path,
|
||||||
|
Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB
|
||||||
|
SHA256: g.randomSHA256(),
|
||||||
|
IsReadme: isReadme,
|
||||||
|
FileType: fileType,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateProcessedModel creates a synthetic ProcessedModel
|
||||||
|
func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel {
|
||||||
|
authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"}
|
||||||
|
modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"}
|
||||||
|
|
||||||
|
author := authors[g.rand.Intn(len(authors))]
|
||||||
|
modelName := modelNames[g.rand.Intn(len(modelNames))]
|
||||||
|
modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6))
|
||||||
|
|
||||||
|
// Generate files
|
||||||
|
numFiles := g.rand.Intn(5) + 2 // 2-6 files
|
||||||
|
files := make([]ProcessedModelFile, numFiles)
|
||||||
|
|
||||||
|
// Ensure at least one model file and one readme
|
||||||
|
hasModelFile := false
|
||||||
|
hasReadme := false
|
||||||
|
|
||||||
|
for i := 0; i < numFiles; i++ {
|
||||||
|
files[i] = g.GenerateProcessedModelFile()
|
||||||
|
if files[i].FileType == "model" {
|
||||||
|
hasModelFile = true
|
||||||
|
}
|
||||||
|
if files[i].FileType == "readme" {
|
||||||
|
hasReadme = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add required files if missing
|
||||||
|
if !hasModelFile {
|
||||||
|
modelFile := g.GenerateProcessedModelFile()
|
||||||
|
modelFile.FileType = "model"
|
||||||
|
modelFile.Path = fmt.Sprintf("%s-Q4_K_M.gguf", modelName)
|
||||||
|
files = append(files, modelFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasReadme {
|
||||||
|
readmeFile := g.GenerateProcessedModelFile()
|
||||||
|
readmeFile.FileType = "readme"
|
||||||
|
readmeFile.Path = "README.md"
|
||||||
|
readmeFile.IsReadme = true
|
||||||
|
files = append(files, readmeFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find preferred model file
|
||||||
|
var preferredModelFile *ProcessedModelFile
|
||||||
|
for i := range files {
|
||||||
|
if files[i].FileType == "model" {
|
||||||
|
preferredModelFile = &files[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find readme file
|
||||||
|
var readmeFile *ProcessedModelFile
|
||||||
|
for i := range files {
|
||||||
|
if files[i].FileType == "readme" {
|
||||||
|
readmeFile = &files[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
readmeContent := g.generateReadmeContent(modelName, author)
|
||||||
|
|
||||||
|
return ProcessedModel{
|
||||||
|
ModelID: modelID,
|
||||||
|
Author: author,
|
||||||
|
Downloads: g.rand.Intn(1000000) + 1000,
|
||||||
|
LastModified: g.randomDate(),
|
||||||
|
Files: files,
|
||||||
|
PreferredModelFile: preferredModelFile,
|
||||||
|
ReadmeFile: readmeFile,
|
||||||
|
ReadmeContent: readmeContent,
|
||||||
|
ReadmeContentPreview: truncateString(readmeContent, 200),
|
||||||
|
QuantizationPreferences: []string{"Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"},
|
||||||
|
ProcessingError: "",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper methods for synthetic data generation
|
||||||
|
func (g *SyntheticDataGenerator) randomString(length int) string {
|
||||||
|
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
b := make([]byte, length)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = charset[g.rand.Intn(len(charset))]
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *SyntheticDataGenerator) randomSHA256() string {
|
||||||
|
const charset = "0123456789abcdef"
|
||||||
|
b := make([]byte, 64)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = charset[g.rand.Intn(len(charset))]
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *SyntheticDataGenerator) randomDate() string {
|
||||||
|
now := time.Now()
|
||||||
|
daysAgo := g.rand.Intn(365) // Random date within last year
|
||||||
|
pastDate := now.AddDate(0, 0, -daysAgo)
|
||||||
|
return pastDate.Format("2006-01-02T15:04:05.000Z")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string) string {
|
||||||
|
templates := []string{
|
||||||
|
fmt.Sprintf("# %s Model\n\nThis is a %s model developed by %s. It's designed for various natural language processing tasks including text generation, question answering, and conversation.\n\n## Features\n\n- High-quality text generation\n- Efficient inference\n- Multiple quantization options\n- Easy to use with LocalAI\n\n## Usage\n\nUse this model with LocalAI for various AI tasks.", strings.Title(modelName), modelName, author),
|
||||||
|
fmt.Sprintf("# %s\n\nA powerful language model from %s. This model excels at understanding and generating human-like text across multiple domains.\n\n## Capabilities\n\n- Text completion\n- Code generation\n- Creative writing\n- Technical documentation\n\n## Model Details\n\n- Architecture: Transformer-based\n- Training: Large-scale supervised learning\n- Quantization: Available in multiple formats", strings.Title(modelName), author),
|
||||||
|
fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author),
|
||||||
|
}
|
||||||
|
|
||||||
|
return templates[g.rand.Intn(len(templates))]
|
||||||
|
}
|
||||||
46
.github/gallery-agent/tools.go
vendored
Normal file
46
.github/gallery-agent/tools.go
vendored
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
"github.com/tmc/langchaingo/jsonschema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get repository README from HF
|
||||||
|
type HFReadmeTool struct {
|
||||||
|
client *hfapi.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HFReadmeTool) Run(args map[string]any) (string, error) {
|
||||||
|
q, ok := args["repository"].(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("no query")
|
||||||
|
}
|
||||||
|
readme, err := s.client.GetReadmeContent(q, "README.md")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return readme, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HFReadmeTool) Tool() openai.Tool {
|
||||||
|
return openai.Tool{
|
||||||
|
Type: openai.ToolTypeFunction,
|
||||||
|
Function: &openai.FunctionDefinition{
|
||||||
|
Name: "hf_readme",
|
||||||
|
Description: "A tool to get the README content of a huggingface repository",
|
||||||
|
Parameters: jsonschema.Definition{
|
||||||
|
Type: jsonschema.Object,
|
||||||
|
Properties: map[string]jsonschema.Definition{
|
||||||
|
"repository": {
|
||||||
|
Type: jsonschema.String,
|
||||||
|
Description: "The huggingface repository to get the README content of",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Required: []string{"repository"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
100
.github/workflows/backend.yml
vendored
100
.github/workflows/backend.yml
vendored
@@ -111,6 +111,18 @@ jobs:
|
|||||||
backend: "diffusers"
|
backend: "diffusers"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
|
- build-type: ''
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-cpu-chatterbox'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:22.04"
|
||||||
|
skip-drivers: 'true'
|
||||||
|
backend: "chatterbox"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
# CUDA 11 additional backends
|
# CUDA 11 additional backends
|
||||||
- build-type: 'cublas'
|
- build-type: 'cublas'
|
||||||
cuda-major-version: "11"
|
cuda-major-version: "11"
|
||||||
@@ -477,6 +489,18 @@ jobs:
|
|||||||
backend: "diffusers"
|
backend: "diffusers"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
|
- build-type: 'l4t'
|
||||||
|
cuda-major-version: "12"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/arm64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-nvidia-l4t-kokoro'
|
||||||
|
runs-on: 'ubuntu-24.04-arm'
|
||||||
|
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||||
|
skip-drivers: 'true'
|
||||||
|
backend: "kokoro"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
# SYCL additional backends
|
# SYCL additional backends
|
||||||
- build-type: 'intel'
|
- build-type: 'intel'
|
||||||
cuda-major-version: ""
|
cuda-major-version: ""
|
||||||
@@ -763,7 +787,7 @@ jobs:
|
|||||||
cuda-minor-version: ""
|
cuda-minor-version: ""
|
||||||
platforms: 'linux/amd64'
|
platforms: 'linux/amd64'
|
||||||
tag-latest: 'auto'
|
tag-latest: 'auto'
|
||||||
tag-suffix: '-gpu-hipblas-whisper'
|
tag-suffix: '-gpu-rocm-hipblas-whisper'
|
||||||
base-image: "rocm/dev-ubuntu-22.04:6.4.3"
|
base-image: "rocm/dev-ubuntu-22.04:6.4.3"
|
||||||
runs-on: 'ubuntu-latest'
|
runs-on: 'ubuntu-latest'
|
||||||
skip-drivers: 'false'
|
skip-drivers: 'false'
|
||||||
@@ -858,7 +882,7 @@ jobs:
|
|||||||
backend: "rfdetr"
|
backend: "rfdetr"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
- build-type: 'cublas'
|
- build-type: 'l4t'
|
||||||
cuda-major-version: "12"
|
cuda-major-version: "12"
|
||||||
cuda-minor-version: "0"
|
cuda-minor-version: "0"
|
||||||
platforms: 'linux/arm64'
|
platforms: 'linux/arm64'
|
||||||
@@ -931,6 +955,18 @@ jobs:
|
|||||||
backend: "exllama2"
|
backend: "exllama2"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
|
- build-type: 'l4t'
|
||||||
|
cuda-major-version: "12"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/arm64'
|
||||||
|
skip-drivers: 'true'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-nvidia-l4t-arm64-chatterbox'
|
||||||
|
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||||
|
runs-on: 'ubuntu-24.04-arm'
|
||||||
|
backend: "chatterbox"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
# runs out of space on the runner
|
# runs out of space on the runner
|
||||||
# - build-type: 'hipblas'
|
# - build-type: 'hipblas'
|
||||||
# cuda-major-version: ""
|
# cuda-major-version: ""
|
||||||
@@ -957,6 +993,55 @@ jobs:
|
|||||||
backend: "kitten-tts"
|
backend: "kitten-tts"
|
||||||
dockerfile: "./backend/Dockerfile.python"
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
context: "./backend"
|
context: "./backend"
|
||||||
|
# neutts
|
||||||
|
- build-type: ''
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64,linux/arm64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-cpu-neutts'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:22.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "neutts"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
|
- build-type: 'cublas'
|
||||||
|
cuda-major-version: "12"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-nvidia-cuda-12-neutts'
|
||||||
|
runs-on: 'ubuntu-latest'
|
||||||
|
base-image: "ubuntu:22.04"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "neutts"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
|
- build-type: 'hipblas'
|
||||||
|
cuda-major-version: ""
|
||||||
|
cuda-minor-version: ""
|
||||||
|
platforms: 'linux/amd64'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-gpu-rocm-hipblas-neutts'
|
||||||
|
runs-on: 'arc-runner-set'
|
||||||
|
base-image: "rocm/dev-ubuntu-22.04:6.4.3"
|
||||||
|
skip-drivers: 'false'
|
||||||
|
backend: "neutts"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
|
- build-type: 'l4t'
|
||||||
|
cuda-major-version: "12"
|
||||||
|
cuda-minor-version: "0"
|
||||||
|
platforms: 'linux/arm64'
|
||||||
|
skip-drivers: 'true'
|
||||||
|
tag-latest: 'auto'
|
||||||
|
tag-suffix: '-nvidia-l4t-arm64-neutts'
|
||||||
|
base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0"
|
||||||
|
runs-on: 'ubuntu-24.04-arm'
|
||||||
|
backend: "neutts"
|
||||||
|
dockerfile: "./backend/Dockerfile.python"
|
||||||
|
context: "./backend"
|
||||||
backend-jobs-darwin:
|
backend-jobs-darwin:
|
||||||
uses: ./.github/workflows/backend_build_darwin.yml
|
uses: ./.github/workflows/backend_build_darwin.yml
|
||||||
strategy:
|
strategy:
|
||||||
@@ -968,6 +1053,9 @@ jobs:
|
|||||||
- backend: "mlx"
|
- backend: "mlx"
|
||||||
tag-suffix: "-metal-darwin-arm64-mlx"
|
tag-suffix: "-metal-darwin-arm64-mlx"
|
||||||
build-type: "mps"
|
build-type: "mps"
|
||||||
|
- backend: "chatterbox"
|
||||||
|
tag-suffix: "-metal-darwin-arm64-chatterbox"
|
||||||
|
build-type: "mps"
|
||||||
- backend: "mlx-vlm"
|
- backend: "mlx-vlm"
|
||||||
tag-suffix: "-metal-darwin-arm64-mlx-vlm"
|
tag-suffix: "-metal-darwin-arm64-mlx-vlm"
|
||||||
build-type: "mps"
|
build-type: "mps"
|
||||||
@@ -1021,7 +1109,7 @@ jobs:
|
|||||||
make protogen-go
|
make protogen-go
|
||||||
make backends/llama-cpp-darwin
|
make backends/llama-cpp-darwin
|
||||||
- name: Upload llama-cpp.tar
|
- name: Upload llama-cpp.tar
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: llama-cpp-tar
|
name: llama-cpp-tar
|
||||||
path: backend-images/llama-cpp.tar
|
path: backend-images/llama-cpp.tar
|
||||||
@@ -1031,7 +1119,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Download llama-cpp.tar
|
- name: Download llama-cpp.tar
|
||||||
uses: actions/download-artifact@v5
|
uses: actions/download-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: llama-cpp-tar
|
name: llama-cpp-tar
|
||||||
path: .
|
path: .
|
||||||
@@ -1109,7 +1197,7 @@ jobs:
|
|||||||
export PLATFORMARCH=darwin/amd64
|
export PLATFORMARCH=darwin/amd64
|
||||||
make backends/llama-cpp-darwin
|
make backends/llama-cpp-darwin
|
||||||
- name: Upload llama-cpp.tar
|
- name: Upload llama-cpp.tar
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: llama-cpp-tar-x86
|
name: llama-cpp-tar-x86
|
||||||
path: backend-images/llama-cpp.tar
|
path: backend-images/llama-cpp.tar
|
||||||
@@ -1119,7 +1207,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Download llama-cpp.tar
|
- name: Download llama-cpp.tar
|
||||||
uses: actions/download-artifact@v5
|
uses: actions/download-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: llama-cpp-tar-x86
|
name: llama-cpp-tar-x86
|
||||||
path: .
|
path: .
|
||||||
|
|||||||
4
.github/workflows/backend_build_darwin.yml
vendored
4
.github/workflows/backend_build_darwin.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
|||||||
BACKEND=${{ inputs.backend }} BUILD_TYPE=${{ inputs.build-type }} USE_PIP=${{ inputs.use-pip }} make build-darwin-${{ inputs.lang }}-backend
|
BACKEND=${{ inputs.backend }} BUILD_TYPE=${{ inputs.build-type }} USE_PIP=${{ inputs.use-pip }} make build-darwin-${{ inputs.lang }}-backend
|
||||||
|
|
||||||
- name: Upload ${{ inputs.backend }}.tar
|
- name: Upload ${{ inputs.backend }}.tar
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: ${{ inputs.backend }}-tar
|
name: ${{ inputs.backend }}-tar
|
||||||
path: backend-images/${{ inputs.backend }}.tar
|
path: backend-images/${{ inputs.backend }}.tar
|
||||||
@@ -85,7 +85,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Download ${{ inputs.backend }}.tar
|
- name: Download ${{ inputs.backend }}.tar
|
||||||
uses: actions/download-artifact@v5
|
uses: actions/download-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: ${{ inputs.backend }}-tar
|
name: ${{ inputs.backend }}-tar
|
||||||
path: .
|
path: .
|
||||||
|
|||||||
10
.github/workflows/build-test.yaml
vendored
10
.github/workflows/build-test.yaml
vendored
@@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: 1.23
|
go-version: 1.25
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
run: |
|
run: |
|
||||||
make dev-dist
|
make dev-dist
|
||||||
@@ -31,13 +31,13 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: 1.23
|
go-version: 1.25
|
||||||
- name: Build launcher for macOS ARM64
|
- name: Build launcher for macOS ARM64
|
||||||
run: |
|
run: |
|
||||||
make build-launcher-darwin
|
make build-launcher-darwin
|
||||||
ls -liah dist
|
ls -liah dist
|
||||||
- name: Upload macOS launcher artifacts
|
- name: Upload macOS launcher artifacts
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: launcher-macos
|
name: launcher-macos
|
||||||
path: dist/
|
path: dist/
|
||||||
@@ -53,14 +53,14 @@ jobs:
|
|||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: 1.23
|
go-version: 1.25
|
||||||
- name: Build launcher for Linux
|
- name: Build launcher for Linux
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
sudo apt-get update
|
||||||
sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev
|
sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev
|
||||||
make build-launcher-linux
|
make build-launcher-linux
|
||||||
- name: Upload Linux launcher artifacts
|
- name: Upload Linux launcher artifacts
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: launcher-linux
|
name: launcher-linux
|
||||||
path: local-ai-launcher-linux.tar.xz
|
path: local-ai-launcher-linux.tar.xz
|
||||||
|
|||||||
126
.github/workflows/gallery-agent.yaml
vendored
Normal file
126
.github/workflows/gallery-agent.yaml
vendored
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
name: Gallery Agent
|
||||||
|
on:
|
||||||
|
|
||||||
|
schedule:
|
||||||
|
- cron: '0 */1 * * *' # Run every 4 hours
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
search_term:
|
||||||
|
description: 'Search term for models'
|
||||||
|
required: false
|
||||||
|
default: 'GGUF'
|
||||||
|
type: string
|
||||||
|
limit:
|
||||||
|
description: 'Maximum number of models to process'
|
||||||
|
required: false
|
||||||
|
default: '15'
|
||||||
|
type: string
|
||||||
|
quantization:
|
||||||
|
description: 'Preferred quantization format'
|
||||||
|
required: false
|
||||||
|
default: 'Q4_K_M'
|
||||||
|
type: string
|
||||||
|
max_models:
|
||||||
|
description: 'Maximum number of models to add to the gallery'
|
||||||
|
required: false
|
||||||
|
default: '1'
|
||||||
|
type: string
|
||||||
|
jobs:
|
||||||
|
gallery-agent:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Set up Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.21'
|
||||||
|
|
||||||
|
- name: Build gallery agent
|
||||||
|
run: |
|
||||||
|
cd .github/gallery-agent
|
||||||
|
go mod download
|
||||||
|
go build -o gallery-agent .
|
||||||
|
|
||||||
|
- name: Run gallery agent
|
||||||
|
env:
|
||||||
|
OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }}
|
||||||
|
OPENAI_KEY: ${{ secrets.OPENAI_KEY }}
|
||||||
|
OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }}
|
||||||
|
SEARCH_TERM: ${{ github.event.inputs.search_term || 'GGUF' }}
|
||||||
|
LIMIT: ${{ github.event.inputs.limit || '15' }}
|
||||||
|
QUANTIZATION: ${{ github.event.inputs.quantization || 'Q4_K_M' }}
|
||||||
|
MAX_MODELS: ${{ github.event.inputs.max_models || '1' }}
|
||||||
|
run: |
|
||||||
|
export GALLERY_INDEX_PATH=$PWD/gallery/index.yaml
|
||||||
|
cd .github/gallery-agent
|
||||||
|
./gallery-agent
|
||||||
|
rm -rf gallery-agent
|
||||||
|
|
||||||
|
- name: Check for changes
|
||||||
|
id: check_changes
|
||||||
|
run: |
|
||||||
|
if git diff --quiet gallery/index.yaml; then
|
||||||
|
echo "changes=false" >> $GITHUB_OUTPUT
|
||||||
|
echo "No changes detected in gallery/index.yaml"
|
||||||
|
else
|
||||||
|
echo "changes=true" >> $GITHUB_OUTPUT
|
||||||
|
echo "Changes detected in gallery/index.yaml"
|
||||||
|
git diff gallery/index.yaml
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Read gallery agent summary
|
||||||
|
id: read_summary
|
||||||
|
if: steps.check_changes.outputs.changes == 'true'
|
||||||
|
run: |
|
||||||
|
if [ -f ".github/gallery-agent/gallery-agent-summary.json" ]; then
|
||||||
|
echo "summary_exists=true" >> $GITHUB_OUTPUT
|
||||||
|
# Extract summary data using jq
|
||||||
|
echo "search_term=$(jq -r '.search_term' .github/gallery-agent/gallery-agent-summary.json)" >> $GITHUB_OUTPUT
|
||||||
|
echo "total_found=$(jq -r '.total_found' .github/gallery-agent/gallery-agent-summary.json)" >> $GITHUB_OUTPUT
|
||||||
|
echo "models_added=$(jq -r '.models_added' .github/gallery-agent/gallery-agent-summary.json)" >> $GITHUB_OUTPUT
|
||||||
|
echo "quantization=$(jq -r '.quantization' .github/gallery-agent/gallery-agent-summary.json)" >> $GITHUB_OUTPUT
|
||||||
|
echo "processing_time=$(jq -r '.processing_time' .github/gallery-agent/gallery-agent-summary.json)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
# Create a formatted list of added models with URLs
|
||||||
|
added_models=$(jq -r 'range(0; .added_model_ids | length) as $i | "- [\(.added_model_ids[$i])](\(.added_model_urls[$i]))"' .github/gallery-agent/gallery-agent-summary.json | tr '\n' '\n')
|
||||||
|
echo "added_models<<EOF" >> $GITHUB_OUTPUT
|
||||||
|
echo "$added_models" >> $GITHUB_OUTPUT
|
||||||
|
echo "EOF" >> $GITHUB_OUTPUT
|
||||||
|
rm -f .github/gallery-agent/gallery-agent-summary.json
|
||||||
|
else
|
||||||
|
echo "summary_exists=false" >> $GITHUB_OUTPUT
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Create Pull Request
|
||||||
|
if: steps.check_changes.outputs.changes == 'true'
|
||||||
|
uses: peter-evans/create-pull-request@v7
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.UPDATE_BOT_TOKEN }}
|
||||||
|
push-to-fork: ci-forks/LocalAI
|
||||||
|
commit-message: 'chore(model gallery): :robot: add new models via gallery agent'
|
||||||
|
title: 'chore(model gallery): :robot: add ${{ steps.read_summary.outputs.models_added || 0 }} new models via gallery agent'
|
||||||
|
# Branch has to be unique so PRs are not overriding each other
|
||||||
|
branch-suffix: timestamp
|
||||||
|
body: |
|
||||||
|
This PR was automatically created by the gallery agent workflow.
|
||||||
|
|
||||||
|
**Summary:**
|
||||||
|
- **Search Term:** ${{ steps.read_summary.outputs.search_term || github.event.inputs.search_term || 'GGUF' }}
|
||||||
|
- **Models Found:** ${{ steps.read_summary.outputs.total_found || 'N/A' }}
|
||||||
|
- **Models Added:** ${{ steps.read_summary.outputs.models_added || '0' }}
|
||||||
|
- **Quantization:** ${{ steps.read_summary.outputs.quantization || github.event.inputs.quantization || 'Q4_K_M' }}
|
||||||
|
- **Processing Time:** ${{ steps.read_summary.outputs.processing_time || 'N/A' }}
|
||||||
|
|
||||||
|
**Added Models:**
|
||||||
|
${{ steps.read_summary.outputs.added_models || '- No models added' }}
|
||||||
|
|
||||||
|
**Workflow Details:**
|
||||||
|
- Triggered by: `${{ github.event_name }}`
|
||||||
|
- Run ID: `${{ github.run_id }}`
|
||||||
|
- Commit: `${{ github.sha }}`
|
||||||
|
signoff: true
|
||||||
|
delete-branch: true
|
||||||
2
.github/workflows/labeler.yml
vendored
2
.github/workflows/labeler.yml
vendored
@@ -9,4 +9,4 @@ jobs:
|
|||||||
pull-requests: write
|
pull-requests: write
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/labeler@v5
|
- uses: actions/labeler@v6
|
||||||
5
.github/workflows/localaibot_automerge.yml
vendored
5
.github/workflows/localaibot_automerge.yml
vendored
@@ -6,11 +6,12 @@ permissions:
|
|||||||
contents: write
|
contents: write
|
||||||
pull-requests: write
|
pull-requests: write
|
||||||
packages: read
|
packages: read
|
||||||
|
issues: write # for Homebrew/actions/post-comment
|
||||||
|
actions: write # to dispatch publish workflow
|
||||||
jobs:
|
jobs:
|
||||||
dependabot:
|
dependabot:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: ${{ github.actor == 'localai-bot' }}
|
if: ${{ github.actor == 'localai-bot' && !contains(github.event.pull_request.title, 'chore(model gallery):') }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v5
|
||||||
|
|||||||
18
.github/workflows/notify-models.yaml
vendored
18
.github/workflows/notify-models.yaml
vendored
@@ -1,22 +1,27 @@
|
|||||||
name: Notifications for new models
|
name: Notifications for new models
|
||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request_target:
|
||||||
types:
|
types:
|
||||||
- closed
|
- closed
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
notify-discord:
|
notify-discord:
|
||||||
if: ${{ (github.event.pull_request.merged == true) && (contains(github.event.pull_request.labels.*.name, 'area/ai-model')) }}
|
if: ${{ (github.event.pull_request.merged == true) && (contains(github.event.pull_request.labels.*.name, 'area/ai-model')) }}
|
||||||
env:
|
env:
|
||||||
MODEL_NAME: gemma-3-12b-it
|
MODEL_NAME: gemma-3-12b-it-qat
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v5
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # needed to checkout all branches for this Action to work
|
fetch-depth: 0 # needed to checkout all branches for this Action to work
|
||||||
|
ref: ${{ github.event.pull_request.head.sha }} # Checkout the PR head to get the actual changes
|
||||||
- uses: mudler/localai-github-action@v1
|
- uses: mudler/localai-github-action@v1
|
||||||
with:
|
with:
|
||||||
model: 'gemma-3-12b-it' # Any from models.localai.io, or from huggingface.com with: "huggingface://<repository>/file"
|
model: 'gemma-3-12b-it-qat' # Any from models.localai.io, or from huggingface.com with: "huggingface://<repository>/file"
|
||||||
# Check the PR diff using the current branch and the base branch of the PR
|
# Check the PR diff using the current branch and the base branch of the PR
|
||||||
- uses: GrantBirki/git-diff-action@v2.8.1
|
- uses: GrantBirki/git-diff-action@v2.8.1
|
||||||
id: git-diff-action
|
id: git-diff-action
|
||||||
@@ -79,7 +84,7 @@ jobs:
|
|||||||
args: ${{ steps.summarize.outputs.message }}
|
args: ${{ steps.summarize.outputs.message }}
|
||||||
- name: Setup tmate session if fails
|
- name: Setup tmate session if fails
|
||||||
if: ${{ failure() }}
|
if: ${{ failure() }}
|
||||||
uses: mxschmitt/action-tmate@v3.22
|
uses: mxschmitt/action-tmate@v3.23
|
||||||
with:
|
with:
|
||||||
detached: true
|
detached: true
|
||||||
connect-timeout-seconds: 180
|
connect-timeout-seconds: 180
|
||||||
@@ -87,12 +92,13 @@ jobs:
|
|||||||
notify-twitter:
|
notify-twitter:
|
||||||
if: ${{ (github.event.pull_request.merged == true) && (contains(github.event.pull_request.labels.*.name, 'area/ai-model')) }}
|
if: ${{ (github.event.pull_request.merged == true) && (contains(github.event.pull_request.labels.*.name, 'area/ai-model')) }}
|
||||||
env:
|
env:
|
||||||
MODEL_NAME: gemma-3-12b-it
|
MODEL_NAME: gemma-3-12b-it-qat
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v5
|
- uses: actions/checkout@v5
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # needed to checkout all branches for this Action to work
|
fetch-depth: 0 # needed to checkout all branches for this Action to work
|
||||||
|
ref: ${{ github.event.pull_request.head.sha }} # Checkout the PR head to get the actual changes
|
||||||
- name: Start LocalAI
|
- name: Start LocalAI
|
||||||
run: |
|
run: |
|
||||||
echo "Starting LocalAI..."
|
echo "Starting LocalAI..."
|
||||||
@@ -161,7 +167,7 @@ jobs:
|
|||||||
TWITTER_ACCESS_TOKEN_SECRET: ${{ secrets.TWITTER_ACCESS_TOKEN_SECRET }}
|
TWITTER_ACCESS_TOKEN_SECRET: ${{ secrets.TWITTER_ACCESS_TOKEN_SECRET }}
|
||||||
- name: Setup tmate session if fails
|
- name: Setup tmate session if fails
|
||||||
if: ${{ failure() }}
|
if: ${{ failure() }}
|
||||||
uses: mxschmitt/action-tmate@v3.22
|
uses: mxschmitt/action-tmate@v3.23
|
||||||
with:
|
with:
|
||||||
detached: true
|
detached: true
|
||||||
connect-timeout-seconds: 180
|
connect-timeout-seconds: 180
|
||||||
|
|||||||
3
.github/workflows/notify-releases.yaml
vendored
3
.github/workflows/notify-releases.yaml
vendored
@@ -11,10 +11,11 @@ jobs:
|
|||||||
RELEASE_BODY: ${{ github.event.release.body }}
|
RELEASE_BODY: ${{ github.event.release.body }}
|
||||||
RELEASE_TITLE: ${{ github.event.release.name }}
|
RELEASE_TITLE: ${{ github.event.release.name }}
|
||||||
RELEASE_TAG_NAME: ${{ github.event.release.tag_name }}
|
RELEASE_TAG_NAME: ${{ github.event.release.tag_name }}
|
||||||
|
MODEL_NAME: gemma-3-12b-it-qat
|
||||||
steps:
|
steps:
|
||||||
- uses: mudler/localai-github-action@v1
|
- uses: mudler/localai-github-action@v1
|
||||||
with:
|
with:
|
||||||
model: 'gemma-3-12b-it' # Any from models.localai.io, or from huggingface.com with: "huggingface://<repository>/file"
|
model: 'gemma-3-12b-it-qat' # Any from models.localai.io, or from huggingface.com with: "huggingface://<repository>/file"
|
||||||
- name: Summarize
|
- name: Summarize
|
||||||
id: summarize
|
id: summarize
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
4
.github/workflows/release.yaml
vendored
4
.github/workflows/release.yaml
vendored
@@ -41,7 +41,7 @@ jobs:
|
|||||||
- name: Upload DMG to Release
|
- name: Upload DMG to Release
|
||||||
uses: softprops/action-gh-release@v2
|
uses: softprops/action-gh-release@v2
|
||||||
with:
|
with:
|
||||||
files: ./dist/LocalAI-Launcher.dmg
|
files: ./dist/LocalAI.dmg
|
||||||
launcher-build-linux:
|
launcher-build-linux:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
@@ -61,4 +61,4 @@ jobs:
|
|||||||
- name: Upload Linux launcher artifacts
|
- name: Upload Linux launcher artifacts
|
||||||
uses: softprops/action-gh-release@v2
|
uses: softprops/action-gh-release@v2
|
||||||
with:
|
with:
|
||||||
files: ./local-ai-launcher-linux.tar.xz
|
files: ./local-ai-launcher-linux.tar.xz
|
||||||
|
|||||||
4
.github/workflows/secscan.yaml
vendored
4
.github/workflows/secscan.yaml
vendored
@@ -18,13 +18,13 @@ jobs:
|
|||||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||||
- name: Run Gosec Security Scanner
|
- name: Run Gosec Security Scanner
|
||||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||||
uses: securego/gosec@v2.22.8
|
uses: securego/gosec@v2.22.9
|
||||||
with:
|
with:
|
||||||
# we let the report trigger content trigger a failure using the GitHub Security features.
|
# we let the report trigger content trigger a failure using the GitHub Security features.
|
||||||
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
args: '-no-fail -fmt sarif -out results.sarif ./...'
|
||||||
- name: Upload SARIF file
|
- name: Upload SARIF file
|
||||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||||
uses: github/codeql-action/upload-sarif@v3
|
uses: github/codeql-action/upload-sarif@v4
|
||||||
with:
|
with:
|
||||||
# Path to SARIF file relative to the root of the repository
|
# Path to SARIF file relative to the root of the repository
|
||||||
sarif_file: results.sarif
|
sarif_file: results.sarif
|
||||||
|
|||||||
2
.github/workflows/stalebot.yml
vendored
2
.github/workflows/stalebot.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
|||||||
stale:
|
stale:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9
|
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # v9
|
||||||
with:
|
with:
|
||||||
stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
|
stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days.'
|
||||||
stale-pr-message: 'This PR is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 10 days.'
|
stale-pr-message: 'This PR is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 10 days.'
|
||||||
|
|||||||
10
.github/workflows/test.yml
vendored
10
.github/workflows/test.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go-version: ['1.21.x']
|
go-version: ['1.25.x']
|
||||||
steps:
|
steps:
|
||||||
- name: Free Disk Space (Ubuntu)
|
- name: Free Disk Space (Ubuntu)
|
||||||
uses: jlumbroso/free-disk-space@main
|
uses: jlumbroso/free-disk-space@main
|
||||||
@@ -124,7 +124,7 @@ jobs:
|
|||||||
PATH="$PATH:/root/go/bin" GO_TAGS="tts" make --jobs 5 --output-sync=target test
|
PATH="$PATH:/root/go/bin" GO_TAGS="tts" make --jobs 5 --output-sync=target test
|
||||||
- name: Setup tmate session if tests fail
|
- name: Setup tmate session if tests fail
|
||||||
if: ${{ failure() }}
|
if: ${{ failure() }}
|
||||||
uses: mxschmitt/action-tmate@v3.22
|
uses: mxschmitt/action-tmate@v3.23
|
||||||
with:
|
with:
|
||||||
detached: true
|
detached: true
|
||||||
connect-timeout-seconds: 180
|
connect-timeout-seconds: 180
|
||||||
@@ -183,7 +183,7 @@ jobs:
|
|||||||
PATH="$PATH:$HOME/go/bin" make backends/local-store backends/silero-vad backends/llama-cpp backends/whisper backends/piper backends/stablediffusion-ggml docker-build-aio e2e-aio
|
PATH="$PATH:$HOME/go/bin" make backends/local-store backends/silero-vad backends/llama-cpp backends/whisper backends/piper backends/stablediffusion-ggml docker-build-aio e2e-aio
|
||||||
- name: Setup tmate session if tests fail
|
- name: Setup tmate session if tests fail
|
||||||
if: ${{ failure() }}
|
if: ${{ failure() }}
|
||||||
uses: mxschmitt/action-tmate@v3.22
|
uses: mxschmitt/action-tmate@v3.23
|
||||||
with:
|
with:
|
||||||
detached: true
|
detached: true
|
||||||
connect-timeout-seconds: 180
|
connect-timeout-seconds: 180
|
||||||
@@ -193,7 +193,7 @@ jobs:
|
|||||||
runs-on: macOS-14
|
runs-on: macOS-14
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go-version: ['1.21.x']
|
go-version: ['1.25.x']
|
||||||
steps:
|
steps:
|
||||||
- name: Clone
|
- name: Clone
|
||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v5
|
||||||
@@ -226,7 +226,7 @@ jobs:
|
|||||||
PATH="$PATH:$HOME/go/bin" BUILD_TYPE="GITHUB_CI_HAS_BROKEN_METAL" CMAKE_ARGS="-DGGML_F16C=OFF -DGGML_AVX512=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF" make --jobs 4 --output-sync=target test
|
PATH="$PATH:$HOME/go/bin" BUILD_TYPE="GITHUB_CI_HAS_BROKEN_METAL" CMAKE_ARGS="-DGGML_F16C=OFF -DGGML_AVX512=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF" make --jobs 4 --output-sync=target test
|
||||||
- name: Setup tmate session if tests fail
|
- name: Setup tmate session if tests fail
|
||||||
if: ${{ failure() }}
|
if: ${{ failure() }}
|
||||||
uses: mxschmitt/action-tmate@v3.22
|
uses: mxschmitt/action-tmate@v3.23
|
||||||
with:
|
with:
|
||||||
detached: true
|
detached: true
|
||||||
connect-timeout-seconds: 180
|
connect-timeout-seconds: 180
|
||||||
|
|||||||
10
Dockerfile
10
Dockerfile
@@ -78,6 +78,16 @@ RUN <<EOT bash
|
|||||||
fi
|
fi
|
||||||
EOT
|
EOT
|
||||||
|
|
||||||
|
# https://github.com/NVIDIA/Isaac-GR00T/issues/343
|
||||||
|
RUN <<EOT bash
|
||||||
|
if [ "${BUILD_TYPE}" = "cublas" ] && [ "${TARGETARCH}" = "arm64" ]; then
|
||||||
|
wget https://developer.download.nvidia.com/compute/cudss/0.6.0/local_installers/cudss-local-tegra-repo-ubuntu2204-0.6.0_0.6.0-1_arm64.deb && \
|
||||||
|
dpkg -i cudss-local-tegra-repo-ubuntu2204-0.6.0_0.6.0-1_arm64.deb && \
|
||||||
|
cp /var/cudss-local-tegra-repo-ubuntu2204-0.6.0/cudss-*-keyring.gpg /usr/share/keyrings/ && \
|
||||||
|
apt-get update && apt-get -y install cudss
|
||||||
|
fi
|
||||||
|
EOT
|
||||||
|
|
||||||
# If we are building with clblas support, we need the libraries for the builds
|
# If we are building with clblas support, we need the libraries for the builds
|
||||||
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
RUN if [ "${BUILD_TYPE}" = "clblas" ] && [ "${SKIP_DRIVERS}" = "false" ]; then \
|
||||||
apt-get update && \
|
apt-get update && \
|
||||||
|
|||||||
21
Makefile
21
Makefile
@@ -117,8 +117,8 @@ run: ## run local-ai
|
|||||||
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
|
CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
|
||||||
|
|
||||||
test-models/testmodel.ggml:
|
test-models/testmodel.ggml:
|
||||||
mkdir test-models
|
mkdir -p test-models
|
||||||
mkdir test-dir
|
mkdir -p test-dir
|
||||||
wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml
|
wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml
|
||||||
wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
|
wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en
|
||||||
wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert
|
wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert
|
||||||
@@ -369,10 +369,16 @@ backends/kitten-tts: docker-build-kitten-tts docker-save-kitten-tts build
|
|||||||
backends/kokoro: docker-build-kokoro docker-save-kokoro build
|
backends/kokoro: docker-build-kokoro docker-save-kokoro build
|
||||||
./local-ai backends install "ocifile://$(abspath ./backend-images/kokoro.tar)"
|
./local-ai backends install "ocifile://$(abspath ./backend-images/kokoro.tar)"
|
||||||
|
|
||||||
|
backends/chatterbox: docker-build-chatterbox docker-save-chatterbox build
|
||||||
|
./local-ai backends install "ocifile://$(abspath ./backend-images/chatterbox.tar)"
|
||||||
|
|
||||||
backends/llama-cpp-darwin: build
|
backends/llama-cpp-darwin: build
|
||||||
bash ./scripts/build/llama-cpp-darwin.sh
|
bash ./scripts/build/llama-cpp-darwin.sh
|
||||||
./local-ai backends install "ocifile://$(abspath ./backend-images/llama-cpp.tar)"
|
./local-ai backends install "ocifile://$(abspath ./backend-images/llama-cpp.tar)"
|
||||||
|
|
||||||
|
backends/neutts: docker-build-neutts docker-save-neutts build
|
||||||
|
./local-ai backends install "ocifile://$(abspath ./backend-images/neutts.tar)"
|
||||||
|
|
||||||
build-darwin-python-backend: build
|
build-darwin-python-backend: build
|
||||||
bash ./scripts/build/python-darwin.sh
|
bash ./scripts/build/python-darwin.sh
|
||||||
|
|
||||||
@@ -426,6 +432,15 @@ docker-build-kitten-tts:
|
|||||||
docker-save-kitten-tts: backend-images
|
docker-save-kitten-tts: backend-images
|
||||||
docker save local-ai-backend:kitten-tts -o backend-images/kitten-tts.tar
|
docker save local-ai-backend:kitten-tts -o backend-images/kitten-tts.tar
|
||||||
|
|
||||||
|
docker-save-chatterbox: backend-images
|
||||||
|
docker save local-ai-backend:chatterbox -o backend-images/chatterbox.tar
|
||||||
|
|
||||||
|
docker-build-neutts:
|
||||||
|
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:neutts -f backend/Dockerfile.python --build-arg BACKEND=neutts ./backend
|
||||||
|
|
||||||
|
docker-save-neutts: backend-images
|
||||||
|
docker save local-ai-backend:neutts -o backend-images/neutts.tar
|
||||||
|
|
||||||
docker-build-kokoro:
|
docker-build-kokoro:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:kokoro -f backend/Dockerfile.python --build-arg BACKEND=kokoro ./backend
|
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:kokoro -f backend/Dockerfile.python --build-arg BACKEND=kokoro ./backend
|
||||||
|
|
||||||
@@ -493,7 +508,7 @@ docker-build-bark:
|
|||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:bark -f backend/Dockerfile.python --build-arg BACKEND=bark .
|
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:bark -f backend/Dockerfile.python --build-arg BACKEND=bark .
|
||||||
|
|
||||||
docker-build-chatterbox:
|
docker-build-chatterbox:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:chatterbox -f backend/Dockerfile.python --build-arg BACKEND=chatterbox .
|
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:chatterbox -f backend/Dockerfile.python --build-arg BACKEND=chatterbox ./backend
|
||||||
|
|
||||||
docker-build-exllama2:
|
docker-build-exllama2:
|
||||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:exllama2 -f backend/Dockerfile.python --build-arg BACKEND=exllama2 .
|
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:exllama2 -f backend/Dockerfile.python --build-arg BACKEND=exllama2 .
|
||||||
|
|||||||
28
README.md
28
README.md
@@ -43,7 +43,7 @@
|
|||||||
|
|
||||||
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
||||||
>
|
>
|
||||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🥽 Demo](https://demo.localai.io) [🌍 Explorer](https://explorer.localai.io) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
|
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🌍 Explorer](https://explorer.localai.io) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
|
||||||
[](https://t.me/localaiofficial_bot)
|
[](https://t.me/localaiofficial_bot)
|
||||||
|
|
||||||
[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[](https://artifacthub.io/packages/search?repo=localai)
|
[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[](https://artifacthub.io/packages/search?repo=localai)
|
||||||
@@ -110,8 +110,21 @@ curl https://localai.io/install.sh | sh
|
|||||||
|
|
||||||
For more installation options, see [Installer Options](https://localai.io/docs/advanced/installer/).
|
For more installation options, see [Installer Options](https://localai.io/docs/advanced/installer/).
|
||||||
|
|
||||||
|
### macOS Download:
|
||||||
|
|
||||||
|
<a href="https://github.com/mudler/LocalAI/releases/latest/download/LocalAI.dmg">
|
||||||
|
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
|
||||||
|
</a>
|
||||||
|
|
||||||
Or run with docker:
|
Or run with docker:
|
||||||
|
|
||||||
|
> **💡 Docker Run vs Docker Start**
|
||||||
|
>
|
||||||
|
> - `docker run` creates and starts a new container. If a container with the same name already exists, this command will fail.
|
||||||
|
> - `docker start` starts an existing container that was previously created with `docker run`.
|
||||||
|
>
|
||||||
|
> If you've already run LocalAI before and want to start it again, use: `docker start -i local-ai`
|
||||||
|
|
||||||
### CPU only image:
|
### CPU only image:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -191,6 +204,8 @@ For more information, see [💻 Getting started](https://localai.io/basics/getti
|
|||||||
|
|
||||||
## 📰 Latest project news
|
## 📰 Latest project news
|
||||||
|
|
||||||
|
- October 2025: 🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support added for agentic capabilities with external tools
|
||||||
|
- September 2025: New Launcher application for MacOS and Linux, extended support to many backends for Mac and Nvidia L4T devices. Models: Added MLX-Audio, WAN 2.2. WebUI improvements and Python-based backends now ships portable python environments.
|
||||||
- August 2025: MLX, MLX-VLM, Diffusers and llama.cpp are now supported on Mac M1/M2/M3+ chips ( with `development` suffix in the gallery ): https://github.com/mudler/LocalAI/pull/6049 https://github.com/mudler/LocalAI/pull/6119 https://github.com/mudler/LocalAI/pull/6121 https://github.com/mudler/LocalAI/pull/6060
|
- August 2025: MLX, MLX-VLM, Diffusers and llama.cpp are now supported on Mac M1/M2/M3+ chips ( with `development` suffix in the gallery ): https://github.com/mudler/LocalAI/pull/6049 https://github.com/mudler/LocalAI/pull/6119 https://github.com/mudler/LocalAI/pull/6121 https://github.com/mudler/LocalAI/pull/6060
|
||||||
- July/August 2025: 🔍 [Object Detection](https://localai.io/features/object-detection/) added to the API featuring [rf-detr](https://github.com/roboflow/rf-detr)
|
- July/August 2025: 🔍 [Object Detection](https://localai.io/features/object-detection/) added to the API featuring [rf-detr](https://github.com/roboflow/rf-detr)
|
||||||
- July 2025: All backends migrated outside of the main binary. LocalAI is now more lightweight, small, and automatically downloads the required backend to run the model. [Read the release notes](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
|
- July 2025: All backends migrated outside of the main binary. LocalAI is now more lightweight, small, and automatically downloads the required backend to run the model. [Read the release notes](https://github.com/mudler/LocalAI/releases/tag/v3.2.0)
|
||||||
@@ -229,7 +244,7 @@ Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3A
|
|||||||
- 🔍 [Object Detection](https://localai.io/features/object-detection/)
|
- 🔍 [Object Detection](https://localai.io/features/object-detection/)
|
||||||
- 📈 [Reranker API](https://localai.io/features/reranker/)
|
- 📈 [Reranker API](https://localai.io/features/reranker/)
|
||||||
- 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/)
|
- 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/)
|
||||||
- [Agentic capabilities](https://github.com/mudler/LocalAGI)
|
- 🆕🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) - Agentic capabilities with external tools and [LocalAGI's Agentic capabilities](https://github.com/mudler/LocalAGI)
|
||||||
- 🔊 Voice activity detection (Silero-VAD support)
|
- 🔊 Voice activity detection (Silero-VAD support)
|
||||||
- 🌍 Integrated WebUI!
|
- 🌍 Integrated WebUI!
|
||||||
|
|
||||||
@@ -260,6 +275,7 @@ LocalAI supports a comprehensive range of AI backends with multiple acceleration
|
|||||||
| **piper** | Fast neural TTS system | CPU |
|
| **piper** | Fast neural TTS system | CPU |
|
||||||
| **kitten-tts** | Kitten TTS models | CPU |
|
| **kitten-tts** | Kitten TTS models | CPU |
|
||||||
| **silero-vad** | Voice Activity Detection | CPU |
|
| **silero-vad** | Voice Activity Detection | CPU |
|
||||||
|
| **neutts** | Text-to-speech with voice cloning | CUDA 12, ROCm, CPU |
|
||||||
|
|
||||||
### Image & Video Generation
|
### Image & Video Generation
|
||||||
| Backend | Description | Acceleration Support |
|
| Backend | Description | Acceleration Support |
|
||||||
@@ -281,7 +297,7 @@ LocalAI supports a comprehensive range of AI backends with multiple acceleration
|
|||||||
|-------------------|-------------------|------------------|
|
|-------------------|-------------------|------------------|
|
||||||
| **NVIDIA CUDA 11** | llama.cpp, whisper, stablediffusion, diffusers, rerankers, bark, chatterbox | Nvidia hardware |
|
| **NVIDIA CUDA 11** | llama.cpp, whisper, stablediffusion, diffusers, rerankers, bark, chatterbox | Nvidia hardware |
|
||||||
| **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware |
|
| **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware |
|
||||||
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, bark | AMD Graphics |
|
| **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, bark, neutts | AMD Graphics |
|
||||||
| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, exllama2, coqui, kokoro, bark | Intel Arc, Intel iGPUs |
|
| **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, exllama2, coqui, kokoro, bark | Intel Arc, Intel iGPUs |
|
||||||
| **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, bark-cpp | Apple M1/M2/M3+ |
|
| **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, bark-cpp | Apple M1/M2/M3+ |
|
||||||
| **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs |
|
| **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs |
|
||||||
@@ -298,6 +314,12 @@ WebUIs:
|
|||||||
- https://github.com/go-skynet/LocalAI-frontend
|
- https://github.com/go-skynet/LocalAI-frontend
|
||||||
- QA-Pilot(An interactive chat project that leverages LocalAI LLMs for rapid understanding and navigation of GitHub code repository) https://github.com/reid41/QA-Pilot
|
- QA-Pilot(An interactive chat project that leverages LocalAI LLMs for rapid understanding and navigation of GitHub code repository) https://github.com/reid41/QA-Pilot
|
||||||
|
|
||||||
|
Agentic Libraries:
|
||||||
|
- https://github.com/mudler/cogito
|
||||||
|
|
||||||
|
MCPs:
|
||||||
|
- https://github.com/mudler/MCPs
|
||||||
|
|
||||||
Model galleries
|
Model galleries
|
||||||
- https://github.com/go-skynet/model-gallery
|
- https://github.com/go-skynet/model-gallery
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ context_size: 4096
|
|||||||
f16: true
|
f16: true
|
||||||
backend: llama-cpp
|
backend: llama-cpp
|
||||||
mmap: true
|
mmap: true
|
||||||
mmproj: minicpm-v-2_6-mmproj-f16.gguf
|
mmproj: minicpm-v-4_5-mmproj-f16.gguf
|
||||||
name: gpt-4o
|
name: gpt-4o
|
||||||
parameters:
|
parameters:
|
||||||
model: minicpm-v-2_6-Q4_K_M.gguf
|
model: minicpm-v-4_5-Q4_K_M.gguf
|
||||||
stopwords:
|
stopwords:
|
||||||
- <|im_end|>
|
- <|im_end|>
|
||||||
- <dummy32000>
|
- <dummy32000>
|
||||||
@@ -42,9 +42,9 @@ template:
|
|||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
|
|
||||||
download_files:
|
download_files:
|
||||||
- filename: minicpm-v-2_6-Q4_K_M.gguf
|
- filename: minicpm-v-4_5-Q4_K_M.gguf
|
||||||
sha256: 3a4078d53b46f22989adbf998ce5a3fd090b6541f112d7e936eb4204a04100b1
|
sha256: c1c3c33100b15b4caf7319acce4e23c0eb0ce1cbd12f70e8d24f05aa67b7512f
|
||||||
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/ggml-model-Q4_K_M.gguf
|
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/ggml-model-Q4_K_M.gguf
|
||||||
- filename: minicpm-v-2_6-mmproj-f16.gguf
|
- filename: minicpm-v-4_5-mmproj-f16.gguf
|
||||||
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/mmproj-model-f16.gguf
|
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/mmproj-model-f16.gguf
|
||||||
sha256: 4485f68a0f1aa404c391e788ea88ea653c100d8e98fe572698f701e5809711fd
|
sha256: 7a7225a32e8d453aaa3d22d8c579b5bf833c253f784cdb05c99c9a76fd616df8
|
||||||
@@ -2,10 +2,10 @@ context_size: 4096
|
|||||||
backend: llama-cpp
|
backend: llama-cpp
|
||||||
f16: true
|
f16: true
|
||||||
mmap: true
|
mmap: true
|
||||||
mmproj: minicpm-v-2_6-mmproj-f16.gguf
|
mmproj: minicpm-v-4_5-mmproj-f16.gguf
|
||||||
name: gpt-4o
|
name: gpt-4o
|
||||||
parameters:
|
parameters:
|
||||||
model: minicpm-v-2_6-Q4_K_M.gguf
|
model: minicpm-v-4_5-Q4_K_M.gguf
|
||||||
stopwords:
|
stopwords:
|
||||||
- <|im_end|>
|
- <|im_end|>
|
||||||
- <dummy32000>
|
- <dummy32000>
|
||||||
@@ -42,9 +42,9 @@ template:
|
|||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
|
|
||||||
download_files:
|
download_files:
|
||||||
- filename: minicpm-v-2_6-Q4_K_M.gguf
|
- filename: minicpm-v-4_5-Q4_K_M.gguf
|
||||||
sha256: 3a4078d53b46f22989adbf998ce5a3fd090b6541f112d7e936eb4204a04100b1
|
sha256: c1c3c33100b15b4caf7319acce4e23c0eb0ce1cbd12f70e8d24f05aa67b7512f
|
||||||
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/ggml-model-Q4_K_M.gguf
|
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/ggml-model-Q4_K_M.gguf
|
||||||
- filename: minicpm-v-2_6-mmproj-f16.gguf
|
- filename: minicpm-v-4_5-mmproj-f16.gguf
|
||||||
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/mmproj-model-f16.gguf
|
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/mmproj-model-f16.gguf
|
||||||
sha256: 4485f68a0f1aa404c391e788ea88ea653c100d8e98fe572698f701e5809711fd
|
sha256: 7a7225a32e8d453aaa3d22d8c579b5bf833c253f784cdb05c99c9a76fd616df8
|
||||||
@@ -2,10 +2,10 @@ context_size: 4096
|
|||||||
backend: llama-cpp
|
backend: llama-cpp
|
||||||
f16: true
|
f16: true
|
||||||
mmap: true
|
mmap: true
|
||||||
mmproj: minicpm-v-2_6-mmproj-f16.gguf
|
mmproj: minicpm-v-4_5-mmproj-f16.gguf
|
||||||
name: gpt-4o
|
name: gpt-4o
|
||||||
parameters:
|
parameters:
|
||||||
model: minicpm-v-2_6-Q4_K_M.gguf
|
model: minicpm-v-4_5-Q4_K_M.gguf
|
||||||
stopwords:
|
stopwords:
|
||||||
- <|im_end|>
|
- <|im_end|>
|
||||||
- <dummy32000>
|
- <dummy32000>
|
||||||
@@ -43,9 +43,9 @@ template:
|
|||||||
|
|
||||||
|
|
||||||
download_files:
|
download_files:
|
||||||
- filename: minicpm-v-2_6-Q4_K_M.gguf
|
- filename: minicpm-v-4_5-Q4_K_M.gguf
|
||||||
sha256: 3a4078d53b46f22989adbf998ce5a3fd090b6541f112d7e936eb4204a04100b1
|
sha256: c1c3c33100b15b4caf7319acce4e23c0eb0ce1cbd12f70e8d24f05aa67b7512f
|
||||||
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/ggml-model-Q4_K_M.gguf
|
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/ggml-model-Q4_K_M.gguf
|
||||||
- filename: minicpm-v-2_6-mmproj-f16.gguf
|
- filename: minicpm-v-4_5-mmproj-f16.gguf
|
||||||
uri: huggingface://openbmb/MiniCPM-V-2_6-gguf/mmproj-model-f16.gguf
|
uri: huggingface://openbmb/MiniCPM-V-4_5-gguf/mmproj-model-f16.gguf
|
||||||
sha256: 4485f68a0f1aa404c391e788ea88ea653c100d8e98fe572698f701e5809711fd
|
sha256: 7a7225a32e8d453aaa3d22d8c579b5bf833c253f784cdb05c99c9a76fd616df8
|
||||||
@@ -197,7 +197,7 @@ EOT
|
|||||||
|
|
||||||
|
|
||||||
# Copy libraries using a script to handle architecture differences
|
# Copy libraries using a script to handle architecture differences
|
||||||
RUN make -C /LocalAI/backend/cpp/llama-cpp package
|
RUN make -BC /LocalAI/backend/cpp/llama-cpp package
|
||||||
|
|
||||||
|
|
||||||
FROM scratch
|
FROM scratch
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ RUN apt-get update && \
|
|||||||
curl python3-pip \
|
curl python3-pip \
|
||||||
python-is-python3 \
|
python-is-python3 \
|
||||||
python3-dev llvm \
|
python3-dev llvm \
|
||||||
python3-venv make && \
|
python3-venv make cmake && \
|
||||||
apt-get clean && \
|
apt-get clean && \
|
||||||
rm -rf /var/lib/apt/lists/* && \
|
rm -rf /var/lib/apt/lists/* && \
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
|
|||||||
@@ -276,6 +276,7 @@ message TranscriptRequest {
|
|||||||
string language = 3;
|
string language = 3;
|
||||||
uint32 threads = 4;
|
uint32 threads = 4;
|
||||||
bool translate = 5;
|
bool translate = 5;
|
||||||
|
bool diarize = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TranscriptResult {
|
message TranscriptResult {
|
||||||
@@ -305,7 +306,7 @@ message GenerateImageRequest {
|
|||||||
// Diffusers
|
// Diffusers
|
||||||
string EnableParameters = 10;
|
string EnableParameters = 10;
|
||||||
int32 CLIPSkip = 11;
|
int32 CLIPSkip = 11;
|
||||||
|
|
||||||
// Reference images for models that support them (e.g., Flux Kontext)
|
// Reference images for models that support them (e.g., Flux Kontext)
|
||||||
repeated string ref_images = 12;
|
repeated string ref_images = 12;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
LLAMA_VERSION?=3de008208b9b8a33f49f979097a99b4d59e6e521
|
LLAMA_VERSION?=5a4ff43e7dd049e35942bc3d12361dab2f155544
|
||||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||||
|
|
||||||
CMAKE_ARGS?=
|
CMAKE_ARGS?=
|
||||||
@@ -14,7 +14,7 @@ CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=OFF
|
|||||||
|
|
||||||
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
|
||||||
ifeq ($(NATIVE),false)
|
ifeq ($(NATIVE),false)
|
||||||
CMAKE_ARGS+=-DGGML_NATIVE=OFF
|
CMAKE_ARGS+=-DGGML_NATIVE=OFF -DLLAMA_OPENSSL=OFF
|
||||||
endif
|
endif
|
||||||
# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
|
# If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically
|
||||||
ifeq ($(BUILD_TYPE),cublas)
|
ifeq ($(BUILD_TYPE),cublas)
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ static void start_llama_server(server_context& ctx_server) {
|
|||||||
ctx_server.queue_tasks.start_loop();
|
ctx_server.queue_tasks.start_loop();
|
||||||
}
|
}
|
||||||
|
|
||||||
json parse_options(bool streaming, const backend::PredictOptions* predict)
|
json parse_options(bool streaming, const backend::PredictOptions* predict, const server_context& ctx_server)
|
||||||
{
|
{
|
||||||
|
|
||||||
// Create now a json data from the prediction options instead
|
// Create now a json data from the prediction options instead
|
||||||
@@ -147,6 +147,28 @@ json parse_options(bool streaming, const backend::PredictOptions* predict)
|
|||||||
// data["n_probs"] = predict->nprobs();
|
// data["n_probs"] = predict->nprobs();
|
||||||
//TODO: images,
|
//TODO: images,
|
||||||
|
|
||||||
|
// Serialize grammar triggers from server context to JSON array
|
||||||
|
if (!ctx_server.params_base.sampling.grammar_triggers.empty()) {
|
||||||
|
json grammar_triggers = json::array();
|
||||||
|
for (const auto& trigger : ctx_server.params_base.sampling.grammar_triggers) {
|
||||||
|
json trigger_json;
|
||||||
|
trigger_json["value"] = trigger.value;
|
||||||
|
// Always serialize as WORD type since upstream converts WORD to TOKEN internally
|
||||||
|
trigger_json["type"] = static_cast<int>(COMMON_GRAMMAR_TRIGGER_TYPE_WORD);
|
||||||
|
grammar_triggers.push_back(trigger_json);
|
||||||
|
}
|
||||||
|
data["grammar_triggers"] = grammar_triggers;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize preserved tokens from server context to JSON array
|
||||||
|
if (!ctx_server.params_base.sampling.preserved_tokens.empty()) {
|
||||||
|
json preserved_tokens = json::array();
|
||||||
|
for (const auto& token : ctx_server.params_base.sampling.preserved_tokens) {
|
||||||
|
preserved_tokens.push_back(common_token_to_piece(ctx_server.ctx, token));
|
||||||
|
}
|
||||||
|
data["preserved_tokens"] = preserved_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,7 +229,7 @@ static void add_rpc_devices(std::string servers) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void params_parse(const backend::ModelOptions* request,
|
static void params_parse(server_context& ctx_server, const backend::ModelOptions* request,
|
||||||
common_params & params) {
|
common_params & params) {
|
||||||
|
|
||||||
// this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809
|
// this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809
|
||||||
@@ -231,6 +253,7 @@ static void params_parse(const backend::ModelOptions* request,
|
|||||||
params.cpuparams.n_threads = request->threads();
|
params.cpuparams.n_threads = request->threads();
|
||||||
params.n_gpu_layers = request->ngpulayers();
|
params.n_gpu_layers = request->ngpulayers();
|
||||||
params.n_batch = request->nbatch();
|
params.n_batch = request->nbatch();
|
||||||
|
params.n_ubatch = request->nbatch(); // fixes issue with reranking models being limited to 512 tokens (the default n_ubatch size); allows for setting the maximum input amount of tokens thereby avoiding this error "input is too large to process. increase the physical batch size"
|
||||||
// Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1
|
// Set params.n_parallel by environment variable (LLAMA_PARALLEL), defaults to 1
|
||||||
//params.n_parallel = 1;
|
//params.n_parallel = 1;
|
||||||
const char *env_parallel = std::getenv("LLAMACPP_PARALLEL");
|
const char *env_parallel = std::getenv("LLAMACPP_PARALLEL");
|
||||||
@@ -268,6 +291,11 @@ static void params_parse(const backend::ModelOptions* request,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!params.kv_overrides.empty()) {
|
||||||
|
params.kv_overrides.emplace_back();
|
||||||
|
params.kv_overrides.back().key[0] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Add yarn
|
// TODO: Add yarn
|
||||||
|
|
||||||
if (!request->tensorsplit().empty()) {
|
if (!request->tensorsplit().empty()) {
|
||||||
@@ -346,14 +374,14 @@ static void params_parse(const backend::ModelOptions* request,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (request->grammartriggers_size() > 0) {
|
if (request->grammartriggers_size() > 0) {
|
||||||
params.sampling.grammar_lazy = true;
|
//params.sampling.grammar_lazy = true;
|
||||||
|
// Store grammar trigger words for processing after model is loaded
|
||||||
for (int i = 0; i < request->grammartriggers_size(); i++) {
|
for (int i = 0; i < request->grammartriggers_size(); i++) {
|
||||||
|
const auto & word = request->grammartriggers(i).word();
|
||||||
common_grammar_trigger trigger;
|
common_grammar_trigger trigger;
|
||||||
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
|
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD;
|
||||||
trigger.value = request->grammartriggers(i).word();
|
trigger.value = word;
|
||||||
// trigger.at_start = request->grammartriggers(i).at_start();
|
params.sampling.grammar_triggers.push_back(std::move(trigger));
|
||||||
params.sampling.grammar_triggers.push_back(trigger);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -376,7 +404,7 @@ public:
|
|||||||
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) {
|
grpc::Status LoadModel(ServerContext* context, const backend::ModelOptions* request, backend::Result* result) {
|
||||||
// Implement LoadModel RPC
|
// Implement LoadModel RPC
|
||||||
common_params params;
|
common_params params;
|
||||||
params_parse(request, params);
|
params_parse(ctx_server, request, params);
|
||||||
|
|
||||||
common_init();
|
common_init();
|
||||||
|
|
||||||
@@ -395,6 +423,39 @@ public:
|
|||||||
return Status::CANCELLED;
|
return Status::CANCELLED;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process grammar triggers now that vocab is available
|
||||||
|
if (!params.sampling.grammar_triggers.empty()) {
|
||||||
|
std::vector<common_grammar_trigger> processed_triggers;
|
||||||
|
for (const auto& trigger : params.sampling.grammar_triggers) {
|
||||||
|
if (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) {
|
||||||
|
auto ids = common_tokenize(ctx_server.vocab, trigger.value, /* add_special= */ false, /* parse_special= */ true);
|
||||||
|
if (ids.size() == 1) {
|
||||||
|
auto token = ids[0];
|
||||||
|
// Add the token to preserved_tokens if not already present
|
||||||
|
if (params.sampling.preserved_tokens.find(token) == params.sampling.preserved_tokens.end()) {
|
||||||
|
params.sampling.preserved_tokens.insert(token);
|
||||||
|
LOG_INF("Added grammar trigger token to preserved tokens: %d (`%s`)\n", token, trigger.value.c_str());
|
||||||
|
}
|
||||||
|
LOG_INF("Grammar trigger token: %d (`%s`)\n", token, trigger.value.c_str());
|
||||||
|
common_grammar_trigger processed_trigger;
|
||||||
|
processed_trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN;
|
||||||
|
processed_trigger.value = trigger.value;
|
||||||
|
processed_trigger.token = token;
|
||||||
|
processed_triggers.push_back(std::move(processed_trigger));
|
||||||
|
} else {
|
||||||
|
LOG_INF("Grammar trigger word: `%s`\n", trigger.value.c_str());
|
||||||
|
processed_triggers.push_back(trigger);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
processed_triggers.push_back(trigger);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Update the grammar triggers in params_base
|
||||||
|
ctx_server.params_base.sampling.grammar_triggers = std::move(processed_triggers);
|
||||||
|
// Also update preserved_tokens in params_base
|
||||||
|
ctx_server.params_base.sampling.preserved_tokens = params.sampling.preserved_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
//ctx_server.init();
|
//ctx_server.init();
|
||||||
result->set_message("Loading succeeded");
|
result->set_message("Loading succeeded");
|
||||||
result->set_success(true);
|
result->set_success(true);
|
||||||
@@ -405,7 +466,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override {
|
||||||
json data = parse_options(true, request);
|
json data = parse_options(true, request, ctx_server);
|
||||||
|
|
||||||
|
|
||||||
//Raise error if embeddings is set to true
|
//Raise error if embeddings is set to true
|
||||||
@@ -468,12 +529,12 @@ public:
|
|||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
task.index = i;
|
task.index = i;
|
||||||
|
|
||||||
task.prompt_tokens = std::move(inputs[i]);
|
task.tokens = std::move(inputs[i]);
|
||||||
task.params = server_task::params_from_json_cmpl(
|
task.params = server_task::params_from_json_cmpl(
|
||||||
ctx_server.ctx,
|
ctx_server.ctx,
|
||||||
ctx_server.params_base,
|
ctx_server.params_base,
|
||||||
data);
|
data);
|
||||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
task.id_slot = json_value(data, "id_slot", -1);
|
||||||
|
|
||||||
// OAI-compat
|
// OAI-compat
|
||||||
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
|
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
@@ -555,7 +616,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
|
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) {
|
||||||
json data = parse_options(true, request);
|
json data = parse_options(true, request, ctx_server);
|
||||||
|
|
||||||
data["stream"] = false;
|
data["stream"] = false;
|
||||||
//Raise error if embeddings is set to true
|
//Raise error if embeddings is set to true
|
||||||
@@ -623,12 +684,12 @@ public:
|
|||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
task.index = i;
|
task.index = i;
|
||||||
|
|
||||||
task.prompt_tokens = std::move(inputs[i]);
|
task.tokens = std::move(inputs[i]);
|
||||||
task.params = server_task::params_from_json_cmpl(
|
task.params = server_task::params_from_json_cmpl(
|
||||||
ctx_server.ctx,
|
ctx_server.ctx,
|
||||||
ctx_server.params_base,
|
ctx_server.params_base,
|
||||||
data);
|
data);
|
||||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
task.id_slot = json_value(data, "id_slot", -1);
|
||||||
|
|
||||||
// OAI-compat
|
// OAI-compat
|
||||||
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
|
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
@@ -690,7 +751,7 @@ public:
|
|||||||
|
|
||||||
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) {
|
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) {
|
||||||
|
|
||||||
json body = parse_options(false, request);
|
json body = parse_options(false, request, ctx_server);
|
||||||
|
|
||||||
body["stream"] = false;
|
body["stream"] = false;
|
||||||
|
|
||||||
@@ -701,7 +762,7 @@ public:
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// for the shape of input/content, see tokenize_input_prompts()
|
// for the shape of input/content, see tokenize_input_prompts()
|
||||||
json prompt = body.at("prompt");
|
json prompt = body.at("embeddings");
|
||||||
|
|
||||||
|
|
||||||
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||||
@@ -712,6 +773,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int embd_normalize = 2; // default to Euclidean/L2 norm
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
@@ -723,11 +785,10 @@ public:
|
|||||||
|
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
task.index = i;
|
task.index = i;
|
||||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
task.tokens = std::move(tokenized_prompts[i]);
|
||||||
|
|
||||||
// OAI-compat
|
|
||||||
task.params.oaicompat = OAICOMPAT_TYPE_EMBEDDING;
|
|
||||||
|
|
||||||
|
task.params.oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
|
task.params.embd_normalize = embd_normalize;
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -743,9 +804,8 @@ public:
|
|||||||
responses.push_back(res->to_json());
|
responses.push_back(res->to_json());
|
||||||
}
|
}
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, error_data.value("content", ""));
|
error = true;
|
||||||
}, [&]() {
|
}, [&]() {
|
||||||
// NOTE: we should try to check when the writer is closed here
|
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -755,12 +815,36 @@ public:
|
|||||||
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
|
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> embeddings = responses[0].value("embedding", std::vector<float>());
|
std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl;
|
||||||
// loop the vector and set the embeddings results
|
|
||||||
for (int i = 0; i < embeddings.size(); i++) {
|
// Process the responses and extract embeddings
|
||||||
embeddingResult->add_embeddings(embeddings[i]);
|
for (const auto & response_elem : responses) {
|
||||||
|
// Check if the response has an "embedding" field
|
||||||
|
if (response_elem.contains("embedding")) {
|
||||||
|
json embedding_data = json_value(response_elem, "embedding", json::array());
|
||||||
|
|
||||||
|
if (embedding_data.is_array() && !embedding_data.empty()) {
|
||||||
|
for (const auto & embedding_vector : embedding_data) {
|
||||||
|
if (embedding_vector.is_array()) {
|
||||||
|
for (const auto & embedding_value : embedding_vector) {
|
||||||
|
embeddingResult->add_embeddings(embedding_value.get<float>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Check if the response itself contains the embedding data directly
|
||||||
|
if (response_elem.is_array()) {
|
||||||
|
for (const auto & embedding_value : response_elem) {
|
||||||
|
embeddingResult->add_embeddings(embedding_value.get<float>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return grpc::Status::OK;
|
return grpc::Status::OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -778,11 +862,6 @@ public:
|
|||||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tokenize the query
|
|
||||||
auto tokenized_query = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, request->query(), /* add_special */ false, true);
|
|
||||||
if (tokenized_query.size() != 1) {
|
|
||||||
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must contain only a single prompt");
|
|
||||||
}
|
|
||||||
// Create and queue the task
|
// Create and queue the task
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
bool error = false;
|
||||||
@@ -794,14 +873,13 @@ public:
|
|||||||
documents.push_back(request->documents(i));
|
documents.push_back(request->documents(i));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, documents, /* add_special */ false, true);
|
tasks.reserve(documents.size());
|
||||||
tasks.reserve(tokenized_docs.size());
|
for (size_t i = 0; i < documents.size(); i++) {
|
||||||
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, request->query(), documents[i]);
|
||||||
auto tmp = format_rerank(ctx_server.vocab, tokenized_query[0], tokenized_docs[i]);
|
|
||||||
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
task.index = i;
|
task.index = i;
|
||||||
task.prompt_tokens = std::move(tmp);
|
task.tokens = std::move(tmp);
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -854,7 +932,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
|
grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response) {
|
||||||
json body = parse_options(false, request);
|
json body = parse_options(false, request, ctx_server);
|
||||||
body["stream"] = false;
|
body["stream"] = false;
|
||||||
|
|
||||||
json tokens_response = json::array();
|
json tokens_response = json::array();
|
||||||
|
|||||||
2
backend/go/stablediffusion-ggml/.gitignore
vendored
2
backend/go/stablediffusion-ggml/.gitignore
vendored
@@ -1,4 +1,6 @@
|
|||||||
package/
|
package/
|
||||||
sources/
|
sources/
|
||||||
|
.cache/
|
||||||
|
build/
|
||||||
libgosd.so
|
libgosd.so
|
||||||
stablediffusion-ggml
|
stablediffusion-ggml
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# stablediffusion.cpp (ggml)
|
# stablediffusion.cpp (ggml)
|
||||||
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
STABLEDIFFUSION_GGML_REPO?=https://github.com/leejet/stable-diffusion.cpp
|
||||||
STABLEDIFFUSION_GGML_VERSION?=4c6475f9176bf99271ccf5a2817b30a490b83db0
|
STABLEDIFFUSION_GGML_VERSION?=0ebe6fe118f125665939b27c89f34ed38716bff8
|
||||||
|
|
||||||
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
CMAKE_ARGS+=-DGGML_MAX_NAME=128
|
||||||
|
|
||||||
|
|||||||
@@ -4,17 +4,11 @@
|
|||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
#include <iostream>
|
|
||||||
#include <random>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include "gosd.h"
|
#include "gosd.h"
|
||||||
|
|
||||||
// #include "preprocessing.hpp"
|
|
||||||
#include "flux.hpp"
|
|
||||||
#include "stable-diffusion.h"
|
|
||||||
|
|
||||||
#define STB_IMAGE_IMPLEMENTATION
|
#define STB_IMAGE_IMPLEMENTATION
|
||||||
#define STB_IMAGE_STATIC
|
#define STB_IMAGE_STATIC
|
||||||
#include "stb_image.h"
|
#include "stb_image.h"
|
||||||
@@ -29,7 +23,7 @@
|
|||||||
|
|
||||||
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
|
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
|
||||||
const char* sample_method_str[] = {
|
const char* sample_method_str[] = {
|
||||||
"euler_a",
|
"default",
|
||||||
"euler",
|
"euler",
|
||||||
"heun",
|
"heun",
|
||||||
"dpm2",
|
"dpm2",
|
||||||
@@ -41,19 +35,27 @@ const char* sample_method_str[] = {
|
|||||||
"lcm",
|
"lcm",
|
||||||
"ddim_trailing",
|
"ddim_trailing",
|
||||||
"tcd",
|
"tcd",
|
||||||
|
"euler_a",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static_assert(std::size(sample_method_str) == SAMPLE_METHOD_COUNT, "sample method mismatch");
|
||||||
|
|
||||||
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
|
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
|
||||||
const char* schedule_str[] = {
|
const char* schedulers[] = {
|
||||||
"default",
|
"default",
|
||||||
"discrete",
|
"discrete",
|
||||||
"karras",
|
"karras",
|
||||||
"exponential",
|
"exponential",
|
||||||
"ays",
|
"ays",
|
||||||
"gits",
|
"gits",
|
||||||
|
"smoothstep",
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static_assert(std::size(schedulers) == SCHEDULE_COUNT, "schedulers mismatch");
|
||||||
|
|
||||||
sd_ctx_t* sd_c;
|
sd_ctx_t* sd_c;
|
||||||
|
// Moved from the context (load time) to generation time params
|
||||||
|
scheduler_t scheduler = scheduler_t::DEFAULT;
|
||||||
|
|
||||||
sample_method_t sample_method;
|
sample_method_t sample_method;
|
||||||
|
|
||||||
@@ -105,7 +107,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
|||||||
const char *clip_g_path = "";
|
const char *clip_g_path = "";
|
||||||
const char *t5xxl_path = "";
|
const char *t5xxl_path = "";
|
||||||
const char *vae_path = "";
|
const char *vae_path = "";
|
||||||
const char *scheduler = "";
|
const char *scheduler_str = "";
|
||||||
const char *sampler = "";
|
const char *sampler = "";
|
||||||
char *lora_dir = model_path;
|
char *lora_dir = model_path;
|
||||||
bool lora_dir_allocated = false;
|
bool lora_dir_allocated = false;
|
||||||
@@ -133,7 +135,7 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
|||||||
vae_path = optval;
|
vae_path = optval;
|
||||||
}
|
}
|
||||||
if (!strcmp(optname, "scheduler")) {
|
if (!strcmp(optname, "scheduler")) {
|
||||||
scheduler = optval;
|
scheduler_str = optval;
|
||||||
}
|
}
|
||||||
if (!strcmp(optname, "sampler")) {
|
if (!strcmp(optname, "sampler")) {
|
||||||
sampler = optval;
|
sampler = optval;
|
||||||
@@ -166,26 +168,17 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
|||||||
}
|
}
|
||||||
if (sample_method_found == -1) {
|
if (sample_method_found == -1) {
|
||||||
fprintf(stderr, "Invalid sample method, default to EULER_A!\n");
|
fprintf(stderr, "Invalid sample method, default to EULER_A!\n");
|
||||||
sample_method_found = EULER_A;
|
sample_method_found = sample_method_t::SAMPLE_METHOD_DEFAULT;
|
||||||
}
|
}
|
||||||
sample_method = (sample_method_t)sample_method_found;
|
sample_method = (sample_method_t)sample_method_found;
|
||||||
|
|
||||||
int schedule_found = -1;
|
|
||||||
for (int d = 0; d < SCHEDULE_COUNT; d++) {
|
for (int d = 0; d < SCHEDULE_COUNT; d++) {
|
||||||
if (!strcmp(scheduler, schedule_str[d])) {
|
if (!strcmp(scheduler_str, schedulers[d])) {
|
||||||
schedule_found = d;
|
scheduler = (scheduler_t)d;
|
||||||
fprintf (stderr, "Found scheduler: %s\n", scheduler);
|
fprintf (stderr, "Found scheduler: %s\n", scheduler_str);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (schedule_found == -1) {
|
|
||||||
fprintf (stderr, "Invalid scheduler! using DEFAULT\n");
|
|
||||||
schedule_found = DEFAULT;
|
|
||||||
}
|
|
||||||
|
|
||||||
schedule_t schedule = (schedule_t)schedule_found;
|
|
||||||
|
|
||||||
fprintf (stderr, "Creating context\n");
|
fprintf (stderr, "Creating context\n");
|
||||||
sd_ctx_params_t ctx_params;
|
sd_ctx_params_t ctx_params;
|
||||||
sd_ctx_params_init(&ctx_params);
|
sd_ctx_params_init(&ctx_params);
|
||||||
@@ -199,13 +192,10 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
|||||||
ctx_params.control_net_path = "";
|
ctx_params.control_net_path = "";
|
||||||
ctx_params.lora_model_dir = lora_dir;
|
ctx_params.lora_model_dir = lora_dir;
|
||||||
ctx_params.embedding_dir = "";
|
ctx_params.embedding_dir = "";
|
||||||
ctx_params.stacked_id_embed_dir = "";
|
|
||||||
ctx_params.vae_decode_only = false;
|
ctx_params.vae_decode_only = false;
|
||||||
ctx_params.vae_tiling = false;
|
|
||||||
ctx_params.free_params_immediately = false;
|
ctx_params.free_params_immediately = false;
|
||||||
ctx_params.n_threads = threads;
|
ctx_params.n_threads = threads;
|
||||||
ctx_params.rng_type = STD_DEFAULT_RNG;
|
ctx_params.rng_type = STD_DEFAULT_RNG;
|
||||||
ctx_params.schedule = schedule;
|
|
||||||
sd_ctx_t* sd_ctx = new_sd_ctx(&ctx_params);
|
sd_ctx_t* sd_ctx = new_sd_ctx(&ctx_params);
|
||||||
|
|
||||||
if (sd_ctx == NULL) {
|
if (sd_ctx == NULL) {
|
||||||
@@ -228,7 +218,49 @@ int load_model(const char *model, char *model_path, char* options[], int threads
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int64_t seed, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
|
void sd_tiling_params_set_enabled(sd_tiling_params_t *params, bool enabled) {
|
||||||
|
params->enabled = enabled;
|
||||||
|
}
|
||||||
|
|
||||||
|
void sd_tiling_params_set_tile_sizes(sd_tiling_params_t *params, int tile_size_x, int tile_size_y) {
|
||||||
|
params->tile_size_x = tile_size_x;
|
||||||
|
params->tile_size_y = tile_size_y;
|
||||||
|
}
|
||||||
|
|
||||||
|
void sd_tiling_params_set_rel_sizes(sd_tiling_params_t *params, float rel_size_x, float rel_size_y) {
|
||||||
|
params->rel_size_x = rel_size_x;
|
||||||
|
params->rel_size_y = rel_size_y;
|
||||||
|
}
|
||||||
|
|
||||||
|
void sd_tiling_params_set_target_overlap(sd_tiling_params_t *params, float target_overlap) {
|
||||||
|
params->target_overlap = target_overlap;
|
||||||
|
}
|
||||||
|
|
||||||
|
sd_tiling_params_t* sd_img_gen_params_get_vae_tiling_params(sd_img_gen_params_t *params) {
|
||||||
|
return ¶ms->vae_tiling_params;
|
||||||
|
}
|
||||||
|
|
||||||
|
sd_img_gen_params_t* sd_img_gen_params_new(void) {
|
||||||
|
sd_img_gen_params_t *params = (sd_img_gen_params_t *)std::malloc(sizeof(sd_img_gen_params_t));
|
||||||
|
sd_img_gen_params_init(params);
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
|
||||||
|
void sd_img_gen_params_set_prompts(sd_img_gen_params_t *params, const char *prompt, const char *negative_prompt) {
|
||||||
|
params->prompt = prompt;
|
||||||
|
params->negative_prompt = negative_prompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
void sd_img_gen_params_set_dimensions(sd_img_gen_params_t *params, int width, int height) {
|
||||||
|
params->width = width;
|
||||||
|
params->height = height;
|
||||||
|
}
|
||||||
|
|
||||||
|
void sd_img_gen_params_set_seed(sd_img_gen_params_t *params, int64_t seed) {
|
||||||
|
params->seed = seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count) {
|
||||||
|
|
||||||
sd_image_t* results;
|
sd_image_t* results;
|
||||||
|
|
||||||
@@ -236,20 +268,15 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
|||||||
|
|
||||||
fprintf (stderr, "Generating image\n");
|
fprintf (stderr, "Generating image\n");
|
||||||
|
|
||||||
sd_img_gen_params_t p;
|
p->sample_params.guidance.txt_cfg = cfg_scale;
|
||||||
sd_img_gen_params_init(&p);
|
p->sample_params.guidance.slg.layers = skip_layers.data();
|
||||||
|
p->sample_params.guidance.slg.layer_count = skip_layers.size();
|
||||||
|
p->sample_params.sample_method = sample_method;
|
||||||
|
p->sample_params.sample_steps = steps;
|
||||||
|
p->sample_params.scheduler = scheduler;
|
||||||
|
|
||||||
p.prompt = text;
|
int width = p->width;
|
||||||
p.negative_prompt = negativeText;
|
int height = p->height;
|
||||||
p.guidance.txt_cfg = cfg_scale;
|
|
||||||
p.guidance.slg.layers = skip_layers.data();
|
|
||||||
p.guidance.slg.layer_count = skip_layers.size();
|
|
||||||
p.width = width;
|
|
||||||
p.height = height;
|
|
||||||
p.sample_method = sample_method;
|
|
||||||
p.sample_steps = steps;
|
|
||||||
p.seed = seed;
|
|
||||||
p.input_id_images_path = "";
|
|
||||||
|
|
||||||
// Handle input image for img2img
|
// Handle input image for img2img
|
||||||
bool has_input_image = (src_image != NULL && strlen(src_image) > 0);
|
bool has_input_image = (src_image != NULL && strlen(src_image) > 0);
|
||||||
@@ -298,13 +325,13 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
|||||||
input_image_buffer = resized_image_buffer;
|
input_image_buffer = resized_image_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
p.init_image = {(uint32_t)width, (uint32_t)height, 3, input_image_buffer};
|
p->init_image = {(uint32_t)width, (uint32_t)height, 3, input_image_buffer};
|
||||||
p.strength = strength;
|
p->strength = strength;
|
||||||
fprintf(stderr, "Using img2img with strength: %.2f\n", strength);
|
fprintf(stderr, "Using img2img with strength: %.2f\n", strength);
|
||||||
} else {
|
} else {
|
||||||
// No input image, use empty image for text-to-image
|
// No input image, use empty image for text-to-image
|
||||||
p.init_image = {(uint32_t)width, (uint32_t)height, 3, NULL};
|
p->init_image = {(uint32_t)width, (uint32_t)height, 3, NULL};
|
||||||
p.strength = 0.0f;
|
p->strength = 0.0f;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle mask image for inpainting
|
// Handle mask image for inpainting
|
||||||
@@ -344,12 +371,12 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
|||||||
mask_image_buffer = resized_mask_buffer;
|
mask_image_buffer = resized_mask_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, mask_image_buffer};
|
p->mask_image = {(uint32_t)width, (uint32_t)height, 1, mask_image_buffer};
|
||||||
fprintf(stderr, "Using inpainting with mask\n");
|
fprintf(stderr, "Using inpainting with mask\n");
|
||||||
} else {
|
} else {
|
||||||
// No mask image, create default full mask
|
// No mask image, create default full mask
|
||||||
default_mask_image_vec.resize(width * height, 255);
|
default_mask_image_vec.resize(width * height, 255);
|
||||||
p.mask_image = {(uint32_t)width, (uint32_t)height, 1, default_mask_image_vec.data()};
|
p->mask_image = {(uint32_t)width, (uint32_t)height, 1, default_mask_image_vec.data()};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle reference images
|
// Handle reference images
|
||||||
@@ -407,13 +434,15 @@ int gen_image(char *text, char *negativeText, int width, int height, int steps,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!ref_images_vec.empty()) {
|
if (!ref_images_vec.empty()) {
|
||||||
p.ref_images = ref_images_vec.data();
|
p->ref_images = ref_images_vec.data();
|
||||||
p.ref_images_count = ref_images_vec.size();
|
p->ref_images_count = ref_images_vec.size();
|
||||||
fprintf(stderr, "Using %zu reference images\n", ref_images_vec.size());
|
fprintf(stderr, "Using %zu reference images\n", ref_images_vec.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
results = generate_image(sd_c, &p);
|
results = generate_image(sd_c, p);
|
||||||
|
|
||||||
|
std::free(p);
|
||||||
|
|
||||||
if (results == NULL) {
|
if (results == NULL) {
|
||||||
fprintf (stderr, "NO results\n");
|
fprintf (stderr, "NO results\n");
|
||||||
|
|||||||
@@ -22,7 +22,18 @@ type SDGGML struct {
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
LoadModel func(model, model_apth string, options []uintptr, threads int32, diff int) int
|
LoadModel func(model, model_apth string, options []uintptr, threads int32, diff int) int
|
||||||
GenImage func(text, negativeText string, width, height, steps int, seed int64, dst string, cfgScale float32, srcImage string, strength float32, maskImage string, refImages []string, refImagesCount int) int
|
GenImage func(params uintptr, steps int, dst string, cfgScale float32, srcImage string, strength float32, maskImage string, refImages []string, refImagesCount int) int
|
||||||
|
|
||||||
|
TilingParamsSetEnabled func(params uintptr, enabled bool)
|
||||||
|
TilingParamsSetTileSizes func(params uintptr, tileSizeX int, tileSizeY int)
|
||||||
|
TilingParamsSetRelSizes func(params uintptr, relSizeX float32, relSizeY float32)
|
||||||
|
TilingParamsSetTargetOverlap func(params uintptr, targetOverlap float32)
|
||||||
|
|
||||||
|
ImgGenParamsNew func() uintptr
|
||||||
|
ImgGenParamsSetPrompts func(params uintptr, prompt string, negativePrompt string)
|
||||||
|
ImgGenParamsSetDimensions func(params uintptr, width int, height int)
|
||||||
|
ImgGenParamsSetSeed func(params uintptr, seed int64)
|
||||||
|
ImgGenParamsGetVaeTilingParams func(params uintptr) uintptr
|
||||||
)
|
)
|
||||||
|
|
||||||
// Copied from Purego internal/strings
|
// Copied from Purego internal/strings
|
||||||
@@ -120,7 +131,15 @@ func (sd *SDGGML) GenerateImage(opts *pb.GenerateImageRequest) error {
|
|||||||
// Default strength for img2img (0.75 is a good default)
|
// Default strength for img2img (0.75 is a good default)
|
||||||
strength := float32(0.75)
|
strength := float32(0.75)
|
||||||
|
|
||||||
ret := GenImage(t, negative, int(opts.Width), int(opts.Height), int(opts.Step), int64(opts.Seed), dst, sd.cfgScale, srcImage, strength, maskImage, refImages, refImagesCount)
|
// free'd by GenImage
|
||||||
|
p := ImgGenParamsNew()
|
||||||
|
ImgGenParamsSetPrompts(p, t, negative)
|
||||||
|
ImgGenParamsSetDimensions(p, int(opts.Width), int(opts.Height))
|
||||||
|
ImgGenParamsSetSeed(p, int64(opts.Seed))
|
||||||
|
vaep := ImgGenParamsGetVaeTilingParams(p)
|
||||||
|
TilingParamsSetEnabled(vaep, false)
|
||||||
|
|
||||||
|
ret := GenImage(p, int(opts.Step), dst, sd.cfgScale, srcImage, strength, maskImage, refImages, refImagesCount)
|
||||||
if ret != 0 {
|
if ret != 0 {
|
||||||
return fmt.Errorf("inference failed")
|
return fmt.Errorf("inference failed")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,23 @@
|
|||||||
|
#include <cstdint>
|
||||||
|
#include "stable-diffusion.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
void sd_tiling_params_set_enabled(sd_tiling_params_t *params, bool enabled);
|
||||||
|
void sd_tiling_params_set_tile_sizes(sd_tiling_params_t *params, int tile_size_x, int tile_size_y);
|
||||||
|
void sd_tiling_params_set_rel_sizes(sd_tiling_params_t *params, float rel_size_x, float rel_size_y);
|
||||||
|
void sd_tiling_params_set_target_overlap(sd_tiling_params_t *params, float target_overlap);
|
||||||
|
sd_tiling_params_t* sd_img_gen_params_get_vae_tiling_params(sd_img_gen_params_t *params);
|
||||||
|
|
||||||
|
sd_img_gen_params_t* sd_img_gen_params_new(void);
|
||||||
|
void sd_img_gen_params_set_prompts(sd_img_gen_params_t *params, const char *prompt, const char *negative_prompt);
|
||||||
|
void sd_img_gen_params_set_dimensions(sd_img_gen_params_t *params, int width, int height);
|
||||||
|
void sd_img_gen_params_set_seed(sd_img_gen_params_t *params, int64_t seed);
|
||||||
|
|
||||||
int load_model(const char *model, char *model_path, char* options[], int threads, int diffusionModel);
|
int load_model(const char *model, char *model_path, char* options[], int threads, int diffusionModel);
|
||||||
int gen_image(char *text, char *negativeText, int width, int height, int steps, int64_t seed, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count);
|
int gen_image(sd_img_gen_params_t *p, int steps, char *dst, float cfg_scale, char *src_image, float strength, char *mask_image, char **ref_images, int ref_images_count);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -11,14 +11,35 @@ var (
|
|||||||
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
addr = flag.String("addr", "localhost:50051", "the address to connect to")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type LibFuncs struct {
|
||||||
|
FuncPtr any
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
gosd, err := purego.Dlopen("./libgosd.so", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
gosd, err := purego.Dlopen("./libgosd.so", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
purego.RegisterLibFunc(&LoadModel, gosd, "load_model")
|
libFuncs := []LibFuncs{
|
||||||
purego.RegisterLibFunc(&GenImage, gosd, "gen_image")
|
{&LoadModel, "load_model"},
|
||||||
|
{&GenImage, "gen_image"},
|
||||||
|
{&TilingParamsSetEnabled, "sd_tiling_params_set_enabled"},
|
||||||
|
{&TilingParamsSetTileSizes, "sd_tiling_params_set_tile_sizes"},
|
||||||
|
{&TilingParamsSetRelSizes, "sd_tiling_params_set_rel_sizes"},
|
||||||
|
{&TilingParamsSetTargetOverlap, "sd_tiling_params_set_target_overlap"},
|
||||||
|
|
||||||
|
{&ImgGenParamsNew, "sd_img_gen_params_new"},
|
||||||
|
{&ImgGenParamsSetPrompts, "sd_img_gen_params_set_prompts"},
|
||||||
|
{&ImgGenParamsSetDimensions, "sd_img_gen_params_set_dimensions"},
|
||||||
|
{&ImgGenParamsSetSeed, "sd_img_gen_params_set_seed"},
|
||||||
|
{&ImgGenParamsGetVaeTilingParams, "sd_img_gen_params_get_vae_tiling_params"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, lf := range libFuncs {
|
||||||
|
purego.RegisterLibFunc(lf.FuncPtr, gosd, lf.Name)
|
||||||
|
}
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ JOBS?=$(shell nproc --ignore=1)
|
|||||||
|
|
||||||
# whisper.cpp version
|
# whisper.cpp version
|
||||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||||
WHISPER_CPP_VERSION?=7745fcf32846006128f16de429cfe1677c963b30
|
WHISPER_CPP_VERSION?=f16c12f3f55f5bd3d6ac8cf2f31ab90a42c884d5
|
||||||
|
SO_TARGET?=libgowhisper.so
|
||||||
|
|
||||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||||
|
|
||||||
@@ -57,15 +58,18 @@ sources/whisper.cpp:
|
|||||||
git checkout $(WHISPER_CPP_VERSION) && \
|
git checkout $(WHISPER_CPP_VERSION) && \
|
||||||
git submodule update --init --recursive --depth 1 --single-branch
|
git submodule update --init --recursive --depth 1 --single-branch
|
||||||
|
|
||||||
libgowhisper.so: sources/whisper.cpp CMakeLists.txt gowhisper.cpp gowhisper.h
|
# Detect OS
|
||||||
mkdir -p build && \
|
UNAME_S := $(shell uname -s)
|
||||||
cd build && \
|
|
||||||
cmake .. $(CMAKE_ARGS) && \
|
|
||||||
cmake --build . --config Release -j$(JOBS) && \
|
|
||||||
cd .. && \
|
|
||||||
mv build/libgowhisper.so ./
|
|
||||||
|
|
||||||
whisper: main.go gowhisper.go libgowhisper.so
|
# Only build CPU variants on Linux
|
||||||
|
ifeq ($(UNAME_S),Linux)
|
||||||
|
VARIANT_TARGETS = libgowhisper-avx.so libgowhisper-avx2.so libgowhisper-avx512.so libgowhisper-fallback.so
|
||||||
|
else
|
||||||
|
# On non-Linux (e.g., Darwin), build only fallback variant
|
||||||
|
VARIANT_TARGETS = libgowhisper-fallback.so
|
||||||
|
endif
|
||||||
|
|
||||||
|
whisper: main.go gowhisper.go $(VARIANT_TARGETS)
|
||||||
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o whisper ./
|
CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o whisper ./
|
||||||
|
|
||||||
package: whisper
|
package: whisper
|
||||||
@@ -73,5 +77,46 @@ package: whisper
|
|||||||
|
|
||||||
build: package
|
build: package
|
||||||
|
|
||||||
clean:
|
clean: purge
|
||||||
rm -rf libgowhisper.o build whisper
|
rm -rf libgowhisper*.so sources/whisper.cpp whisper
|
||||||
|
|
||||||
|
purge:
|
||||||
|
rm -rf build*
|
||||||
|
|
||||||
|
# Build all variants (Linux only)
|
||||||
|
ifeq ($(UNAME_S),Linux)
|
||||||
|
libgowhisper-avx.so: sources/whisper.cpp
|
||||||
|
$(MAKE) purge
|
||||||
|
$(info ${GREEN}I whisper build info:avx${RESET})
|
||||||
|
SO_TARGET=libgowhisper-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off" $(MAKE) libgowhisper-custom
|
||||||
|
rm -rfv build*
|
||||||
|
|
||||||
|
libgowhisper-avx2.so: sources/whisper.cpp
|
||||||
|
$(MAKE) purge
|
||||||
|
$(info ${GREEN}I whisper build info:avx2${RESET})
|
||||||
|
SO_TARGET=libgowhisper-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) libgowhisper-custom
|
||||||
|
rm -rfv build*
|
||||||
|
|
||||||
|
libgowhisper-avx512.so: sources/whisper.cpp
|
||||||
|
$(MAKE) purge
|
||||||
|
$(info ${GREEN}I whisper build info:avx512${RESET})
|
||||||
|
SO_TARGET=libgowhisper-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) libgowhisper-custom
|
||||||
|
rm -rfv build*
|
||||||
|
endif
|
||||||
|
|
||||||
|
# Build fallback variant (all platforms)
|
||||||
|
libgowhisper-fallback.so: sources/whisper.cpp
|
||||||
|
$(MAKE) purge
|
||||||
|
$(info ${GREEN}I whisper build info:fallback${RESET})
|
||||||
|
SO_TARGET=libgowhisper-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off" $(MAKE) libgowhisper-custom
|
||||||
|
rm -rfv build*
|
||||||
|
|
||||||
|
libgowhisper-custom: CMakeLists.txt gowhisper.cpp gowhisper.h
|
||||||
|
mkdir -p build-$(SO_TARGET) && \
|
||||||
|
cd build-$(SO_TARGET) && \
|
||||||
|
cmake .. $(CMAKE_ARGS) && \
|
||||||
|
cmake --build . --config Release -j$(JOBS) && \
|
||||||
|
cd .. && \
|
||||||
|
mv build-$(SO_TARGET)/libgowhisper.so ./$(SO_TARGET)
|
||||||
|
|
||||||
|
all: whisper package
|
||||||
|
|||||||
@@ -7,34 +7,35 @@ static struct whisper_vad_context *vctx;
|
|||||||
static struct whisper_context *ctx;
|
static struct whisper_context *ctx;
|
||||||
static std::vector<float> flat_segs;
|
static std::vector<float> flat_segs;
|
||||||
|
|
||||||
static void ggml_log_cb(enum ggml_log_level level, const char* log, void* data) {
|
static void ggml_log_cb(enum ggml_log_level level, const char *log,
|
||||||
const char* level_str;
|
void *data) {
|
||||||
|
const char *level_str;
|
||||||
|
|
||||||
if (!log) {
|
if (!log) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (level) {
|
switch (level) {
|
||||||
case GGML_LOG_LEVEL_DEBUG:
|
case GGML_LOG_LEVEL_DEBUG:
|
||||||
level_str = "DEBUG";
|
level_str = "DEBUG";
|
||||||
break;
|
break;
|
||||||
case GGML_LOG_LEVEL_INFO:
|
case GGML_LOG_LEVEL_INFO:
|
||||||
level_str = "INFO";
|
level_str = "INFO";
|
||||||
break;
|
break;
|
||||||
case GGML_LOG_LEVEL_WARN:
|
case GGML_LOG_LEVEL_WARN:
|
||||||
level_str = "WARN";
|
level_str = "WARN";
|
||||||
break;
|
break;
|
||||||
case GGML_LOG_LEVEL_ERROR:
|
case GGML_LOG_LEVEL_ERROR:
|
||||||
level_str = "ERROR";
|
level_str = "ERROR";
|
||||||
break;
|
break;
|
||||||
default: /* Potential future-proofing */
|
default: /* Potential future-proofing */
|
||||||
level_str = "?????";
|
level_str = "?????";
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(stderr, "[%-5s] ", level_str);
|
fprintf(stderr, "[%-5s] ", level_str);
|
||||||
fputs(log, stderr);
|
fputs(log, stderr);
|
||||||
fflush(stderr);
|
fflush(stderr);
|
||||||
}
|
}
|
||||||
|
|
||||||
int load_model(const char *const model_path) {
|
int load_model(const char *const model_path) {
|
||||||
@@ -105,8 +106,8 @@ int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
int transcribe(uint32_t threads, char *lang, bool translate, float pcmf32[],
|
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
|
||||||
size_t pcmf32_len, size_t *segs_out_len) {
|
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len) {
|
||||||
whisper_full_params wparams =
|
whisper_full_params wparams =
|
||||||
whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||||
|
|
||||||
@@ -120,6 +121,9 @@ int transcribe(uint32_t threads, char *lang, bool translate, float pcmf32[],
|
|||||||
wparams.translate = translate;
|
wparams.translate = translate;
|
||||||
wparams.debug_mode = true;
|
wparams.debug_mode = true;
|
||||||
wparams.print_progress = true;
|
wparams.print_progress = true;
|
||||||
|
wparams.tdrz_enable = tdrz;
|
||||||
|
|
||||||
|
fprintf(stderr, "info: Enable tdrz: %d\n", tdrz);
|
||||||
|
|
||||||
if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) {
|
if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) {
|
||||||
fprintf(stderr, "error: transcription failed\n");
|
fprintf(stderr, "error: transcription failed\n");
|
||||||
@@ -144,3 +148,7 @@ int n_tokens(int i) { return whisper_full_n_tokens(ctx, i); }
|
|||||||
int32_t get_token_id(int i, int j) {
|
int32_t get_token_id(int i, int j) {
|
||||||
return whisper_full_get_token_id(ctx, i, j);
|
return whisper_full_get_token_id(ctx, i, j);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool get_segment_speaker_turn_next(int i) {
|
||||||
|
return whisper_full_get_segment_speaker_turn_next(ctx, i);
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,15 +14,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
CppLoadModel func(modelPath string) int
|
CppLoadModel func(modelPath string) int
|
||||||
CppLoadModelVAD func(modelPath string) int
|
CppLoadModelVAD func(modelPath string) int
|
||||||
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
|
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
|
||||||
CppTranscribe func(threads uint32, lang string, translate bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer) int
|
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer) int
|
||||||
CppGetSegmentText func(i int) string
|
CppGetSegmentText func(i int) string
|
||||||
CppGetSegmentStart func(i int) int64
|
CppGetSegmentStart func(i int) int64
|
||||||
CppGetSegmentEnd func(i int) int64
|
CppGetSegmentEnd func(i int) int64
|
||||||
CppNTokens func(i int) int
|
CppNTokens func(i int) int
|
||||||
CppGetTokenID func(i int, j int) int
|
CppGetTokenID func(i int, j int) int
|
||||||
|
CppGetSegmentSpeakerTurnNext func(i int) bool
|
||||||
)
|
)
|
||||||
|
|
||||||
type Whisper struct {
|
type Whisper struct {
|
||||||
@@ -122,7 +123,7 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
|||||||
segsLen := uintptr(0xdeadbeef)
|
segsLen := uintptr(0xdeadbeef)
|
||||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||||
|
|
||||||
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, data, uintptr(len(data)), segsLenPtr); ret != 0 {
|
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr); ret != 0 {
|
||||||
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,6 +135,10 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
|||||||
txt := strings.Clone(CppGetSegmentText(i))
|
txt := strings.Clone(CppGetSegmentText(i))
|
||||||
tokens := make([]int32, CppNTokens(i))
|
tokens := make([]int32, CppNTokens(i))
|
||||||
|
|
||||||
|
if opts.Diarize && CppGetSegmentSpeakerTurnNext(i) {
|
||||||
|
txt += " [SPEAKER_TURN]"
|
||||||
|
}
|
||||||
|
|
||||||
for j := range tokens {
|
for j := range tokens {
|
||||||
tokens[j] = int32(CppGetTokenID(i, j))
|
tokens[j] = int32(CppGetTokenID(i, j))
|
||||||
}
|
}
|
||||||
@@ -151,6 +156,6 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
|||||||
|
|
||||||
return pb.TranscriptResult{
|
return pb.TranscriptResult{
|
||||||
Segments: segments,
|
Segments: segments,
|
||||||
Text: strings.TrimSpace(text),
|
Text: strings.TrimSpace(text),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,11 +6,12 @@ int load_model(const char *const model_path);
|
|||||||
int load_model_vad(const char *const model_path);
|
int load_model_vad(const char *const model_path);
|
||||||
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
|
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
|
||||||
size_t *segs_out_len);
|
size_t *segs_out_len);
|
||||||
int transcribe(uint32_t threads, char *lang, bool translate, float pcmf32[],
|
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
|
||||||
size_t pcmf32_len, size_t *segs_out_len);
|
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len);
|
||||||
const char *get_segment_text(int i);
|
const char *get_segment_text(int i);
|
||||||
int64_t get_segment_t0(int i);
|
int64_t get_segment_t0(int i);
|
||||||
int64_t get_segment_t1(int i);
|
int64_t get_segment_t1(int i);
|
||||||
int n_tokens(int i);
|
int n_tokens(int i);
|
||||||
int32_t get_token_id(int i, int j);
|
int32_t get_token_id(int i, int j);
|
||||||
|
bool get_segment_speaker_turn_next(int i);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package main
|
|||||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
"os"
|
||||||
|
|
||||||
"github.com/ebitengine/purego"
|
"github.com/ebitengine/purego"
|
||||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
@@ -18,7 +19,13 @@ type LibFuncs struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
gosd, err := purego.Dlopen("./libgowhisper.so", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
// Get library name from environment variable, default to fallback
|
||||||
|
libName := os.Getenv("WHISPER_LIBRARY")
|
||||||
|
if libName == "" {
|
||||||
|
libName = "./libgowhisper-fallback.so"
|
||||||
|
}
|
||||||
|
|
||||||
|
gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -33,6 +40,7 @@ func main() {
|
|||||||
{&CppGetSegmentEnd, "get_segment_t1"},
|
{&CppGetSegmentEnd, "get_segment_t1"},
|
||||||
{&CppNTokens, "n_tokens"},
|
{&CppNTokens, "n_tokens"},
|
||||||
{&CppGetTokenID, "get_token_id"},
|
{&CppGetTokenID, "get_token_id"},
|
||||||
|
{&CppGetSegmentSpeakerTurnNext, "get_segment_speaker_turn_next"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, lf := range libFuncs {
|
for _, lf := range libFuncs {
|
||||||
|
|||||||
@@ -10,7 +10,8 @@ CURDIR=$(dirname "$(realpath $0)")
|
|||||||
# Create lib directory
|
# Create lib directory
|
||||||
mkdir -p $CURDIR/package/lib
|
mkdir -p $CURDIR/package/lib
|
||||||
|
|
||||||
cp -avf $CURDIR/whisper $CURDIR/libgowhisper.so $CURDIR/package/
|
cp -avf $CURDIR/whisper $CURDIR/package/
|
||||||
|
cp -fv $CURDIR/libgowhisper-*.so $CURDIR/package/
|
||||||
cp -fv $CURDIR/run.sh $CURDIR/package/
|
cp -fv $CURDIR/run.sh $CURDIR/package/
|
||||||
|
|
||||||
# Detect architecture and copy appropriate libraries
|
# Detect architecture and copy appropriate libraries
|
||||||
|
|||||||
@@ -1,14 +1,52 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
set -ex
|
set -ex
|
||||||
|
|
||||||
|
# Get the absolute current dir where the script is located
|
||||||
CURDIR=$(dirname "$(realpath $0)")
|
CURDIR=$(dirname "$(realpath $0)")
|
||||||
|
|
||||||
|
cd /
|
||||||
|
|
||||||
|
echo "CPU info:"
|
||||||
|
if [ "$(uname)" != "Darwin" ]; then
|
||||||
|
grep -e "model\sname" /proc/cpuinfo | head -1
|
||||||
|
grep -e "flags" /proc/cpuinfo | head -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
LIBRARY="$CURDIR/libgowhisper-fallback.so"
|
||||||
|
|
||||||
|
if [ "$(uname)" != "Darwin" ]; then
|
||||||
|
if grep -q -e "\savx\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX found OK"
|
||||||
|
if [ -e $CURDIR/libgowhisper-avx.so ]; then
|
||||||
|
LIBRARY="$CURDIR/libgowhisper-avx.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if grep -q -e "\savx2\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX2 found OK"
|
||||||
|
if [ -e $CURDIR/libgowhisper-avx2.so ]; then
|
||||||
|
LIBRARY="$CURDIR/libgowhisper-avx2.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check avx 512
|
||||||
|
if grep -q -e "\savx512f\s" /proc/cpuinfo ; then
|
||||||
|
echo "CPU: AVX512F found OK"
|
||||||
|
if [ -e $CURDIR/libgowhisper-avx512.so ]; then
|
||||||
|
LIBRARY="$CURDIR/libgowhisper-avx512.so"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH
|
||||||
|
export WHISPER_LIBRARY=$LIBRARY
|
||||||
|
|
||||||
# If there is a lib/ld.so, use it
|
# If there is a lib/ld.so, use it
|
||||||
if [ -f $CURDIR/lib/ld.so ]; then
|
if [ -f $CURDIR/lib/ld.so ]; then
|
||||||
echo "Using lib/ld.so"
|
echo "Using lib/ld.so"
|
||||||
|
echo "Using library: $LIBRARY"
|
||||||
exec $CURDIR/lib/ld.so $CURDIR/whisper "$@"
|
exec $CURDIR/lib/ld.so $CURDIR/whisper "$@"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
echo "Using library: $LIBRARY"
|
||||||
exec $CURDIR/whisper "$@"
|
exec $CURDIR/whisper "$@"
|
||||||
@@ -270,6 +270,7 @@
|
|||||||
nvidia: "cuda12-kokoro"
|
nvidia: "cuda12-kokoro"
|
||||||
intel: "intel-kokoro"
|
intel: "intel-kokoro"
|
||||||
amd: "rocm-kokoro"
|
amd: "rocm-kokoro"
|
||||||
|
nvidia-l4t: "nvidia-l4t-kokoro"
|
||||||
- &coqui
|
- &coqui
|
||||||
urls:
|
urls:
|
||||||
- https://github.com/idiap/coqui-ai-TTS
|
- https://github.com/idiap/coqui-ai-TTS
|
||||||
@@ -350,6 +351,9 @@
|
|||||||
alias: "chatterbox"
|
alias: "chatterbox"
|
||||||
capabilities:
|
capabilities:
|
||||||
nvidia: "cuda12-chatterbox"
|
nvidia: "cuda12-chatterbox"
|
||||||
|
metal: "metal-chatterbox"
|
||||||
|
default: "cpu-chatterbox"
|
||||||
|
nvidia-l4t: "nvidia-l4t-arm64-chatterbox"
|
||||||
- &piper
|
- &piper
|
||||||
name: "piper"
|
name: "piper"
|
||||||
uri: "quay.io/go-skynet/local-ai-backends:latest-piper"
|
uri: "quay.io/go-skynet/local-ai-backends:latest-piper"
|
||||||
@@ -423,6 +427,68 @@
|
|||||||
- text-to-speech
|
- text-to-speech
|
||||||
- TTS
|
- TTS
|
||||||
license: apache-2.0
|
license: apache-2.0
|
||||||
|
- &neutts
|
||||||
|
name: "neutts"
|
||||||
|
urls:
|
||||||
|
- https://github.com/neuphonic/neutts-air
|
||||||
|
description: |
|
||||||
|
NeuTTS Air is the world’s first super-realistic, on-device, TTS speech language model with instant voice cloning. Built off a 0.5B LLM backbone, NeuTTS Air brings natural-sounding speech, real-time performance, built-in security and speaker cloning to your local device - unlocking a new category of embedded voice agents, assistants, toys, and compliance-safe apps.
|
||||||
|
tags:
|
||||||
|
- text-to-speech
|
||||||
|
- TTS
|
||||||
|
license: apache-2.0
|
||||||
|
capabilities:
|
||||||
|
default: "cpu-neutts"
|
||||||
|
nvidia: "cuda12-neutts"
|
||||||
|
amd: "rocm-neutts"
|
||||||
|
nvidia-l4t: "nvidia-l4t-neutts"
|
||||||
|
- !!merge <<: *neutts
|
||||||
|
name: "neutts-development"
|
||||||
|
capabilities:
|
||||||
|
default: "cpu-neutts-development"
|
||||||
|
nvidia: "cuda12-neutts-development"
|
||||||
|
amd: "rocm-neutts-development"
|
||||||
|
nvidia-l4t: "nvidia-l4t-neutts-development"
|
||||||
|
- !!merge <<: *neutts
|
||||||
|
name: "cpu-neutts"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-neutts"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-cpu-neutts
|
||||||
|
- !!merge <<: *neutts
|
||||||
|
name: "cuda12-neutts"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-neutts"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-gpu-nvidia-cuda-12-neutts
|
||||||
|
- !!merge <<: *neutts
|
||||||
|
name: "rocm-neutts"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-neutts"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-gpu-rocm-hipblas-neutts
|
||||||
|
- !!merge <<: *neutts
|
||||||
|
name: "nvidia-l4t-neutts"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-neutts"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-nvidia-l4t-arm64-neutts
|
||||||
|
- !!merge <<: *neutts
|
||||||
|
name: "cpu-neutts-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-neutts"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-cpu-neutts
|
||||||
|
- !!merge <<: *neutts
|
||||||
|
name: "cuda12-neutts-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-neutts"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-gpu-nvidia-cuda-12-neutts
|
||||||
|
- !!merge <<: *neutts
|
||||||
|
name: "rocm-neutts-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-neutts"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-gpu-rocm-hipblas-neutts
|
||||||
|
- !!merge <<: *neutts
|
||||||
|
name: "nvidia-l4t-neutts-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-neutts"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-nvidia-l4t-arm64-neutts
|
||||||
- !!merge <<: *mlx
|
- !!merge <<: *mlx
|
||||||
name: "mlx-development"
|
name: "mlx-development"
|
||||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx"
|
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx"
|
||||||
@@ -1047,6 +1113,7 @@
|
|||||||
nvidia: "cuda12-kokoro-development"
|
nvidia: "cuda12-kokoro-development"
|
||||||
intel: "intel-kokoro-development"
|
intel: "intel-kokoro-development"
|
||||||
amd: "rocm-kokoro-development"
|
amd: "rocm-kokoro-development"
|
||||||
|
nvidia-l4t: "nvidia-l4t-kokoro-development"
|
||||||
- !!merge <<: *kokoro
|
- !!merge <<: *kokoro
|
||||||
name: "cuda11-kokoro-development"
|
name: "cuda11-kokoro-development"
|
||||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-11-kokoro"
|
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-11-kokoro"
|
||||||
@@ -1072,6 +1139,16 @@
|
|||||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-kokoro"
|
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-kokoro"
|
||||||
mirrors:
|
mirrors:
|
||||||
- localai/localai-backends:master-gpu-intel-kokoro
|
- localai/localai-backends:master-gpu-intel-kokoro
|
||||||
|
- !!merge <<: *kokoro
|
||||||
|
name: "nvidia-l4t-kokoro"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-l4t-kokoro"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-gpu-nvidia-l4t-kokoro
|
||||||
|
- !!merge <<: *kokoro
|
||||||
|
name: "nvidia-l4t-kokoro-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-l4t-kokoro"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-gpu-nvidia-l4t-kokoro
|
||||||
- !!merge <<: *kokoro
|
- !!merge <<: *kokoro
|
||||||
name: "cuda11-kokoro"
|
name: "cuda11-kokoro"
|
||||||
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-11-kokoro"
|
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-11-kokoro"
|
||||||
@@ -1223,6 +1300,39 @@
|
|||||||
name: "chatterbox-development"
|
name: "chatterbox-development"
|
||||||
capabilities:
|
capabilities:
|
||||||
nvidia: "cuda12-chatterbox-development"
|
nvidia: "cuda12-chatterbox-development"
|
||||||
|
metal: "metal-chatterbox-development"
|
||||||
|
default: "cpu-chatterbox-development"
|
||||||
|
nvidia-l4t: "nvidia-l4t-arm64-chatterbox"
|
||||||
|
- !!merge <<: *chatterbox
|
||||||
|
name: "cpu-chatterbox"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-chatterbox"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-cpu-chatterbox
|
||||||
|
- !!merge <<: *chatterbox
|
||||||
|
name: "cpu-chatterbox-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-cpu-chatterbox"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-cpu-chatterbox
|
||||||
|
- !!merge <<: *chatterbox
|
||||||
|
name: "nvidia-l4t-arm64-chatterbox"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-l4t-arm64-chatterbox"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-gpu-nvidia-l4t-arm64-chatterbox
|
||||||
|
- !!merge <<: *chatterbox
|
||||||
|
name: "nvidia-l4t-arm64-chatterbox-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-l4t-arm64-chatterbox"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-gpu-nvidia-l4t-arm64-chatterbox
|
||||||
|
- !!merge <<: *chatterbox
|
||||||
|
name: "metal-chatterbox"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-chatterbox"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:latest-metal-darwin-arm64-chatterbox
|
||||||
|
- !!merge <<: *chatterbox
|
||||||
|
name: "metal-chatterbox-development"
|
||||||
|
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-chatterbox"
|
||||||
|
mirrors:
|
||||||
|
- localai/localai-backends:master-metal-darwin-arm64-chatterbox
|
||||||
- !!merge <<: *chatterbox
|
- !!merge <<: *chatterbox
|
||||||
name: "cuda12-chatterbox-development"
|
name: "cuda12-chatterbox-development"
|
||||||
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-chatterbox"
|
uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-chatterbox"
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
bark==0.1.5
|
bark==0.1.5
|
||||||
grpcio==1.74.0
|
grpcio==1.76.0
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
This is an extra gRPC server of LocalAI for Bark TTS
|
This is an extra gRPC server of LocalAI for Chatterbox TTS
|
||||||
"""
|
"""
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
import time
|
import time
|
||||||
@@ -14,15 +14,98 @@ import backend_pb2_grpc
|
|||||||
import torch
|
import torch
|
||||||
import torchaudio as ta
|
import torchaudio as ta
|
||||||
from chatterbox.tts import ChatterboxTTS
|
from chatterbox.tts import ChatterboxTTS
|
||||||
|
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||||
import grpc
|
import grpc
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
def is_float(s):
|
||||||
|
"""Check if a string can be converted to float."""
|
||||||
|
try:
|
||||||
|
float(s)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
def is_int(s):
|
||||||
|
"""Check if a string can be converted to int."""
|
||||||
|
try:
|
||||||
|
int(s)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def split_text_at_word_boundary(text, max_length=250):
|
||||||
|
"""
|
||||||
|
Split text at word boundaries without truncating words.
|
||||||
|
Returns a list of text chunks.
|
||||||
|
"""
|
||||||
|
if not text or len(text) <= max_length:
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
words = text.split()
|
||||||
|
current_chunk = ""
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
# Check if adding this word would exceed the limit
|
||||||
|
if len(current_chunk) + len(word) + 1 <= max_length:
|
||||||
|
if current_chunk:
|
||||||
|
current_chunk += " " + word
|
||||||
|
else:
|
||||||
|
current_chunk = word
|
||||||
|
else:
|
||||||
|
# If current chunk is not empty, add it to chunks
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk)
|
||||||
|
current_chunk = word
|
||||||
|
else:
|
||||||
|
# If a single word is longer than max_length, we have to include it anyway
|
||||||
|
chunks.append(word)
|
||||||
|
current_chunk = ""
|
||||||
|
|
||||||
|
# Add the last chunk if it's not empty
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(current_chunk)
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def merge_audio_files(audio_files, output_path, sample_rate):
|
||||||
|
"""
|
||||||
|
Merge multiple audio files into a single audio file.
|
||||||
|
"""
|
||||||
|
if not audio_files:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(audio_files) == 1:
|
||||||
|
# If only one file, just copy it
|
||||||
|
import shutil
|
||||||
|
shutil.copy2(audio_files[0], output_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load all audio files
|
||||||
|
waveforms = []
|
||||||
|
for audio_file in audio_files:
|
||||||
|
waveform, sr = ta.load(audio_file)
|
||||||
|
if sr != sample_rate:
|
||||||
|
# Resample if necessary
|
||||||
|
resampler = ta.transforms.Resample(sr, sample_rate)
|
||||||
|
waveform = resampler(waveform)
|
||||||
|
waveforms.append(waveform)
|
||||||
|
|
||||||
|
# Concatenate all waveforms
|
||||||
|
merged_waveform = torch.cat(waveforms, dim=1)
|
||||||
|
|
||||||
|
# Save the merged audio
|
||||||
|
ta.save(output_path, merged_waveform, sample_rate)
|
||||||
|
|
||||||
|
# Clean up temporary files
|
||||||
|
for audio_file in audio_files:
|
||||||
|
if os.path.exists(audio_file):
|
||||||
|
os.remove(audio_file)
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||||
COQUI_LANGUAGE = os.environ.get('COQUI_LANGUAGE', None)
|
|
||||||
|
|
||||||
# Implement the BackendServicer class with the service methods
|
# Implement the BackendServicer class with the service methods
|
||||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
@@ -47,6 +130,28 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
if not torch.cuda.is_available() and request.CUDA:
|
if not torch.cuda.is_available() and request.CUDA:
|
||||||
return backend_pb2.Result(success=False, message="CUDA is not available")
|
return backend_pb2.Result(success=False, message="CUDA is not available")
|
||||||
|
|
||||||
|
|
||||||
|
options = request.Options
|
||||||
|
|
||||||
|
# empty dict
|
||||||
|
self.options = {}
|
||||||
|
|
||||||
|
# The options are a list of strings in this form optname:optvalue
|
||||||
|
# We are storing all the options in a dict so we can use it later when
|
||||||
|
# generating the images
|
||||||
|
for opt in options:
|
||||||
|
if ":" not in opt:
|
||||||
|
continue
|
||||||
|
key, value = opt.split(":")
|
||||||
|
# if value is a number, convert it to the appropriate type
|
||||||
|
if is_float(value):
|
||||||
|
value = float(value)
|
||||||
|
elif is_int(value):
|
||||||
|
value = int(value)
|
||||||
|
elif value.lower() in ["true", "false"]:
|
||||||
|
value = value.lower() == "true"
|
||||||
|
self.options[key] = value
|
||||||
|
|
||||||
self.AudioPath = None
|
self.AudioPath = None
|
||||||
|
|
||||||
if os.path.isabs(request.AudioPath):
|
if os.path.isabs(request.AudioPath):
|
||||||
@@ -56,10 +161,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
modelFileBase = os.path.dirname(request.ModelFile)
|
modelFileBase = os.path.dirname(request.ModelFile)
|
||||||
# modify LoraAdapter to be relative to modelFileBase
|
# modify LoraAdapter to be relative to modelFileBase
|
||||||
self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
|
self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print("Preparing models, please wait", file=sys.stderr)
|
print("Preparing models, please wait", file=sys.stderr)
|
||||||
self.model = ChatterboxTTS.from_pretrained(device=device)
|
if "multilingual" in self.options:
|
||||||
|
# remove key from options
|
||||||
|
del self.options["multilingual"]
|
||||||
|
self.model = ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||||
|
else:
|
||||||
|
self.model = ChatterboxTTS.from_pretrained(device=device)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
# Implement your logic here for the LoadModel service
|
# Implement your logic here for the LoadModel service
|
||||||
@@ -68,14 +177,43 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
|
|
||||||
def TTS(self, request, context):
|
def TTS(self, request, context):
|
||||||
try:
|
try:
|
||||||
# Generate audio using ChatterboxTTS
|
kwargs = {}
|
||||||
|
|
||||||
|
if "language" in self.options:
|
||||||
|
kwargs["language_id"] = self.options["language"]
|
||||||
if self.AudioPath is not None:
|
if self.AudioPath is not None:
|
||||||
wav = self.model.generate(request.text, audio_prompt_path=self.AudioPath)
|
kwargs["audio_prompt_path"] = self.AudioPath
|
||||||
|
|
||||||
|
# add options to kwargs
|
||||||
|
kwargs.update(self.options)
|
||||||
|
|
||||||
|
# Check if text exceeds 250 characters
|
||||||
|
# (chatterbox does not support long text)
|
||||||
|
# https://github.com/resemble-ai/chatterbox/issues/60
|
||||||
|
# https://github.com/resemble-ai/chatterbox/issues/110
|
||||||
|
if len(request.text) > 250:
|
||||||
|
# Split text at word boundaries
|
||||||
|
text_chunks = split_text_at_word_boundary(request.text, max_length=250)
|
||||||
|
print(f"Splitting text into chunks of 250 characters: {len(text_chunks)}", file=sys.stderr)
|
||||||
|
# Generate audio for each chunk
|
||||||
|
temp_audio_files = []
|
||||||
|
for i, chunk in enumerate(text_chunks):
|
||||||
|
# Generate audio for this chunk
|
||||||
|
wav = self.model.generate(chunk, **kwargs)
|
||||||
|
|
||||||
|
# Create temporary file for this chunk
|
||||||
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
|
||||||
|
temp_file.close()
|
||||||
|
ta.save(temp_file.name, wav, self.model.sr)
|
||||||
|
temp_audio_files.append(temp_file.name)
|
||||||
|
|
||||||
|
# Merge all audio files
|
||||||
|
merge_audio_files(temp_audio_files, request.dst, self.model.sr)
|
||||||
else:
|
else:
|
||||||
wav = self.model.generate(request.text)
|
# Generate audio using ChatterboxTTS for short text
|
||||||
|
wav = self.model.generate(request.text, **kwargs)
|
||||||
# Save the generated audio
|
# Save the generated audio
|
||||||
ta.save(request.dst, wav, self.model.sr)
|
ta.save(request.dst, wav, self.model.sr)
|
||||||
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
|
|||||||
@@ -15,5 +15,6 @@ fi
|
|||||||
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||||
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||||
fi
|
fi
|
||||||
|
EXTRA_PIP_INSTALL_FLAGS+=" --no-build-isolation"
|
||||||
|
|
||||||
installRequirements
|
installRequirements
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
accelerate
|
accelerate
|
||||||
torch==2.6.0
|
torch
|
||||||
torchaudio==2.6.0
|
torchaudio
|
||||||
transformers==4.46.3
|
transformers
|
||||||
chatterbox-tts
|
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||||
|
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||||
|
#chatterbox-tts==0.1.4
|
||||||
@@ -2,5 +2,6 @@
|
|||||||
torch==2.6.0+cu118
|
torch==2.6.0+cu118
|
||||||
torchaudio==2.6.0+cu118
|
torchaudio==2.6.0+cu118
|
||||||
transformers==4.46.3
|
transformers==4.46.3
|
||||||
chatterbox-tts
|
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||||
|
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||||
accelerate
|
accelerate
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
torch==2.6.0
|
torch
|
||||||
torchaudio==2.6.0
|
torchaudio
|
||||||
transformers==4.46.3
|
transformers
|
||||||
chatterbox-tts
|
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||||
|
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||||
accelerate
|
accelerate
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
--extra-index-url https://download.pytorch.org/whl/rocm6.0
|
||||||
torch==2.6.0+rocm6.1
|
torch==2.6.0+rocm6.1
|
||||||
torchaudio==2.6.0+rocm6.1
|
torchaudio==2.6.0+rocm6.1
|
||||||
transformers==4.46.3
|
transformers
|
||||||
chatterbox-tts
|
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||||
|
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||||
accelerate
|
accelerate
|
||||||
|
|||||||
@@ -2,10 +2,10 @@
|
|||||||
intel-extension-for-pytorch==2.3.110+xpu
|
intel-extension-for-pytorch==2.3.110+xpu
|
||||||
torch==2.3.1+cxx11.abi
|
torch==2.3.1+cxx11.abi
|
||||||
torchaudio==2.3.1+cxx11.abi
|
torchaudio==2.3.1+cxx11.abi
|
||||||
transformers==4.46.3
|
transformers
|
||||||
chatterbox-tts
|
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||||
|
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||||
accelerate
|
accelerate
|
||||||
oneccl_bind_pt==2.3.100+xpu
|
oneccl_bind_pt==2.3.100+xpu
|
||||||
optimum[openvino]
|
optimum[openvino]
|
||||||
setuptools
|
setuptools
|
||||||
accelerate
|
|
||||||
6
backend/python/chatterbox/requirements-l4t.txt
Normal file
6
backend/python/chatterbox/requirements-l4t.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
|
||||||
|
torch
|
||||||
|
torchaudio
|
||||||
|
transformers
|
||||||
|
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||||
|
accelerate
|
||||||
@@ -2,4 +2,5 @@ grpcio==1.71.0
|
|||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
packaging
|
packaging
|
||||||
setuptools
|
setuptools
|
||||||
|
poetry
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
grpcio==1.74.0
|
grpcio==1.76.0
|
||||||
protobuf
|
protobuf
|
||||||
grpcio-tools
|
grpcio-tools
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
grpcio==1.74.0
|
grpcio==1.76.0
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
packaging==24.1
|
packaging==24.1
|
||||||
@@ -66,11 +66,20 @@ from diffusers.schedulers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
def is_float(s):
|
def is_float(s):
|
||||||
|
"""Check if a string can be converted to float."""
|
||||||
try:
|
try:
|
||||||
float(s)
|
float(s)
|
||||||
return True
|
return True
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
def is_int(s):
|
||||||
|
"""Check if a string can be converted to int."""
|
||||||
|
try:
|
||||||
|
int(s)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
|
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
|
||||||
# Credits to https://github.com/neggles
|
# Credits to https://github.com/neggles
|
||||||
@@ -177,10 +186,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
key, value = opt.split(":")
|
key, value = opt.split(":")
|
||||||
# if value is a number, convert it to the appropriate type
|
# if value is a number, convert it to the appropriate type
|
||||||
if is_float(value):
|
if is_float(value):
|
||||||
if value.is_integer():
|
value = float(value)
|
||||||
value = int(value)
|
elif is_int(value):
|
||||||
else:
|
value = int(value)
|
||||||
value = float(value)
|
elif value.lower() in ["true", "false"]:
|
||||||
|
value = value.lower() == "true"
|
||||||
self.options[key] = value
|
self.options[key] = value
|
||||||
|
|
||||||
# From options, extract if present "torch_dtype" and set it to the appropriate type
|
# From options, extract if present "torch_dtype" and set it to the appropriate type
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
setuptools
|
setuptools
|
||||||
grpcio==1.74.0
|
grpcio==1.76.0
|
||||||
pillow
|
pillow
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class TestBackendServicer(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
This method tests if the server starts up successfully
|
This method tests if the server starts up successfully
|
||||||
"""
|
"""
|
||||||
time.sleep(10)
|
time.sleep(20)
|
||||||
try:
|
try:
|
||||||
self.setUp()
|
self.setUp()
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
@@ -48,7 +48,7 @@ class TestBackendServicer(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
This method tests if the model is loaded successfully
|
This method tests if the model is loaded successfully
|
||||||
"""
|
"""
|
||||||
time.sleep(10)
|
time.sleep(20)
|
||||||
try:
|
try:
|
||||||
self.setUp()
|
self.setUp()
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
@@ -66,7 +66,7 @@ class TestBackendServicer(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
This method tests if the backend can generate images
|
This method tests if the backend can generate images
|
||||||
"""
|
"""
|
||||||
time.sleep(10)
|
time.sleep(20)
|
||||||
try:
|
try:
|
||||||
self.setUp()
|
self.setUp()
|
||||||
with grpc.insecure_channel("localhost:50051") as channel:
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
grpcio==1.74.0
|
grpcio==1.76.0
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
wheel
|
wheel
|
||||||
|
|||||||
@@ -64,15 +64,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
# Generate audio using Kokoro pipeline
|
# Generate audio using Kokoro pipeline
|
||||||
generator = self.pipeline(request.text, voice=voice)
|
generator = self.pipeline(request.text, voice=voice)
|
||||||
|
|
||||||
# Get the first (and typically only) audio segment
|
speechs = []
|
||||||
|
# Get all the audio segment
|
||||||
for i, (gs, ps, audio) in enumerate(generator):
|
for i, (gs, ps, audio) in enumerate(generator):
|
||||||
# Save audio to the destination file
|
speechs.append(audio)
|
||||||
sf.write(request.dst, audio, 24000)
|
|
||||||
print(f"Generated audio segment {i}: gs={gs}, ps={ps}", file=sys.stderr)
|
print(f"Generated audio segment {i}: gs={gs}, ps={ps}", file=sys.stderr)
|
||||||
# For now, we only process the first segment
|
# Merges the audio segments and writes them to the destination
|
||||||
# If you need to handle multiple segments, you might want to modify this
|
speech = torch.cat(speechs, dim=0)
|
||||||
break
|
sf.write(request.dst, speech, 24000)
|
||||||
|
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
|
|
||||||
|
|||||||
7
backend/python/kokoro/requirements-l4t.txt
Normal file
7
backend/python/kokoro/requirements-l4t.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
|
||||||
|
torch
|
||||||
|
torchaudio
|
||||||
|
transformers
|
||||||
|
accelerate
|
||||||
|
kokoro
|
||||||
|
soundfile
|
||||||
@@ -20,6 +20,21 @@ import soundfile as sf
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
def is_float(s):
|
||||||
|
"""Check if a string can be converted to float."""
|
||||||
|
try:
|
||||||
|
float(s)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
def is_int(s):
|
||||||
|
"""Check if a string can be converted to int."""
|
||||||
|
try:
|
||||||
|
int(s)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||||
@@ -32,14 +47,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
This backend provides TTS (Text-to-Speech) functionality using MLX-Audio.
|
This backend provides TTS (Text-to-Speech) functionality using MLX-Audio.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _is_float(self, s):
|
|
||||||
"""Check if a string can be converted to float."""
|
|
||||||
try:
|
|
||||||
float(s)
|
|
||||||
return True
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def Health(self, request, context):
|
def Health(self, request, context):
|
||||||
"""
|
"""
|
||||||
Returns a health check message.
|
Returns a health check message.
|
||||||
@@ -80,11 +87,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
||||||
|
|
||||||
# Convert numeric values to appropriate types
|
# Convert numeric values to appropriate types
|
||||||
if self._is_float(value):
|
if is_float(value):
|
||||||
if float(value).is_integer():
|
value = float(value)
|
||||||
value = int(value)
|
elif is_int(value):
|
||||||
else:
|
value = int(value)
|
||||||
value = float(value)
|
|
||||||
elif value.lower() in ["true", "false"]:
|
elif value.lower() in ["true", "false"]:
|
||||||
value = value.lower() == "true"
|
value = value.lower() == "true"
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,21 @@ import io
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
|
def is_float(s):
|
||||||
|
"""Check if a string can be converted to float."""
|
||||||
|
try:
|
||||||
|
float(s)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
def is_int(s):
|
||||||
|
"""Check if a string can be converted to int."""
|
||||||
|
try:
|
||||||
|
int(s)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||||
@@ -32,14 +47,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
A gRPC servicer that implements the Backend service defined in backend.proto.
|
A gRPC servicer that implements the Backend service defined in backend.proto.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _is_float(self, s):
|
|
||||||
"""Check if a string can be converted to float."""
|
|
||||||
try:
|
|
||||||
float(s)
|
|
||||||
return True
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def Health(self, request, context):
|
def Health(self, request, context):
|
||||||
"""
|
"""
|
||||||
Returns a health check message.
|
Returns a health check message.
|
||||||
@@ -79,12 +86,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
continue
|
continue
|
||||||
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
||||||
|
|
||||||
# Convert numeric values to appropriate types
|
if is_float(value):
|
||||||
if self._is_float(value):
|
value = float(value)
|
||||||
if float(value).is_integer():
|
elif is_int(value):
|
||||||
value = int(value)
|
value = int(value)
|
||||||
else:
|
|
||||||
value = float(value)
|
|
||||||
elif value.lower() in ["true", "false"]:
|
elif value.lower() in ["true", "false"]:
|
||||||
value = value.lower() == "true"
|
value = value.lower() == "true"
|
||||||
|
|
||||||
|
|||||||
@@ -24,20 +24,27 @@ _ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
|||||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||||
|
|
||||||
|
def is_float(s):
|
||||||
|
"""Check if a string can be converted to float."""
|
||||||
|
try:
|
||||||
|
float(s)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
def is_int(s):
|
||||||
|
"""Check if a string can be converted to int."""
|
||||||
|
try:
|
||||||
|
int(s)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
# Implement the BackendServicer class with the service methods
|
# Implement the BackendServicer class with the service methods
|
||||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
"""
|
"""
|
||||||
A gRPC servicer that implements the Backend service defined in backend.proto.
|
A gRPC servicer that implements the Backend service defined in backend.proto.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _is_float(self, s):
|
|
||||||
"""Check if a string can be converted to float."""
|
|
||||||
try:
|
|
||||||
float(s)
|
|
||||||
return True
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def Health(self, request, context):
|
def Health(self, request, context):
|
||||||
"""
|
"""
|
||||||
Returns a health check message.
|
Returns a health check message.
|
||||||
@@ -78,11 +85,10 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
||||||
|
|
||||||
# Convert numeric values to appropriate types
|
# Convert numeric values to appropriate types
|
||||||
if self._is_float(value):
|
if is_float(value):
|
||||||
if float(value).is_integer():
|
value = float(value)
|
||||||
value = int(value)
|
elif is_int(value):
|
||||||
else:
|
value = int(value)
|
||||||
value = float(value)
|
|
||||||
elif value.lower() in ["true", "false"]:
|
elif value.lower() in ["true", "false"]:
|
||||||
value = value.lower() == "true"
|
value = value.lower() == "true"
|
||||||
|
|
||||||
|
|||||||
23
backend/python/neutts/Makefile
Normal file
23
backend/python/neutts/Makefile
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
.PHONY: neutts
|
||||||
|
neutts:
|
||||||
|
bash install.sh
|
||||||
|
|
||||||
|
.PHONY: run
|
||||||
|
run: neutts
|
||||||
|
@echo "Running neutts..."
|
||||||
|
bash run.sh
|
||||||
|
@echo "neutts run."
|
||||||
|
|
||||||
|
.PHONY: test
|
||||||
|
test: neutts
|
||||||
|
@echo "Testing neutts..."
|
||||||
|
bash test.sh
|
||||||
|
@echo "neutts tested."
|
||||||
|
|
||||||
|
.PHONY: protogen-clean
|
||||||
|
protogen-clean:
|
||||||
|
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||||
|
|
||||||
|
.PHONY: clean
|
||||||
|
clean: protogen-clean
|
||||||
|
rm -rf venv __pycache__
|
||||||
162
backend/python/neutts/backend.py
Normal file
162
backend/python/neutts/backend.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
This is an extra gRPC server of LocalAI for NeuTTSAir
|
||||||
|
"""
|
||||||
|
from concurrent import futures
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import backend_pb2
|
||||||
|
import backend_pb2_grpc
|
||||||
|
import torch
|
||||||
|
from neuttsair.neutts import NeuTTSAir
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
def is_float(s):
|
||||||
|
"""Check if a string can be converted to float."""
|
||||||
|
try:
|
||||||
|
float(s)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
def is_int(s):
|
||||||
|
"""Check if a string can be converted to int."""
|
||||||
|
try:
|
||||||
|
int(s)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
|
|
||||||
|
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||||
|
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||||
|
|
||||||
|
# Implement the BackendServicer class with the service methods
|
||||||
|
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||||
|
"""
|
||||||
|
BackendServicer is the class that implements the gRPC service
|
||||||
|
"""
|
||||||
|
def Health(self, request, context):
|
||||||
|
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||||
|
def LoadModel(self, request, context):
|
||||||
|
|
||||||
|
# Get device
|
||||||
|
# device = "cuda" if request.CUDA else "cpu"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print("CUDA is available", file=sys.stderr)
|
||||||
|
device = "cuda"
|
||||||
|
else:
|
||||||
|
print("CUDA is not available", file=sys.stderr)
|
||||||
|
device = "cpu"
|
||||||
|
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||||
|
if mps_available:
|
||||||
|
device = "mps"
|
||||||
|
if not torch.cuda.is_available() and request.CUDA:
|
||||||
|
return backend_pb2.Result(success=False, message="CUDA is not available")
|
||||||
|
|
||||||
|
|
||||||
|
options = request.Options
|
||||||
|
|
||||||
|
# empty dict
|
||||||
|
self.options = {}
|
||||||
|
self.ref_text = None
|
||||||
|
|
||||||
|
# The options are a list of strings in this form optname:optvalue
|
||||||
|
# We are storing all the options in a dict so we can use it later when
|
||||||
|
# generating the images
|
||||||
|
for opt in options:
|
||||||
|
if ":" not in opt:
|
||||||
|
continue
|
||||||
|
key, value = opt.split(":")
|
||||||
|
# if value is a number, convert it to the appropriate type
|
||||||
|
if is_float(value):
|
||||||
|
value = float(value)
|
||||||
|
elif is_int(value):
|
||||||
|
value = int(value)
|
||||||
|
elif value.lower() in ["true", "false"]:
|
||||||
|
value = value.lower() == "true"
|
||||||
|
self.options[key] = value
|
||||||
|
|
||||||
|
codec_repo = "neuphonic/neucodec"
|
||||||
|
if "codec_repo" in self.options:
|
||||||
|
codec_repo = self.options["codec_repo"]
|
||||||
|
del self.options["codec_repo"]
|
||||||
|
if "ref_text" in self.options:
|
||||||
|
self.ref_text = self.options["ref_text"]
|
||||||
|
del self.options["ref_text"]
|
||||||
|
|
||||||
|
self.AudioPath = None
|
||||||
|
|
||||||
|
if os.path.isabs(request.AudioPath):
|
||||||
|
self.AudioPath = request.AudioPath
|
||||||
|
elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath):
|
||||||
|
# get base path of modelFile
|
||||||
|
modelFileBase = os.path.dirname(request.ModelFile)
|
||||||
|
# modify LoraAdapter to be relative to modelFileBase
|
||||||
|
self.AudioPath = os.path.join(modelFileBase, request.AudioPath)
|
||||||
|
try:
|
||||||
|
print("Preparing models, please wait", file=sys.stderr)
|
||||||
|
self.model = NeuTTSAir(backbone_repo=request.Model, backbone_device=device, codec_repo=codec_repo, codec_device=device)
|
||||||
|
except Exception as err:
|
||||||
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
|
# Implement your logic here for the LoadModel service
|
||||||
|
# Replace this with your desired response
|
||||||
|
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||||
|
|
||||||
|
def TTS(self, request, context):
|
||||||
|
try:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
# add options to kwargs
|
||||||
|
kwargs.update(self.options)
|
||||||
|
|
||||||
|
ref_codes = self.model.encode_reference(self.AudioPath)
|
||||||
|
|
||||||
|
wav = self.model.infer(request.text, ref_codes, self.ref_text)
|
||||||
|
|
||||||
|
sf.write(request.dst, wav, 24000)
|
||||||
|
except Exception as err:
|
||||||
|
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||||
|
return backend_pb2.Result(success=True)
|
||||||
|
|
||||||
|
def serve(address):
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||||
|
options=[
|
||||||
|
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
|
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
|
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||||
|
])
|
||||||
|
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||||
|
server.add_insecure_port(address)
|
||||||
|
server.start()
|
||||||
|
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||||
|
|
||||||
|
# Define the signal handler function
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
print("Received termination signal. Shutting down...")
|
||||||
|
server.stop(0)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Set the signal handlers for SIGINT and SIGTERM
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
time.sleep(_ONE_DAY_IN_SECONDS)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
server.stop(0)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
serve(args.addr)
|
||||||
33
backend/python/neutts/install.sh
Executable file
33
backend/python/neutts/install.sh
Executable file
@@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
backend_dir=$(dirname $0)
|
||||||
|
if [ -d $backend_dir/common ]; then
|
||||||
|
source $backend_dir/common/libbackend.sh
|
||||||
|
else
|
||||||
|
source $backend_dir/../common/libbackend.sh
|
||||||
|
fi
|
||||||
|
|
||||||
|
# This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links.
|
||||||
|
# This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match.
|
||||||
|
# We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index
|
||||||
|
# the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index
|
||||||
|
if [ "x${BUILD_PROFILE}" == "xintel" ]; then
|
||||||
|
EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "x${BUILD_TYPE}" == "xcublas" ] || [ "x${BUILD_TYPE}" == "xl4t" ]; then
|
||||||
|
export CMAKE_ARGS="-DGGML_CUDA=on"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "x${BUILD_TYPE}" == "xhipblas" ]; then
|
||||||
|
export CMAKE_ARGS="-DGGML_HIPBLAS=on"
|
||||||
|
fi
|
||||||
|
|
||||||
|
EXTRA_PIP_INSTALL_FLAGS+=" --no-build-isolation"
|
||||||
|
|
||||||
|
git clone https://github.com/neuphonic/neutts-air neutts-air
|
||||||
|
|
||||||
|
cp -rfv neutts-air/neuttsair ./
|
||||||
|
|
||||||
|
installRequirements
|
||||||
2
backend/python/neutts/requirements-after.txt
Normal file
2
backend/python/neutts/requirements-after.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
datasets==4.1.1
|
||||||
|
torchtune==0.6.1
|
||||||
10
backend/python/neutts/requirements-cpu.txt
Normal file
10
backend/python/neutts/requirements-cpu.txt
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/cpu
|
||||||
|
accelerate
|
||||||
|
torch==2.8.0
|
||||||
|
transformers==4.56.1
|
||||||
|
librosa==0.11.0
|
||||||
|
neucodec>=0.0.4
|
||||||
|
phonemizer==3.3.0
|
||||||
|
soundfile==0.13.1
|
||||||
|
resemble-perth==1.0.1
|
||||||
|
llama-cpp-python
|
||||||
8
backend/python/neutts/requirements-cublas12.txt
Normal file
8
backend/python/neutts/requirements-cublas12.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
librosa==0.11.0
|
||||||
|
neucodec>=0.0.4
|
||||||
|
phonemizer==3.3.0
|
||||||
|
soundfile==0.13.1
|
||||||
|
torch==2.8.0
|
||||||
|
transformers==4.56.1
|
||||||
|
resemble-perth==1.0.1
|
||||||
|
accelerate
|
||||||
10
backend/python/neutts/requirements-hipblas.txt
Normal file
10
backend/python/neutts/requirements-hipblas.txt
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
--extra-index-url https://download.pytorch.org/whl/rocm6.3
|
||||||
|
torch==2.8.0+rocm6.3
|
||||||
|
transformers==4.56.1
|
||||||
|
accelerate
|
||||||
|
librosa==0.11.0
|
||||||
|
neucodec>=0.0.4
|
||||||
|
phonemizer==3.3.0
|
||||||
|
soundfile==0.13.1
|
||||||
|
resemble-perth==1.0.1
|
||||||
|
llama-cpp-python
|
||||||
10
backend/python/neutts/requirements-l4t.txt
Normal file
10
backend/python/neutts/requirements-l4t.txt
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
--extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/
|
||||||
|
torch
|
||||||
|
transformers
|
||||||
|
accelerate
|
||||||
|
librosa==0.11.0
|
||||||
|
neucodec>=0.0.4
|
||||||
|
phonemizer==3.3.0
|
||||||
|
soundfile==0.13.1
|
||||||
|
resemble-perth==1.0.1
|
||||||
|
llama-cpp-python
|
||||||
7
backend/python/neutts/requirements.txt
Normal file
7
backend/python/neutts/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
grpcio==1.71.0
|
||||||
|
protobuf
|
||||||
|
certifi
|
||||||
|
packaging
|
||||||
|
setuptools
|
||||||
|
numpy==2.2.6
|
||||||
|
scikit_build_core
|
||||||
10
backend/python/neutts/run.sh
Executable file
10
backend/python/neutts/run.sh
Executable file
@@ -0,0 +1,10 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
backend_dir=$(dirname $0)
|
||||||
|
if [ -d $backend_dir/common ]; then
|
||||||
|
source $backend_dir/common/libbackend.sh
|
||||||
|
else
|
||||||
|
source $backend_dir/../common/libbackend.sh
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
startBackend $@
|
||||||
82
backend/python/neutts/test.py
Normal file
82
backend/python/neutts/test.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""
|
||||||
|
A test script to test the gRPC service
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
import backend_pb2
|
||||||
|
import backend_pb2_grpc
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackendServicer(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
TestBackendServicer is the class that tests the gRPC service
|
||||||
|
"""
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
This method sets up the gRPC service by starting the server
|
||||||
|
"""
|
||||||
|
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
|
||||||
|
time.sleep(30)
|
||||||
|
|
||||||
|
def tearDown(self) -> None:
|
||||||
|
"""
|
||||||
|
This method tears down the gRPC service by terminating the server
|
||||||
|
"""
|
||||||
|
self.service.terminate()
|
||||||
|
self.service.wait()
|
||||||
|
|
||||||
|
def test_server_startup(self):
|
||||||
|
"""
|
||||||
|
This method tests if the server starts up successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.Health(backend_pb2.HealthMessage())
|
||||||
|
self.assertEqual(response.message, b'OK')
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("Server failed to start")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_load_model(self):
|
||||||
|
"""
|
||||||
|
This method tests if the model is loaded successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.LoadModel(backend_pb2.ModelOptions())
|
||||||
|
print(response)
|
||||||
|
self.assertTrue(response.success)
|
||||||
|
self.assertEqual(response.message, "Model loaded successfully")
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("LoadModel service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
|
|
||||||
|
def test_tts(self):
|
||||||
|
"""
|
||||||
|
This method tests if the embeddings are generated successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.setUp()
|
||||||
|
with grpc.insecure_channel("localhost:50051") as channel:
|
||||||
|
stub = backend_pb2_grpc.BackendStub(channel)
|
||||||
|
response = stub.LoadModel(backend_pb2.ModelOptions())
|
||||||
|
self.assertTrue(response.success)
|
||||||
|
tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story")
|
||||||
|
tts_response = stub.TTS(tts_request)
|
||||||
|
self.assertIsNotNone(tts_response)
|
||||||
|
except Exception as err:
|
||||||
|
print(err)
|
||||||
|
self.fail("TTS service failed")
|
||||||
|
finally:
|
||||||
|
self.tearDown()
|
||||||
11
backend/python/neutts/test.sh
Executable file
11
backend/python/neutts/test.sh
Executable file
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
backend_dir=$(dirname $0)
|
||||||
|
if [ -d $backend_dir/common ]; then
|
||||||
|
source $backend_dir/common/libbackend.sh
|
||||||
|
else
|
||||||
|
source $backend_dir/../common/libbackend.sh
|
||||||
|
fi
|
||||||
|
|
||||||
|
runUnittests
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
grpcio==1.74.0
|
grpcio==1.76.0
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
grpcio==1.74.0
|
grpcio==1.76.0
|
||||||
protobuf==6.32.0
|
protobuf==6.32.0
|
||||||
certifi
|
certifi
|
||||||
setuptools
|
setuptools
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
grpcio==1.74.0
|
grpcio==1.76.0
|
||||||
protobuf
|
protobuf
|
||||||
certifi
|
certifi
|
||||||
setuptools
|
setuptools
|
||||||
@@ -31,6 +31,7 @@ type Config struct {
|
|||||||
StartOnBoot bool `json:"start_on_boot"`
|
StartOnBoot bool `json:"start_on_boot"`
|
||||||
LogLevel string `json:"log_level"`
|
LogLevel string `json:"log_level"`
|
||||||
EnvironmentVars map[string]string `json:"environment_vars"`
|
EnvironmentVars map[string]string `json:"environment_vars"`
|
||||||
|
ShowWelcome *bool `json:"show_welcome"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Launcher represents the main launcher application
|
// Launcher represents the main launcher application
|
||||||
@@ -148,6 +149,13 @@ func (l *Launcher) Initialize() error {
|
|||||||
log.Printf("Initializing empty EnvironmentVars map")
|
log.Printf("Initializing empty EnvironmentVars map")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set default welcome window preference
|
||||||
|
if l.config.ShowWelcome == nil {
|
||||||
|
true := true
|
||||||
|
l.config.ShowWelcome = &true
|
||||||
|
log.Printf("Setting default ShowWelcome: true")
|
||||||
|
}
|
||||||
|
|
||||||
// Create directories
|
// Create directories
|
||||||
os.MkdirAll(l.config.ModelsPath, 0755)
|
os.MkdirAll(l.config.ModelsPath, 0755)
|
||||||
os.MkdirAll(l.config.BackendsPath, 0755)
|
os.MkdirAll(l.config.BackendsPath, 0755)
|
||||||
|
|||||||
@@ -48,6 +48,14 @@ var _ = Describe("Launcher", func() {
|
|||||||
config := launcherInstance.GetConfig()
|
config := launcherInstance.GetConfig()
|
||||||
Expect(config.ModelsPath).ToNot(BeEmpty())
|
Expect(config.ModelsPath).ToNot(BeEmpty())
|
||||||
Expect(config.BackendsPath).ToNot(BeEmpty())
|
Expect(config.BackendsPath).ToNot(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("should set default ShowWelcome to true", func() {
|
||||||
|
err := launcherInstance.Initialize()
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
config := launcherInstance.GetConfig()
|
||||||
|
Expect(config.ShowWelcome).To(BeTrue())
|
||||||
Expect(config.Address).To(Equal("127.0.0.1:8080"))
|
Expect(config.Address).To(Equal("127.0.0.1:8080"))
|
||||||
Expect(config.LogLevel).To(Equal("info"))
|
Expect(config.LogLevel).To(Equal("info"))
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -177,6 +177,9 @@ func (sm *SystrayManager) recreateMenu() {
|
|||||||
fyne.NewMenuItem("Settings", func() {
|
fyne.NewMenuItem("Settings", func() {
|
||||||
sm.showSettings()
|
sm.showSettings()
|
||||||
}),
|
}),
|
||||||
|
fyne.NewMenuItem("Show Welcome Window", func() {
|
||||||
|
sm.showWelcomeWindow()
|
||||||
|
}),
|
||||||
fyne.NewMenuItem("Open Data Folder", func() {
|
fyne.NewMenuItem("Open Data Folder", func() {
|
||||||
sm.openDataFolder()
|
sm.openDataFolder()
|
||||||
}),
|
}),
|
||||||
@@ -243,6 +246,13 @@ func (sm *SystrayManager) showSettings() {
|
|||||||
sm.window.RequestFocus()
|
sm.window.RequestFocus()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// showWelcomeWindow shows the welcome window
|
||||||
|
func (sm *SystrayManager) showWelcomeWindow() {
|
||||||
|
if sm.launcher.GetUI() != nil {
|
||||||
|
sm.launcher.GetUI().ShowWelcomeWindow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// openDataFolder opens the data folder in file manager
|
// openDataFolder opens the data folder in file manager
|
||||||
func (sm *SystrayManager) openDataFolder() {
|
func (sm *SystrayManager) openDataFolder() {
|
||||||
dataPath := sm.launcher.GetDataPath()
|
dataPath := sm.launcher.GetDataPath()
|
||||||
|
|||||||
@@ -675,3 +675,121 @@ func (ui *LauncherUI) UpdateRunningState(isRunning bool) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ShowWelcomeWindow displays the welcome window with helpful information
|
||||||
|
func (ui *LauncherUI) ShowWelcomeWindow() {
|
||||||
|
if ui.launcher == nil || ui.launcher.window == nil {
|
||||||
|
log.Printf("Cannot show welcome window: launcher or window is nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fyne.DoAndWait(func() {
|
||||||
|
// Create welcome window
|
||||||
|
welcomeWindow := ui.launcher.app.NewWindow("Welcome to LocalAI Launcher")
|
||||||
|
welcomeWindow.Resize(fyne.NewSize(600, 500))
|
||||||
|
welcomeWindow.CenterOnScreen()
|
||||||
|
welcomeWindow.SetCloseIntercept(func() {
|
||||||
|
welcomeWindow.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
// Title
|
||||||
|
titleLabel := widget.NewLabel("Welcome to LocalAI Launcher!")
|
||||||
|
titleLabel.TextStyle = fyne.TextStyle{Bold: true}
|
||||||
|
titleLabel.Alignment = fyne.TextAlignCenter
|
||||||
|
|
||||||
|
// Welcome message
|
||||||
|
welcomeText := `LocalAI Launcher makes it easy to run LocalAI on your system.
|
||||||
|
|
||||||
|
What you can do:
|
||||||
|
• Start and stop LocalAI server
|
||||||
|
• Configure models and backends paths
|
||||||
|
• Set environment variables
|
||||||
|
• Check for updates automatically
|
||||||
|
• Access LocalAI WebUI when running
|
||||||
|
|
||||||
|
Getting Started:
|
||||||
|
1. Configure your models and backends paths
|
||||||
|
2. Click "Start LocalAI" to begin
|
||||||
|
3. Use "Open WebUI" to access the interface
|
||||||
|
4. Check the system tray for quick access`
|
||||||
|
|
||||||
|
welcomeLabel := widget.NewLabel(welcomeText)
|
||||||
|
welcomeLabel.Wrapping = fyne.TextWrapWord
|
||||||
|
|
||||||
|
// Useful links section
|
||||||
|
linksTitle := widget.NewLabel("Useful Links:")
|
||||||
|
linksTitle.TextStyle = fyne.TextStyle{Bold: true}
|
||||||
|
|
||||||
|
// Create link buttons
|
||||||
|
docsButton := widget.NewButton("📚 Documentation", func() {
|
||||||
|
ui.openURL("https://localai.io/docs/")
|
||||||
|
})
|
||||||
|
|
||||||
|
githubButton := widget.NewButton("🐙 GitHub Repository", func() {
|
||||||
|
ui.openURL("https://github.com/mudler/LocalAI")
|
||||||
|
})
|
||||||
|
|
||||||
|
modelsButton := widget.NewButton("🤖 Model Gallery", func() {
|
||||||
|
ui.openURL("https://localai.io/models/")
|
||||||
|
})
|
||||||
|
|
||||||
|
communityButton := widget.NewButton("💬 Community", func() {
|
||||||
|
ui.openURL("https://discord.gg/XgwjKptP7Z")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Checkbox to disable welcome window
|
||||||
|
dontShowAgainCheck := widget.NewCheck("Don't show this welcome window again", func(checked bool) {
|
||||||
|
if ui.launcher != nil {
|
||||||
|
config := ui.launcher.GetConfig()
|
||||||
|
v := !checked
|
||||||
|
config.ShowWelcome = &v
|
||||||
|
ui.launcher.SetConfig(config)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
config := ui.launcher.GetConfig()
|
||||||
|
if config.ShowWelcome != nil {
|
||||||
|
dontShowAgainCheck.SetChecked(*config.ShowWelcome)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close button
|
||||||
|
closeButton := widget.NewButton("Get Started", func() {
|
||||||
|
welcomeWindow.Close()
|
||||||
|
})
|
||||||
|
closeButton.Importance = widget.HighImportance
|
||||||
|
|
||||||
|
// Layout
|
||||||
|
linksContainer := container.NewVBox(
|
||||||
|
linksTitle,
|
||||||
|
docsButton,
|
||||||
|
githubButton,
|
||||||
|
modelsButton,
|
||||||
|
communityButton,
|
||||||
|
)
|
||||||
|
|
||||||
|
content := container.NewVBox(
|
||||||
|
titleLabel,
|
||||||
|
widget.NewSeparator(),
|
||||||
|
welcomeLabel,
|
||||||
|
widget.NewSeparator(),
|
||||||
|
linksContainer,
|
||||||
|
widget.NewSeparator(),
|
||||||
|
dontShowAgainCheck,
|
||||||
|
widget.NewSeparator(),
|
||||||
|
closeButton,
|
||||||
|
)
|
||||||
|
|
||||||
|
welcomeWindow.SetContent(content)
|
||||||
|
welcomeWindow.Show()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// openURL opens a URL in the default browser
|
||||||
|
func (ui *LauncherUI) openURL(urlString string) {
|
||||||
|
parsedURL, err := url.Parse(urlString)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to parse URL %s: %v", urlString, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fyne.CurrentApp().OpenURL(parsedURL)
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,14 +2,12 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"fyne.io/fyne/v2"
|
"fyne.io/fyne/v2"
|
||||||
"fyne.io/fyne/v2/app"
|
"fyne.io/fyne/v2/app"
|
||||||
"fyne.io/fyne/v2/driver/desktop"
|
"fyne.io/fyne/v2/driver/desktop"
|
||||||
coreLauncher "github.com/mudler/LocalAI/cmd/launcher/internal"
|
coreLauncher "github.com/mudler/LocalAI/cmd/launcher/internal"
|
||||||
|
"github.com/mudler/LocalAI/pkg/signals"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -42,7 +40,12 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Setup signal handling for graceful shutdown
|
// Setup signal handling for graceful shutdown
|
||||||
setupSignalHandling(launcher)
|
signals.RegisterGracefulTerminationHandler(func() {
|
||||||
|
// Perform cleanup
|
||||||
|
if err := launcher.Shutdown(); err != nil {
|
||||||
|
log.Printf("Error during shutdown: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Initialize the launcher state
|
// Initialize the launcher state
|
||||||
go func() {
|
go func() {
|
||||||
@@ -55,32 +58,15 @@ func main() {
|
|||||||
// Load configuration into UI
|
// Load configuration into UI
|
||||||
launcher.GetUI().LoadConfiguration()
|
launcher.GetUI().LoadConfiguration()
|
||||||
launcher.GetUI().UpdateStatus("Ready")
|
launcher.GetUI().UpdateStatus("Ready")
|
||||||
|
|
||||||
|
// Show welcome window if configured to do so
|
||||||
|
config := launcher.GetConfig()
|
||||||
|
if *config.ShowWelcome {
|
||||||
|
launcher.GetUI().ShowWelcomeWindow()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Run the application in background (window only shown when "Settings" is clicked)
|
// Run the application in background (window only shown when "Settings" is clicked)
|
||||||
myApp.Run()
|
myApp.Run()
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupSignalHandling sets up signal handlers for graceful shutdown
|
|
||||||
func setupSignalHandling(launcher *coreLauncher.Launcher) {
|
|
||||||
// Create a channel to receive OS signals
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
|
||||||
|
|
||||||
// Register for interrupt and terminate signals
|
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
|
|
||||||
// Handle signals in a separate goroutine
|
|
||||||
go func() {
|
|
||||||
sig := <-sigChan
|
|
||||||
log.Printf("Received signal %v, shutting down gracefully...", sig)
|
|
||||||
|
|
||||||
// Perform cleanup
|
|
||||||
if err := launcher.Shutdown(); err != nil {
|
|
||||||
log.Printf("Error during shutdown: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exit the application
|
|
||||||
os.Exit(0)
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/alecthomas/kong"
|
"github.com/alecthomas/kong"
|
||||||
"github.com/joho/godotenv"
|
"github.com/joho/godotenv"
|
||||||
@@ -24,15 +22,7 @@ func main() {
|
|||||||
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
|
||||||
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
zerolog.SetGlobalLevel(zerolog.InfoLevel)
|
||||||
|
|
||||||
// Catch signals from the OS requesting us to exit
|
// handle loading environment variables from .env files
|
||||||
go func() {
|
|
||||||
c := make(chan os.Signal, 1) // we need to reserve to buffer size 1, so the notifier are not blocked
|
|
||||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
|
||||||
<-c
|
|
||||||
os.Exit(1)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// handle loading environment variabled from .env files
|
|
||||||
envFiles := []string{".env", "localai.env"}
|
envFiles := []string{".env", "localai.env"}
|
||||||
homeDir, err := os.UserHomeDir()
|
homeDir, err := os.UserHomeDir()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|||||||
@@ -129,7 +129,6 @@ func grpcModelOpts(c config.ModelConfig) *pb.ModelOptions {
|
|||||||
triggers = append(triggers, &pb.GrammarTrigger{
|
triggers = append(triggers, &pb.GrammarTrigger{
|
||||||
Word: t.Word,
|
Word: t.Word,
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &pb.ModelOptions{
|
return &pb.ModelOptions{
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ func SoundGeneration(
|
|||||||
|
|
||||||
// return RPC error if any
|
// return RPC error if any
|
||||||
if !res.Success {
|
if !res.Success {
|
||||||
return "", nil, fmt.Errorf(res.Message)
|
return "", nil, fmt.Errorf("error during sound generation: %s", res.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
return filePath, res, err
|
return filePath, res, err
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/mudler/LocalAI/pkg/model"
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
func ModelTranscription(audio, language string, translate bool, diarize bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||||
|
|
||||||
if modelConfig.Backend == "" {
|
if modelConfig.Backend == "" {
|
||||||
modelConfig.Backend = model.WhisperBackend
|
modelConfig.Backend = model.WhisperBackend
|
||||||
@@ -34,6 +34,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
|
|||||||
Dst: audio,
|
Dst: audio,
|
||||||
Language: language,
|
Language: language,
|
||||||
Translate: translate,
|
Translate: translate,
|
||||||
|
Diarize: diarize,
|
||||||
Threads: uint32(*modelConfig.Threads),
|
Threads: uint32(*modelConfig.Threads),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func ModelTTS(
|
|||||||
|
|
||||||
// return RPC error if any
|
// return RPC error if any
|
||||||
if !res.Success {
|
if !res.Success {
|
||||||
return "", nil, fmt.Errorf(res.Message)
|
return "", nil, fmt.Errorf("error during TTS: %s", res.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
return filePath, res, err
|
return filePath, res, err
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
"github.com/mudler/LocalAI/core/explorer"
|
"github.com/mudler/LocalAI/core/explorer"
|
||||||
"github.com/mudler/LocalAI/core/http"
|
"github.com/mudler/LocalAI/core/http"
|
||||||
|
"github.com/mudler/LocalAI/pkg/signals"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ExplorerCMD struct {
|
type ExplorerCMD struct {
|
||||||
@@ -45,5 +47,11 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
|
|||||||
|
|
||||||
appHTTP := http.Explorer(db)
|
appHTTP := http.Explorer(db)
|
||||||
|
|
||||||
|
signals.RegisterGracefulTerminationHandler(func() {
|
||||||
|
if err := appHTTP.Shutdown(); err != nil {
|
||||||
|
log.Error().Err(err).Msg("error during shutdown")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
return appHTTP.Listen(e.Address)
|
return appHTTP.Listen(e.Address)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
|
|
||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
"github.com/mudler/LocalAI/core/p2p"
|
"github.com/mudler/LocalAI/core/p2p"
|
||||||
|
"github.com/mudler/LocalAI/pkg/signals"
|
||||||
)
|
)
|
||||||
|
|
||||||
type FederatedCLI struct {
|
type FederatedCLI struct {
|
||||||
@@ -19,5 +20,11 @@ func (f *FederatedCLI) Run(ctx *cliContext.Context) error {
|
|||||||
|
|
||||||
fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker)
|
fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker)
|
||||||
|
|
||||||
return fs.Start(context.Background())
|
c, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
signals.RegisterGracefulTerminationHandler(func() {
|
||||||
|
cancel()
|
||||||
|
})
|
||||||
|
|
||||||
|
return fs.Start(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/mudler/LocalAI/core/http"
|
"github.com/mudler/LocalAI/core/http"
|
||||||
"github.com/mudler/LocalAI/core/p2p"
|
"github.com/mudler/LocalAI/core/p2p"
|
||||||
"github.com/mudler/LocalAI/internal"
|
"github.com/mudler/LocalAI/internal"
|
||||||
|
"github.com/mudler/LocalAI/pkg/signals"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@@ -126,6 +127,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
|
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
|
||||||
config.WithLoadToMemory(r.LoadToMemory),
|
config.WithLoadToMemory(r.LoadToMemory),
|
||||||
config.WithMachineTag(r.MachineTag),
|
config.WithMachineTag(r.MachineTag),
|
||||||
|
config.WithAPIAddress(r.Address),
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.DisableMetricsEndpoint {
|
if r.DisableMetricsEndpoint {
|
||||||
@@ -224,5 +226,11 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
signals.RegisterGracefulTerminationHandler(func() {
|
||||||
|
if err := app.ModelLoader().StopAllGRPC(); err != nil {
|
||||||
|
log.Error().Err(err).Msg("error while stopping all grpc backends")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
return appHTTP.Listen(r.Address)
|
return appHTTP.Listen(r.Address)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ type TranscriptCMD struct {
|
|||||||
Model string `short:"m" required:"" help:"Model name to run the TTS"`
|
Model string `short:"m" required:"" help:"Model name to run the TTS"`
|
||||||
Language string `short:"l" help:"Language of the audio file"`
|
Language string `short:"l" help:"Language of the audio file"`
|
||||||
Translate bool `short:"c" help:"Translate the transcription to english"`
|
Translate bool `short:"c" help:"Translate the transcription to english"`
|
||||||
|
Diarize bool `short:"d" help:"Mark speaker turns"`
|
||||||
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
|
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
|
||||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||||
}
|
}
|
||||||
@@ -56,7 +57,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, ml, c, opts)
|
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, ml, c, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package worker
|
|||||||
|
|
||||||
type WorkerFlags struct {
|
type WorkerFlags struct {
|
||||||
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
|
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"`
|
||||||
|
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
|
||||||
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/usr/share/localai/backends" help:"Path containing system backends used for inferencing" group:"backends"`
|
||||||
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
|
ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package worker
|
package worker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -9,7 +10,9 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
"github.com/mudler/LocalAI/core/gallery"
|
"github.com/mudler/LocalAI/core/gallery"
|
||||||
|
"github.com/mudler/LocalAI/pkg/model"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
@@ -20,9 +23,10 @@ type LLamaCPP struct {
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
llamaCPPRPCBinaryName = "llama-cpp-rpc-server"
|
llamaCPPRPCBinaryName = "llama-cpp-rpc-server"
|
||||||
|
llamaCPPGalleryName = "llama-cpp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func findLLamaCPPBackend(systemState *system.SystemState) (string, error) {
|
func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (string, error) {
|
||||||
backends, err := gallery.ListSystemBackends(systemState)
|
backends, err := gallery.ListSystemBackends(systemState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Msgf("Failed listing system backends: %s", err)
|
log.Warn().Msgf("Failed listing system backends: %s", err)
|
||||||
@@ -30,9 +34,19 @@ func findLLamaCPPBackend(systemState *system.SystemState) (string, error) {
|
|||||||
}
|
}
|
||||||
log.Debug().Msgf("System backends: %v", backends)
|
log.Debug().Msgf("System backends: %v", backends)
|
||||||
|
|
||||||
backend, ok := backends.Get("llama-cpp")
|
backend, ok := backends.Get(llamaCPPGalleryName)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", errors.New("llama-cpp backend not found, install it first")
|
ml := model.NewModelLoader(systemState, true)
|
||||||
|
var gals []config.Gallery
|
||||||
|
if err := json.Unmarshal([]byte(galleries), &gals); err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed loading galleries")
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
backendPath := filepath.Dir(backend.RunFile)
|
backendPath := filepath.Dir(backend.RunFile)
|
||||||
|
|
||||||
@@ -61,7 +75,7 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
grpcProcess, err := findLLamaCPPBackend(systemState)
|
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -69,6 +83,7 @@ func (r *LLamaCPP) Run(ctx *cliContext.Context) error {
|
|||||||
args := strings.Split(r.ExtraLLamaCPPArgs, " ")
|
args := strings.Split(r.ExtraLLamaCPPArgs, " ")
|
||||||
|
|
||||||
args = append([]string{grpcProcess}, args...)
|
args = append([]string{grpcProcess}, args...)
|
||||||
|
|
||||||
return syscall.Exec(
|
return syscall.Exec(
|
||||||
grpcProcess,
|
grpcProcess,
|
||||||
args,
|
args,
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||||
"github.com/mudler/LocalAI/core/p2p"
|
"github.com/mudler/LocalAI/core/p2p"
|
||||||
|
"github.com/mudler/LocalAI/pkg/signals"
|
||||||
"github.com/mudler/LocalAI/pkg/system"
|
"github.com/mudler/LocalAI/pkg/system"
|
||||||
"github.com/phayes/freeport"
|
"github.com/phayes/freeport"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@@ -47,6 +48,9 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
|||||||
|
|
||||||
address := "127.0.0.1"
|
address := "127.0.0.1"
|
||||||
|
|
||||||
|
c, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
if r.NoRunner {
|
if r.NoRunner {
|
||||||
// Let override which port and address to bind if the user
|
// Let override which port and address to bind if the user
|
||||||
// configure the llama-cpp service on its own
|
// configure the llama-cpp service on its own
|
||||||
@@ -58,7 +62,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
|||||||
p = r.RunnerPort
|
p = r.RunnerPort
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = p2p.ExposeService(context.Background(), address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
_, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -69,7 +73,7 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
|||||||
for {
|
for {
|
||||||
log.Info().Msgf("Starting llama-cpp-rpc-server on '%s:%d'", address, port)
|
log.Info().Msgf("Starting llama-cpp-rpc-server on '%s:%d'", address, port)
|
||||||
|
|
||||||
grpcProcess, err := findLLamaCPPBackend(systemState)
|
grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to find llama-cpp-rpc-server")
|
log.Error().Err(err).Msg("Failed to find llama-cpp-rpc-server")
|
||||||
return
|
return
|
||||||
@@ -100,12 +104,16 @@ func (r *P2P) Run(ctx *cliContext.Context) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err = p2p.ExposeService(context.Background(), address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
_, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
signals.RegisterGracefulTerminationHandler(func() {
|
||||||
|
cancel()
|
||||||
|
})
|
||||||
|
|
||||||
for {
|
for {
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user