mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-03 11:13:31 -05:00
Compare commits
168 Commits
feat/stats
...
v3.8.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0d1d0211f | ||
|
|
f617bec686 | ||
|
|
7a94d237c4 | ||
|
|
304ac94d01 | ||
|
|
f9f9b9d444 | ||
|
|
70d78b9fd4 | ||
|
|
91248da09e | ||
|
|
745c31e013 | ||
|
|
7e01aa8faa | ||
|
|
aceebf81d6 | ||
|
|
71ed03102f | ||
|
|
f6d2a52cd5 | ||
|
|
05a00b2399 | ||
|
|
3a232446e0 | ||
|
|
bdfe8431fa | ||
|
|
55607a5aac | ||
|
|
ec492a4c56 | ||
|
|
2defe98df8 | ||
|
|
b406b088a7 | ||
|
|
6261c87b1b | ||
|
|
fa00aa0085 | ||
|
|
0e53ce60b4 | ||
|
|
8aba078439 | ||
|
|
e88db7d142 | ||
|
|
b7b8a0a748 | ||
|
|
dd2828241c | ||
|
|
b8011f49f2 | ||
|
|
16e5689162 | ||
|
|
2dd42292dc | ||
|
|
53d51671d7 | ||
|
|
daf39e1efd | ||
|
|
382474e4a1 | ||
|
|
5fed9c6596 | ||
|
|
bfa07df7cd | ||
|
|
fbaa21b0e5 | ||
|
|
95b6c9bb5a | ||
|
|
2cc4809b0d | ||
|
|
77bbeed57e | ||
|
|
3152611184 | ||
|
|
30f992f241 | ||
|
|
2709220b84 | ||
|
|
4278506876 | ||
|
|
1dd1d12da1 | ||
|
|
3a5b3bb0a6 | ||
|
|
94d9fc923f | ||
|
|
6fcf2c50b6 | ||
|
|
7cbd4a2f18 | ||
|
|
18d11396cd | ||
|
|
93cd688f40 | ||
|
|
721c3f962b | ||
|
|
fb834805db | ||
|
|
839aa7b42b | ||
|
|
e963a45d66 | ||
|
|
c313b2c671 | ||
|
|
137f16336e | ||
|
|
d7f9f3ac93 | ||
|
|
cd7d384500 | ||
|
|
d1a0dd10e6 | ||
|
|
be8cf838c2 | ||
|
|
3276d1cdaf | ||
|
|
5e5f01badd | ||
|
|
6d0f646c37 | ||
|
|
99d31667f8 | ||
|
|
47b546afdc | ||
|
|
a09d49da43 | ||
|
|
1cdcaf0152 | ||
|
|
03e9f4b140 | ||
|
|
7129409bf6 | ||
|
|
d9e9ec6825 | ||
|
|
b82645d28d | ||
|
|
735ca757fa | ||
|
|
b1d1f2a37d | ||
|
|
3728552e94 | ||
|
|
87d0020c10 | ||
|
|
a8eb537071 | ||
|
|
04fe0b0da8 | ||
|
|
fae93e5ba2 | ||
|
|
b606034243 | ||
|
|
5f4663252d | ||
|
|
80bb7c5f67 | ||
|
|
f6881ea023 | ||
|
|
5651a19aa1 | ||
|
|
c834cdb826 | ||
|
|
fa2caef63d | ||
|
|
31abc799f9 | ||
|
|
2368395a0c | ||
|
|
bf77c11b65 | ||
|
|
8876073f5c | ||
|
|
8432915cb8 | ||
|
|
9ddb94b507 | ||
|
|
e42f0f7e79 | ||
|
|
34bc1bda1e | ||
|
|
01cd58a739 | ||
|
|
679d43c2f5 | ||
|
|
4730b52461 | ||
|
|
f678c6b0a9 | ||
|
|
2f2f9beee7 | ||
|
|
8ac7e28c12 | ||
|
|
c5c3538115 | ||
|
|
5ef16b5693 | ||
|
|
02cc8cbcaa | ||
|
|
e5e86d0acb | ||
|
|
edd35d2b33 | ||
|
|
e8cc29e364 | ||
|
|
8f7c499f17 | ||
|
|
ea446fde08 | ||
|
|
122e4c7094 | ||
|
|
2573102317 | ||
|
|
41b60fcfd3 | ||
|
|
cb81869140 | ||
|
|
db9957b94e | ||
|
|
98158881c2 | ||
|
|
79247a5d17 | ||
|
|
46b7a4c5f2 | ||
|
|
436e2d91d0 | ||
|
|
a86fdc4087 | ||
|
|
c7ac6ca687 | ||
|
|
7088327e8d | ||
|
|
e2cb44ef37 | ||
|
|
3a40b4129c | ||
|
|
4ca8055f21 | ||
|
|
704786cc6d | ||
|
|
e5ce1fd9cc | ||
|
|
ea2037f141 | ||
|
|
567fa62330 | ||
|
|
d424a27fa2 | ||
|
|
3ce9cb566d | ||
|
|
ee7638a9b0 | ||
|
|
e57e50e441 | ||
|
|
81880e7975 | ||
|
|
2cad2c8591 | ||
|
|
b87b41ee45 | ||
|
|
424acd66ad | ||
|
|
3cd8234550 | ||
|
|
c70a0f05b8 | ||
|
|
f85e2dd1b8 | ||
|
|
e485bdf9ab | ||
|
|
495c4ee694 | ||
|
|
161d1a0344 | ||
|
|
b6d1def96f | ||
|
|
9ecfdc5938 | ||
|
|
c332ef5cce | ||
|
|
6e7a8c6041 | ||
|
|
43e707ec4f | ||
|
|
fed3663a74 | ||
|
|
5b72798db3 | ||
|
|
d24d6d4e93 | ||
|
|
50ee1fbe06 | ||
|
|
19f3425ce0 | ||
|
|
a6ef245534 | ||
|
|
88cb379c2d | ||
|
|
0ddb2e8dcf | ||
|
|
91b9301bec | ||
|
|
fad5868f7b | ||
|
|
1e5b9135df | ||
|
|
36d19e23e0 | ||
|
|
cba9d1aac0 | ||
|
|
dd21a0d2f9 | ||
|
|
302a43b3ae | ||
|
|
2955061b42 | ||
|
|
84644ab693 | ||
|
|
b8f40dde1e | ||
|
|
a6c9789a54 | ||
|
|
a48d9ce27c | ||
|
|
fb825a2708 | ||
|
|
5558dce449 | ||
|
|
cf74a11e65 | ||
|
|
86b5deec81 |
8
.air.toml
Normal file
8
.air.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
# .air.toml
|
||||
[build]
|
||||
cmd = "make build"
|
||||
bin = "./local-ai"
|
||||
args_bin = [ "--debug" ]
|
||||
include_ext = ["go", "html", "yaml", "toml", "json", "txt", "md"]
|
||||
exclude_dir = ["pkg/grpc/proto"]
|
||||
delay = 1000
|
||||
4
.github/gallery-agent/agent.go
vendored
4
.github/gallery-agent/agent.go
vendored
@@ -7,8 +7,8 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||
"github.com/mudler/cogito"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
cogito "github.com/mudler/cogito"
|
||||
|
||||
"github.com/mudler/cogito/structures"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
|
||||
39
.github/gallery-agent/go.mod
vendored
39
.github/gallery-agent/go.mod
vendored
@@ -1,39 +0,0 @@
|
||||
module github.com/go-skynet/LocalAI/.github/gallery-agent
|
||||
|
||||
go 1.24.1
|
||||
|
||||
require (
|
||||
github.com/mudler/cogito v0.3.0
|
||||
github.com/onsi/ginkgo/v2 v2.25.3
|
||||
github.com/onsi/gomega v1.38.2
|
||||
github.com/sashabaranov/go-openai v1.41.2
|
||||
github.com/tmc/langchaingo v0.1.13
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.1 // indirect
|
||||
github.com/Masterminds/goutils v1.1.1 // indirect
|
||||
github.com/Masterminds/semver/v3 v3.4.0 // indirect
|
||||
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/huandu/xstrings v1.5.0 // indirect
|
||||
github.com/mitchellh/copystructure v1.2.0 // indirect
|
||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||
github.com/modelcontextprotocol/go-sdk v1.0.0 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/spf13/cast v1.7.0 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.41.0 // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.28.0 // indirect
|
||||
golang.org/x/tools v0.36.0 // indirect
|
||||
)
|
||||
168
.github/gallery-agent/go.sum
vendored
168
.github/gallery-agent/go.sum
vendored
@@ -1,168 +0,0 @@
|
||||
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
||||
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
|
||||
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
|
||||
github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0=
|
||||
github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
|
||||
github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs=
|
||||
github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
|
||||
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
|
||||
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.2.2+incompatible h1:CjwRSksz8Yo4+RmQ339Dp/D2tGO5JxwYeqtMOEe0LDw=
|
||||
github.com/docker/docker v28.2.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
|
||||
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8=
|
||||
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
|
||||
github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
|
||||
github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ=
|
||||
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
||||
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
|
||||
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
||||
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
||||
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
|
||||
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
||||
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
||||
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||
github.com/modelcontextprotocol/go-sdk v1.0.0 h1:Z4MSjLi38bTgLrd/LjSmofqRqyBiVKRyQSJgw8q8V74=
|
||||
github.com/modelcontextprotocol/go-sdk v1.0.0/go.mod h1:nYtYQroQ2KQiM0/SbyEPUWQ6xs4B95gJjEalc9AQyOs=
|
||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/mudler/cogito v0.3.0 h1:NbVAO3bLkK5oGSY0xq87jlz8C9OIsLW55s+8Hfzeu9s=
|
||||
github.com/mudler/cogito v0.3.0/go.mod h1:abMwl+CUjCp87IufA2quZdZt0bbLaHHN79o17HbUKxU=
|
||||
github.com/onsi/ginkgo/v2 v2.25.3 h1:Ty8+Yi/ayDAGtk4XxmmfUy4GabvM+MegeB4cDLRi6nw=
|
||||
github.com/onsi/ginkgo/v2 v2.25.3/go.mod h1:43uiyQC4Ed2tkOzLsEYm7hnrb7UJTWHYNsuy3bG/snE=
|
||||
github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A=
|
||||
github.com/onsi/gomega v1.38.2/go.mod h1:W2MJcYxRGV63b418Ai34Ud0hEdTVXq9NW9+Sx6uXf3k=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM=
|
||||
github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/shirou/gopsutil/v4 v4.25.5 h1:rtd9piuSMGeU8g1RMXjZs9y9luK5BwtnG7dZaQUJAsc=
|
||||
github.com/shirou/gopsutil/v4 v4.25.5/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
||||
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/testcontainers/testcontainers-go v0.38.0 h1:d7uEapLcv2P8AvH8ahLqDMMxda2W9gQN1nRbHS28HBw=
|
||||
github.com/testcontainers/testcontainers-go v0.38.0/go.mod h1:C52c9MoHpWO+C4aqmgSU+hxlR5jlEayWtgYrb8Pzz1w=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA=
|
||||
github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
||||
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
|
||||
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
|
||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
2
.github/gallery-agent/main.go
vendored
2
.github/gallery-agent/main.go
vendored
@@ -9,7 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
// ProcessedModelFile represents a processed model file with additional metadata
|
||||
|
||||
8
.github/gallery-agent/tools.go
vendored
8
.github/gallery-agent/tools.go
vendored
@@ -3,9 +3,9 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/tmc/langchaingo/jsonschema"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
jsonschema "github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
// Get repository README from HF
|
||||
@@ -13,7 +13,7 @@ type HFReadmeTool struct {
|
||||
client *hfapi.Client
|
||||
}
|
||||
|
||||
func (s *HFReadmeTool) Run(args map[string]any) (string, error) {
|
||||
func (s *HFReadmeTool) Execute(args map[string]any) (string, error) {
|
||||
q, ok := args["repository"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no query")
|
||||
|
||||
4
.github/workflows/backend.yml
vendored
4
.github/workflows/backend.yml
vendored
@@ -1090,7 +1090,7 @@ jobs:
|
||||
go-version: ['1.21.x']
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
@@ -1176,7 +1176,7 @@ jobs:
|
||||
go-version: ['1.21.x']
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
|
||||
2
.github/workflows/backend_build.yml
vendored
2
.github/workflows/backend_build.yml
vendored
@@ -97,7 +97,7 @@ jobs:
|
||||
&& sudo apt-get install -y git
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Release space from worker
|
||||
if: inputs.runs-on == 'ubuntu-latest'
|
||||
|
||||
2
.github/workflows/backend_build_darwin.yml
vendored
2
.github/workflows/backend_build_darwin.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
go-version: ['${{ inputs.go-version }}']
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
|
||||
|
||||
2
.github/workflows/backend_pr.yml
vendored
2
.github/workflows/backend_pr.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
has-backends-darwin: ${{ steps.set-matrix.outputs.has-backends-darwin }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
6
.github/workflows/build-test.yaml
vendored
6
.github/workflows/build-test.yaml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up Go
|
||||
@@ -25,7 +25,7 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up Go
|
||||
@@ -47,7 +47,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up Go
|
||||
|
||||
6
.github/workflows/bump_deps.yaml
vendored
6
.github/workflows/bump_deps.yaml
vendored
@@ -1,10 +1,10 @@
|
||||
name: Bump dependencies
|
||||
name: Bump Backend dependencies
|
||||
on:
|
||||
schedule:
|
||||
- cron: 0 20 * * *
|
||||
workflow_dispatch:
|
||||
jobs:
|
||||
bump:
|
||||
bump-backends:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
file: "backend/go/piper/Makefile"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- name: Bump dependencies 🔧
|
||||
id: bump
|
||||
run: |
|
||||
|
||||
6
.github/workflows/bump_docs.yaml
vendored
6
.github/workflows/bump_docs.yaml
vendored
@@ -1,10 +1,10 @@
|
||||
name: Bump dependencies
|
||||
name: Bump Documentation
|
||||
on:
|
||||
schedule:
|
||||
- cron: 0 20 * * *
|
||||
workflow_dispatch:
|
||||
jobs:
|
||||
bump:
|
||||
bump-docs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -12,7 +12,7 @@ jobs:
|
||||
- repository: "mudler/LocalAI"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- name: Bump dependencies 🔧
|
||||
run: |
|
||||
bash .github/bump_docs.sh ${{ matrix.repository }}
|
||||
|
||||
2
.github/workflows/checksum_checker.yaml
vendored
2
.github/workflows/checksum_checker.yaml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
&& sudo add-apt-repository -y ppa:git-core/ppa \
|
||||
&& sudo apt-get update \
|
||||
&& sudo apt-get install -y git
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
|
||||
2
.github/workflows/dependabot_auto.yml
vendored
2
.github/workflows/dependabot_auto.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
skip-commit-verification: true
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Approve a PR if not already approved
|
||||
run: |
|
||||
|
||||
6
.github/workflows/deploy-explorer.yaml
vendored
6
.github/workflows/deploy-explorer.yaml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- uses: actions/setup-go@v5
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
run: |
|
||||
CGO_ENABLED=0 make build
|
||||
- name: rm
|
||||
uses: appleboy/ssh-action@v1.2.2
|
||||
uses: appleboy/ssh-action@v1.2.3
|
||||
with:
|
||||
host: ${{ secrets.EXPLORER_SSH_HOST }}
|
||||
username: ${{ secrets.EXPLORER_SSH_USERNAME }}
|
||||
@@ -53,7 +53,7 @@ jobs:
|
||||
rm: true
|
||||
target: ./local-ai
|
||||
- name: restarting
|
||||
uses: appleboy/ssh-action@v1.2.2
|
||||
uses: appleboy/ssh-action@v1.2.3
|
||||
with:
|
||||
host: ${{ secrets.EXPLORER_SSH_HOST }}
|
||||
username: ${{ secrets.EXPLORER_SSH_USERNAME }}
|
||||
|
||||
13
.github/workflows/gallery-agent.yaml
vendored
13
.github/workflows/gallery-agent.yaml
vendored
@@ -2,7 +2,7 @@ name: Gallery Agent
|
||||
on:
|
||||
|
||||
schedule:
|
||||
- cron: '0 */1 * * *' # Run every 4 hours
|
||||
- cron: '0 */3 * * *' # Run every 4 hours
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
search_term:
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
@@ -39,11 +39,6 @@ jobs:
|
||||
with:
|
||||
go-version: '1.21'
|
||||
|
||||
- name: Build gallery agent
|
||||
run: |
|
||||
cd .github/gallery-agent
|
||||
go mod download
|
||||
go build -o gallery-agent .
|
||||
|
||||
- name: Run gallery agent
|
||||
env:
|
||||
@@ -56,9 +51,7 @@ jobs:
|
||||
MAX_MODELS: ${{ github.event.inputs.max_models || '1' }}
|
||||
run: |
|
||||
export GALLERY_INDEX_PATH=$PWD/gallery/index.yaml
|
||||
cd .github/gallery-agent
|
||||
./gallery-agent
|
||||
rm -rf gallery-agent
|
||||
go run .github/gallery-agent
|
||||
|
||||
- name: Check for changes
|
||||
id: check_changes
|
||||
|
||||
2
.github/workflows/generate_grpc_cache.yaml
vendored
2
.github/workflows/generate_grpc_cache.yaml
vendored
@@ -73,7 +73,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@master
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Cache GRPC
|
||||
uses: docker/build-push-action@v6
|
||||
|
||||
4
.github/workflows/generate_intel_image.yaml
vendored
4
.github/workflows/generate_intel_image.yaml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
matrix:
|
||||
include:
|
||||
- base-image: intel/oneapi-basekit:2025.2.0-0-devel-ubuntu22.04
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on: 'arc-runner-set'
|
||||
platforms: 'linux/amd64'
|
||||
runs-on: ${{matrix.runs-on}}
|
||||
steps:
|
||||
@@ -43,7 +43,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@master
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Cache Intel images
|
||||
uses: docker/build-push-action@v6
|
||||
|
||||
2
.github/workflows/image_build.yml
vendored
2
.github/workflows/image_build.yml
vendored
@@ -94,7 +94,7 @@ jobs:
|
||||
&& sudo apt-get update \
|
||||
&& sudo apt-get install -y git
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Release space from worker
|
||||
if: inputs.runs-on == 'ubuntu-latest'
|
||||
|
||||
2
.github/workflows/localaibot_automerge.yml
vendored
2
.github/workflows/localaibot_automerge.yml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
if: ${{ github.actor == 'localai-bot' && !contains(github.event.pull_request.title, 'chore(model gallery):') }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Approve a PR if not already approved
|
||||
run: |
|
||||
|
||||
4
.github/workflows/notify-models.yaml
vendored
4
.github/workflows/notify-models.yaml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
MODEL_NAME: gemma-3-12b-it-qat
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0 # needed to checkout all branches for this Action to work
|
||||
ref: ${{ github.event.pull_request.head.sha }} # Checkout the PR head to get the actual changes
|
||||
@@ -95,7 +95,7 @@ jobs:
|
||||
MODEL_NAME: gemma-3-12b-it-qat
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0 # needed to checkout all branches for this Action to work
|
||||
ref: ${{ github.event.pull_request.head.sha }} # Checkout the PR head to get the actual changes
|
||||
|
||||
6
.github/workflows/release.yaml
vendored
6
.github/workflows/release.yaml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up Go
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up Go
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up Go
|
||||
|
||||
2
.github/workflows/secscan.yaml
vendored
2
.github/workflows/secscan.yaml
vendored
@@ -14,7 +14,7 @@ jobs:
|
||||
GO111MODULE: on
|
||||
steps:
|
||||
- name: Checkout Source
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
- name: Run Gosec Security Scanner
|
||||
if: ${{ github.actor != 'dependabot[bot]' }}
|
||||
|
||||
18
.github/workflows/test-extra.yml
vendored
18
.github/workflows/test-extra.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
# runs-on: ubuntu-latest
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# uses: actions/checkout@v5
|
||||
# uses: actions/checkout@v6
|
||||
# with:
|
||||
# submodules: true
|
||||
# - name: Dependencies
|
||||
@@ -40,7 +40,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
@@ -61,7 +61,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
@@ -104,7 +104,7 @@ jobs:
|
||||
# runs-on: ubuntu-latest
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# uses: actions/checkout@v5
|
||||
# uses: actions/checkout@v6
|
||||
# with:
|
||||
# submodules: true
|
||||
# - name: Dependencies
|
||||
@@ -124,7 +124,7 @@ jobs:
|
||||
# runs-on: ubuntu-latest
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# uses: actions/checkout@v5
|
||||
# uses: actions/checkout@v6
|
||||
# with:
|
||||
# submodules: true
|
||||
# - name: Dependencies
|
||||
@@ -186,7 +186,7 @@ jobs:
|
||||
# sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true
|
||||
# df -h
|
||||
# - name: Clone
|
||||
# uses: actions/checkout@v5
|
||||
# uses: actions/checkout@v6
|
||||
# with:
|
||||
# submodules: true
|
||||
# - name: Dependencies
|
||||
@@ -211,7 +211,7 @@ jobs:
|
||||
# runs-on: ubuntu-latest
|
||||
# steps:
|
||||
# - name: Clone
|
||||
# uses: actions/checkout@v5
|
||||
# uses: actions/checkout@v6
|
||||
# with:
|
||||
# submodules: true
|
||||
# - name: Dependencies
|
||||
@@ -232,7 +232,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
|
||||
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@@ -70,7 +70,7 @@ jobs:
|
||||
sudo rm -rfv build || true
|
||||
df -h
|
||||
- name: Clone
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
@@ -166,7 +166,7 @@ jobs:
|
||||
sudo rm -rfv build || true
|
||||
df -h
|
||||
- name: Clone
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Dependencies
|
||||
@@ -196,7 +196,7 @@ jobs:
|
||||
go-version: ['1.25.x']
|
||||
steps:
|
||||
- name: Clone
|
||||
uses: actions/checkout@v5
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
submodules: true
|
||||
- name: Setup Go ${{ matrix.go-version }}
|
||||
|
||||
2
.github/workflows/update_swagger.yaml
vendored
2
.github/workflows/update_swagger.yaml
vendored
@@ -9,7 +9,7 @@ jobs:
|
||||
fail-fast: false
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: 'stable'
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,6 +1,3 @@
|
||||
[submodule "docs/themes/hugo-theme-relearn"]
|
||||
path = docs/themes/hugo-theme-relearn
|
||||
url = https://github.com/McShelby/hugo-theme-relearn.git
|
||||
[submodule "docs/themes/lotusdocs"]
|
||||
path = docs/themes/lotusdocs
|
||||
url = https://github.com/colinwilson/lotusdocs
|
||||
|
||||
@@ -30,6 +30,7 @@ Thank you for your interest in contributing to LocalAI! We appreciate your time
|
||||
3. Install the required dependencies ( see https://localai.io/basics/build/#build-localai-locally )
|
||||
4. Build LocalAI: `make build`
|
||||
5. Run LocalAI: `./local-ai`
|
||||
6. To Build and live reload: `make build-dev`
|
||||
|
||||
## Contributing
|
||||
|
||||
@@ -76,7 +77,7 @@ LOCALAI_IMAGE_TAG=test LOCALAI_IMAGE=local-ai-aio make run-e2e-aio
|
||||
## Documentation
|
||||
|
||||
We are welcome the contribution of the documents, please open new PR or create a new issue. The documentation is available under `docs/` https://github.com/mudler/LocalAI/tree/master/docs
|
||||
|
||||
|
||||
## Community and Communication
|
||||
|
||||
- You can reach out via the Github issue tracker.
|
||||
|
||||
@@ -332,6 +332,6 @@ RUN mkdir -p /models /backends
|
||||
HEALTHCHECK --interval=1m --timeout=10m --retries=10 \
|
||||
CMD curl -f ${HEALTHCHECK_ENDPOINT} || exit 1
|
||||
|
||||
VOLUME /models /backends
|
||||
VOLUME /models /backends /configuration
|
||||
EXPOSE 8080
|
||||
ENTRYPOINT [ "/entrypoint.sh" ]
|
||||
|
||||
4
Makefile
4
Makefile
@@ -103,6 +103,10 @@ build-launcher: ## Build the launcher application
|
||||
|
||||
build-all: build build-launcher ## Build both server and launcher
|
||||
|
||||
build-dev: ## Run LocalAI in dev mode with live reload
|
||||
@command -v air >/dev/null 2>&1 || go install github.com/air-verse/air@latest
|
||||
air -c .air.toml
|
||||
|
||||
dev-dist:
|
||||
$(GORELEASER) build --snapshot --clean
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@
|
||||
|
||||
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
||||
>
|
||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🌍 Explorer](https://explorer.localai.io) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
|
||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
|
||||
[](https://t.me/localaiofficial_bot)
|
||||
|
||||
[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[](https://artifacthub.io/packages/search?repo=localai)
|
||||
@@ -108,7 +108,7 @@ Run the installer script:
|
||||
curl https://localai.io/install.sh | sh
|
||||
```
|
||||
|
||||
For more installation options, see [Installer Options](https://localai.io/docs/advanced/installer/).
|
||||
For more installation options, see [Installer Options](https://localai.io/installation/).
|
||||
|
||||
### macOS Download:
|
||||
|
||||
@@ -116,6 +116,8 @@ For more installation options, see [Installer Options](https://localai.io/docs/a
|
||||
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
|
||||
</a>
|
||||
|
||||
> Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244
|
||||
|
||||
Or run with docker:
|
||||
|
||||
> **💡 Docker Run vs Docker Start**
|
||||
@@ -200,10 +202,11 @@ local-ai run oci://localai/phi-2:latest
|
||||
|
||||
> ⚡ **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection).
|
||||
|
||||
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html)
|
||||
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html), if you are interested in our roadmap items and future enhancements, you can see the [Issues labeled as Roadmap here](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
||||
|
||||
## 📰 Latest project news
|
||||
|
||||
- November 2025: Major improvements to the UX. Among these: [Import models via URL](https://github.com/mudler/LocalAI/pull/7245) and [Multiple chats and history](https://github.com/mudler/LocalAI/pull/7325)
|
||||
- October 2025: 🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support added for agentic capabilities with external tools
|
||||
- September 2025: New Launcher application for MacOS and Linux, extended support to many backends for Mac and Nvidia L4T devices. Models: Added MLX-Audio, WAN 2.2. WebUI improvements and Python-based backends now ships portable python environments.
|
||||
- August 2025: MLX, MLX-VLM, Diffusers and llama.cpp are now supported on Mac M1/M2/M3+ chips ( with `development` suffix in the gallery ): https://github.com/mudler/LocalAI/pull/6049 https://github.com/mudler/LocalAI/pull/6119 https://github.com/mudler/LocalAI/pull/6121 https://github.com/mudler/LocalAI/pull/6060
|
||||
|
||||
@@ -154,6 +154,10 @@ message PredictOptions {
|
||||
repeated string Videos = 45;
|
||||
repeated string Audios = 46;
|
||||
string CorrelationId = 47;
|
||||
string Tools = 48; // JSON array of available tools/functions for tool calling
|
||||
string ToolChoice = 49; // JSON string or object specifying tool choice behavior
|
||||
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
|
||||
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
|
||||
}
|
||||
|
||||
// The response message containing the result
|
||||
@@ -164,6 +168,7 @@ message Reply {
|
||||
double timing_prompt_processing = 4;
|
||||
double timing_token_generation = 5;
|
||||
bytes audio = 6;
|
||||
bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format
|
||||
}
|
||||
|
||||
message GrammarTrigger {
|
||||
@@ -382,6 +387,11 @@ message StatusResponse {
|
||||
message Message {
|
||||
string role = 1;
|
||||
string content = 2;
|
||||
// Optional fields for OpenAI-compatible message format
|
||||
string name = 3; // Tool name (for tool messages)
|
||||
string tool_call_id = 4; // Tool call ID (for tool messages)
|
||||
string reasoning_content = 5; // Reasoning content (for thinking models)
|
||||
string tool_calls = 6; // Tool calls as JSON string (for assistant messages with tool calls)
|
||||
}
|
||||
|
||||
message DetectOptions {
|
||||
|
||||
@@ -57,7 +57,7 @@ add_library(hw_grpc_proto
|
||||
${hw_proto_srcs}
|
||||
${hw_proto_hdrs} )
|
||||
|
||||
add_executable(${TARGET} grpc-server.cpp utils.hpp json.hpp httplib.h)
|
||||
add_executable(${TARGET} grpc-server.cpp json.hpp httplib.h)
|
||||
|
||||
target_include_directories(${TARGET} PRIVATE ../llava)
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=5a4ff43e7dd049e35942bc3d12361dab2f155544
|
||||
LLAMA_VERSION?=583cb83416467e8abf9b37349dcf1f6a0083745a
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,10 +9,13 @@ done
|
||||
|
||||
set -e
|
||||
|
||||
for file in $(ls llama.cpp/tools/server/); do
|
||||
cp -rfv llama.cpp/tools/server/$file llama.cpp/tools/grpc-server/
|
||||
done
|
||||
|
||||
cp -r CMakeLists.txt llama.cpp/tools/grpc-server/
|
||||
cp -r grpc-server.cpp llama.cpp/tools/grpc-server/
|
||||
cp -rfv llama.cpp/vendor/nlohmann/json.hpp llama.cpp/tools/grpc-server/
|
||||
cp -rfv llama.cpp/tools/server/utils.hpp llama.cpp/tools/grpc-server/
|
||||
cp -rfv llama.cpp/vendor/cpp-httplib/httplib.h llama.cpp/tools/grpc-server/
|
||||
|
||||
set +e
|
||||
|
||||
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=f16c12f3f55f5bd3d6ac8cf2f31ab90a42c884d5
|
||||
WHISPER_CPP_VERSION?=19ceec8eac980403b714d603e5ca31653cd42a3f
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
accelerate
|
||||
torch
|
||||
torchaudio
|
||||
numpy>=1.24.0,<1.26.0
|
||||
transformers
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
torch==2.6.0+cu118
|
||||
torchaudio==2.6.0+cu118
|
||||
transformers==4.46.3
|
||||
numpy>=1.24.0,<1.26.0
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
@@ -1,6 +1,7 @@
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
numpy>=1.24.0,<1.26.0
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
torch==2.6.0+rocm6.1
|
||||
torchaudio==2.6.0+rocm6.1
|
||||
transformers
|
||||
numpy>=1.24.0,<1.26.0
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
|
||||
@@ -3,6 +3,7 @@ intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
torchaudio==2.3.1+cxx11.abi
|
||||
transformers
|
||||
numpy>=1.24.0,<1.26.0
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
|
||||
@@ -2,5 +2,6 @@
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
numpy>=1.24.0,<1.26.0
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
|
||||
@@ -61,7 +61,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if request.PipelineType != "": # Reuse the PipelineType field for language
|
||||
kwargs['lang'] = request.PipelineType
|
||||
self.model_name = model_name
|
||||
self.model = Reranker(model_name, **kwargs)
|
||||
self.model = Reranker(model_name, **kwargs)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
@@ -75,12 +75,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
documents.append(doc)
|
||||
ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents))))
|
||||
# Prepare results to return
|
||||
cropped_results = ranked_results.top_k(request.top_n) if request.top_n > 0 else ranked_results
|
||||
results = [
|
||||
backend_pb2.DocumentResult(
|
||||
index=res.doc_id,
|
||||
text=res.text,
|
||||
relevance_score=res.score
|
||||
) for res in ranked_results.results
|
||||
) for res in (cropped_results)
|
||||
]
|
||||
|
||||
# Calculate the usage and total tokens
|
||||
|
||||
@@ -76,7 +76,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||
)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
|
||||
rerank_response = stub.Rerank(request)
|
||||
print(rerank_response.results[0])
|
||||
self.assertIsNotNone(rerank_response.results)
|
||||
@@ -87,4 +87,60 @@ class TestBackendServicer(unittest.TestCase):
|
||||
print(err)
|
||||
self.fail("Reranker service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
self.tearDown()
|
||||
|
||||
def test_rerank_omit_top_n(self):
|
||||
"""
|
||||
This method tests if the embeddings are generated successfully even top_n is omitted
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
request = backend_pb2.RerankRequest(
|
||||
query="I love you",
|
||||
documents=["I hate you", "I really like you"],
|
||||
top_n=0 #
|
||||
)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
rerank_response = stub.Rerank(request)
|
||||
print(rerank_response.results[0])
|
||||
self.assertIsNotNone(rerank_response.results)
|
||||
self.assertEqual(len(rerank_response.results), 2)
|
||||
self.assertEqual(rerank_response.results[0].text, "I really like you")
|
||||
self.assertEqual(rerank_response.results[1].text, "I hate you")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Reranker service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_rerank_crop(self):
|
||||
"""
|
||||
This method tests top_n cropping
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
request = backend_pb2.RerankRequest(
|
||||
query="I love you",
|
||||
documents=["I hate you", "I really like you", "I hate ignoring top_n"],
|
||||
top_n=2
|
||||
)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
rerank_response = stub.Rerank(request)
|
||||
print(rerank_response.results[0])
|
||||
self.assertIsNotNone(rerank_response.results)
|
||||
self.assertEqual(len(rerank_response.results), 2)
|
||||
self.assertEqual(rerank_response.results[0].text, "I really like you")
|
||||
self.assertEqual(rerank_response.results[1].text, "I hate you")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Reranker service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
@@ -6,4 +6,4 @@ transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.1.0
|
||||
protobuf==6.32.0
|
||||
protobuf==6.33.1
|
||||
@@ -7,4 +7,4 @@ transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.1.0
|
||||
protobuf==6.32.0
|
||||
protobuf==6.33.1
|
||||
@@ -6,4 +6,4 @@ transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.1.0
|
||||
protobuf==6.32.0
|
||||
protobuf==6.33.1
|
||||
@@ -8,4 +8,4 @@ bitsandbytes
|
||||
outetts
|
||||
bitsandbytes
|
||||
sentence-transformers==5.1.0
|
||||
protobuf==6.32.0
|
||||
protobuf==6.33.1
|
||||
@@ -10,4 +10,4 @@ intel-extension-for-transformers
|
||||
bitsandbytes
|
||||
outetts
|
||||
sentence-transformers==5.1.0
|
||||
protobuf==6.32.0
|
||||
protobuf==6.33.1
|
||||
@@ -1,5 +1,5 @@
|
||||
grpcio==1.76.0
|
||||
protobuf==6.32.0
|
||||
protobuf==6.33.1
|
||||
certifi
|
||||
setuptools
|
||||
scipy==1.15.1
|
||||
|
||||
@@ -3,6 +3,13 @@ set -e
|
||||
|
||||
EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation"
|
||||
|
||||
# Avoid to overcommit the CPU during build
|
||||
# https://github.com/vllm-project/vllm/issues/20079
|
||||
# https://docs.vllm.ai/en/v0.8.3/serving/env_vars.html
|
||||
# https://docs.redhat.com/it/documentation/red_hat_ai_inference_server/3.0/html/vllm_server_arguments/environment_variables-server-arguments
|
||||
export NVCC_THREADS=2
|
||||
export MAX_JOBS=1
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
|
||||
@@ -1 +1 @@
|
||||
flash-attn
|
||||
https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
@@ -11,8 +14,14 @@ type Application struct {
|
||||
backendLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
startupConfig *config.ApplicationConfig // Stores original config from env vars (before file loading)
|
||||
templatesEvaluator *templates.Evaluator
|
||||
galleryService *services.GalleryService
|
||||
watchdogMutex sync.Mutex
|
||||
watchdogStop chan bool
|
||||
p2pMutex sync.Mutex
|
||||
p2pCtx context.Context
|
||||
p2pCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newApplication(appConfig *config.ApplicationConfig) *Application {
|
||||
@@ -44,6 +53,11 @@ func (a *Application) GalleryService() *services.GalleryService {
|
||||
return a.galleryService
|
||||
}
|
||||
|
||||
// StartupConfig returns the original startup configuration (from env vars, before file loading)
|
||||
func (a *Application) StartupConfig() *config.ApplicationConfig {
|
||||
return a.startupConfig
|
||||
}
|
||||
|
||||
func (a *Application) start() error {
|
||||
galleryService := services.NewGalleryService(a.ApplicationConfig(), a.ModelLoader())
|
||||
err := galleryService.Start(a.ApplicationConfig().Context, a.ModelConfigLoader(), a.ApplicationConfig().SystemState)
|
||||
|
||||
@@ -1,180 +1,343 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type fileHandler func(fileContent []byte, appConfig *config.ApplicationConfig) error
|
||||
|
||||
type configFileHandler struct {
|
||||
handlers map[string]fileHandler
|
||||
|
||||
watcher *fsnotify.Watcher
|
||||
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
// TODO: This should be a singleton eventually so other parts of the code can register config file handlers,
|
||||
// then we can export it to other packages
|
||||
func newConfigFileHandler(appConfig *config.ApplicationConfig) configFileHandler {
|
||||
c := configFileHandler{
|
||||
handlers: make(map[string]fileHandler),
|
||||
appConfig: appConfig,
|
||||
}
|
||||
err := c.Register("api_keys.json", readApiKeysJson(*appConfig), true)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("file", "api_keys.json").Msg("unable to register config file handler")
|
||||
}
|
||||
err = c.Register("external_backends.json", readExternalBackendsJson(*appConfig), true)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("file", "external_backends.json").Msg("unable to register config file handler")
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *configFileHandler) Register(filename string, handler fileHandler, runNow bool) error {
|
||||
_, ok := c.handlers[filename]
|
||||
if ok {
|
||||
return fmt.Errorf("handler already registered for file %s", filename)
|
||||
}
|
||||
c.handlers[filename] = handler
|
||||
if runNow {
|
||||
c.callHandler(filename, handler)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configFileHandler) callHandler(filename string, handler fileHandler) {
|
||||
rootedFilePath := filepath.Join(c.appConfig.DynamicConfigsDir, filepath.Clean(filename))
|
||||
log.Trace().Str("filename", rootedFilePath).Msg("reading file for dynamic config update")
|
||||
fileContent, err := os.ReadFile(rootedFilePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
log.Error().Err(err).Str("filename", rootedFilePath).Msg("could not read file")
|
||||
}
|
||||
|
||||
if err = handler(fileContent, c.appConfig); err != nil {
|
||||
log.Error().Err(err).Msg("WatchConfigDirectory goroutine failed to update options")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *configFileHandler) Watch() error {
|
||||
configWatcher, err := fsnotify.NewWatcher()
|
||||
c.watcher = configWatcher
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.appConfig.DynamicConfigsDirPollInterval > 0 {
|
||||
log.Debug().Msg("Poll interval set, falling back to polling for configuration changes")
|
||||
ticker := time.NewTicker(c.appConfig.DynamicConfigsDirPollInterval)
|
||||
go func() {
|
||||
for {
|
||||
<-ticker.C
|
||||
for file, handler := range c.handlers {
|
||||
log.Debug().Str("file", file).Msg("polling config file")
|
||||
c.callHandler(file, handler)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start listening for events.
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-c.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Write | fsnotify.Create | fsnotify.Remove) {
|
||||
handler, ok := c.handlers[path.Base(event.Name)]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
c.callHandler(filepath.Base(event.Name), handler)
|
||||
}
|
||||
case err, ok := <-c.watcher.Errors:
|
||||
log.Error().Err(err).Msg("config watcher error received")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Add a path.
|
||||
err = c.watcher.Add(c.appConfig.DynamicConfigsDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create a watcher on the configuration directory: %+v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: When we institute graceful shutdown, this should be called
|
||||
func (c *configFileHandler) Stop() error {
|
||||
return c.watcher.Close()
|
||||
}
|
||||
|
||||
func readApiKeysJson(startupAppConfig config.ApplicationConfig) fileHandler {
|
||||
handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error {
|
||||
log.Debug().Msg("processing api keys runtime update")
|
||||
log.Trace().Int("numKeys", len(startupAppConfig.ApiKeys)).Msg("api keys provided at startup")
|
||||
|
||||
if len(fileContent) > 0 {
|
||||
// Parse JSON content from the file
|
||||
var fileKeys []string
|
||||
err := json.Unmarshal(fileContent, &fileKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Trace().Int("numKeys", len(fileKeys)).Msg("discovered API keys from api keys dynamic config dile")
|
||||
|
||||
appConfig.ApiKeys = append(startupAppConfig.ApiKeys, fileKeys...)
|
||||
} else {
|
||||
log.Trace().Msg("no API keys discovered from dynamic config file")
|
||||
appConfig.ApiKeys = startupAppConfig.ApiKeys
|
||||
}
|
||||
log.Trace().Int("numKeys", len(appConfig.ApiKeys)).Msg("total api keys after processing")
|
||||
return nil
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
func readExternalBackendsJson(startupAppConfig config.ApplicationConfig) fileHandler {
|
||||
handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error {
|
||||
log.Debug().Msg("processing external_backends.json")
|
||||
|
||||
if len(fileContent) > 0 {
|
||||
// Parse JSON content from the file
|
||||
var fileBackends map[string]string
|
||||
err := json.Unmarshal(fileContent, &fileBackends)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
appConfig.ExternalGRPCBackends = startupAppConfig.ExternalGRPCBackends
|
||||
err = mergo.Merge(&appConfig.ExternalGRPCBackends, &fileBackends)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
appConfig.ExternalGRPCBackends = startupAppConfig.ExternalGRPCBackends
|
||||
}
|
||||
log.Debug().Msg("external backends loaded from external_backends.json")
|
||||
return nil
|
||||
}
|
||||
return handler
|
||||
}
|
||||
package application
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type fileHandler func(fileContent []byte, appConfig *config.ApplicationConfig) error
|
||||
|
||||
type configFileHandler struct {
|
||||
handlers map[string]fileHandler
|
||||
|
||||
watcher *fsnotify.Watcher
|
||||
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
// TODO: This should be a singleton eventually so other parts of the code can register config file handlers,
|
||||
// then we can export it to other packages
|
||||
func newConfigFileHandler(appConfig *config.ApplicationConfig) configFileHandler {
|
||||
c := configFileHandler{
|
||||
handlers: make(map[string]fileHandler),
|
||||
appConfig: appConfig,
|
||||
}
|
||||
err := c.Register("api_keys.json", readApiKeysJson(*appConfig), true)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("file", "api_keys.json").Msg("unable to register config file handler")
|
||||
}
|
||||
err = c.Register("external_backends.json", readExternalBackendsJson(*appConfig), true)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("file", "external_backends.json").Msg("unable to register config file handler")
|
||||
}
|
||||
err = c.Register("runtime_settings.json", readRuntimeSettingsJson(*appConfig), true)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("file", "runtime_settings.json").Msg("unable to register config file handler")
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *configFileHandler) Register(filename string, handler fileHandler, runNow bool) error {
|
||||
_, ok := c.handlers[filename]
|
||||
if ok {
|
||||
return fmt.Errorf("handler already registered for file %s", filename)
|
||||
}
|
||||
c.handlers[filename] = handler
|
||||
if runNow {
|
||||
c.callHandler(filename, handler)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *configFileHandler) callHandler(filename string, handler fileHandler) {
|
||||
rootedFilePath := filepath.Join(c.appConfig.DynamicConfigsDir, filepath.Clean(filename))
|
||||
log.Trace().Str("filename", rootedFilePath).Msg("reading file for dynamic config update")
|
||||
fileContent, err := os.ReadFile(rootedFilePath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
log.Error().Err(err).Str("filename", rootedFilePath).Msg("could not read file")
|
||||
}
|
||||
|
||||
if err = handler(fileContent, c.appConfig); err != nil {
|
||||
log.Error().Err(err).Msg("WatchConfigDirectory goroutine failed to update options")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *configFileHandler) Watch() error {
|
||||
configWatcher, err := fsnotify.NewWatcher()
|
||||
c.watcher = configWatcher
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.appConfig.DynamicConfigsDirPollInterval > 0 {
|
||||
log.Debug().Msg("Poll interval set, falling back to polling for configuration changes")
|
||||
ticker := time.NewTicker(c.appConfig.DynamicConfigsDirPollInterval)
|
||||
go func() {
|
||||
for {
|
||||
<-ticker.C
|
||||
for file, handler := range c.handlers {
|
||||
log.Debug().Str("file", file).Msg("polling config file")
|
||||
c.callHandler(file, handler)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start listening for events.
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-c.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Has(fsnotify.Write | fsnotify.Create | fsnotify.Remove) {
|
||||
handler, ok := c.handlers[path.Base(event.Name)]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
c.callHandler(filepath.Base(event.Name), handler)
|
||||
}
|
||||
case err, ok := <-c.watcher.Errors:
|
||||
log.Error().Err(err).Msg("config watcher error received")
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Add a path.
|
||||
err = c.watcher.Add(c.appConfig.DynamicConfigsDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create a watcher on the configuration directory: %+v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: When we institute graceful shutdown, this should be called
|
||||
func (c *configFileHandler) Stop() error {
|
||||
return c.watcher.Close()
|
||||
}
|
||||
|
||||
func readApiKeysJson(startupAppConfig config.ApplicationConfig) fileHandler {
|
||||
handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error {
|
||||
log.Debug().Msg("processing api keys runtime update")
|
||||
log.Trace().Int("numKeys", len(startupAppConfig.ApiKeys)).Msg("api keys provided at startup")
|
||||
|
||||
if len(fileContent) > 0 {
|
||||
// Parse JSON content from the file
|
||||
var fileKeys []string
|
||||
err := json.Unmarshal(fileContent, &fileKeys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Trace().Int("numKeys", len(fileKeys)).Msg("discovered API keys from api keys dynamic config dile")
|
||||
|
||||
appConfig.ApiKeys = append(startupAppConfig.ApiKeys, fileKeys...)
|
||||
} else {
|
||||
log.Trace().Msg("no API keys discovered from dynamic config file")
|
||||
appConfig.ApiKeys = startupAppConfig.ApiKeys
|
||||
}
|
||||
log.Trace().Int("numKeys", len(appConfig.ApiKeys)).Msg("total api keys after processing")
|
||||
return nil
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
func readExternalBackendsJson(startupAppConfig config.ApplicationConfig) fileHandler {
|
||||
handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error {
|
||||
log.Debug().Msg("processing external_backends.json")
|
||||
|
||||
if len(fileContent) > 0 {
|
||||
// Parse JSON content from the file
|
||||
var fileBackends map[string]string
|
||||
err := json.Unmarshal(fileContent, &fileBackends)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
appConfig.ExternalGRPCBackends = startupAppConfig.ExternalGRPCBackends
|
||||
err = mergo.Merge(&appConfig.ExternalGRPCBackends, &fileBackends)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
appConfig.ExternalGRPCBackends = startupAppConfig.ExternalGRPCBackends
|
||||
}
|
||||
log.Debug().Msg("external backends loaded from external_backends.json")
|
||||
return nil
|
||||
}
|
||||
return handler
|
||||
}
|
||||
|
||||
type runtimeSettings struct {
|
||||
WatchdogEnabled *bool `json:"watchdog_enabled,omitempty"`
|
||||
WatchdogIdleEnabled *bool `json:"watchdog_idle_enabled,omitempty"`
|
||||
WatchdogBusyEnabled *bool `json:"watchdog_busy_enabled,omitempty"`
|
||||
WatchdogIdleTimeout *string `json:"watchdog_idle_timeout,omitempty"`
|
||||
WatchdogBusyTimeout *string `json:"watchdog_busy_timeout,omitempty"`
|
||||
SingleBackend *bool `json:"single_backend,omitempty"`
|
||||
ParallelBackendRequests *bool `json:"parallel_backend_requests,omitempty"`
|
||||
Threads *int `json:"threads,omitempty"`
|
||||
ContextSize *int `json:"context_size,omitempty"`
|
||||
F16 *bool `json:"f16,omitempty"`
|
||||
Debug *bool `json:"debug,omitempty"`
|
||||
CORS *bool `json:"cors,omitempty"`
|
||||
CSRF *bool `json:"csrf,omitempty"`
|
||||
CORSAllowOrigins *string `json:"cors_allow_origins,omitempty"`
|
||||
P2PToken *string `json:"p2p_token,omitempty"`
|
||||
P2PNetworkID *string `json:"p2p_network_id,omitempty"`
|
||||
Federated *bool `json:"federated,omitempty"`
|
||||
Galleries *[]config.Gallery `json:"galleries,omitempty"`
|
||||
BackendGalleries *[]config.Gallery `json:"backend_galleries,omitempty"`
|
||||
AutoloadGalleries *bool `json:"autoload_galleries,omitempty"`
|
||||
AutoloadBackendGalleries *bool `json:"autoload_backend_galleries,omitempty"`
|
||||
ApiKeys *[]string `json:"api_keys,omitempty"`
|
||||
}
|
||||
|
||||
func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHandler {
|
||||
handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error {
|
||||
log.Debug().Msg("processing runtime_settings.json")
|
||||
|
||||
// Determine if settings came from env vars by comparing with startup config
|
||||
// startupAppConfig contains the original values set from env vars at startup.
|
||||
// If current values match startup values, they came from env vars (or defaults).
|
||||
// We apply file settings only if current values match startup values (meaning not from env vars).
|
||||
envWatchdogIdle := appConfig.WatchDogIdle == startupAppConfig.WatchDogIdle
|
||||
envWatchdogBusy := appConfig.WatchDogBusy == startupAppConfig.WatchDogBusy
|
||||
envWatchdogIdleTimeout := appConfig.WatchDogIdleTimeout == startupAppConfig.WatchDogIdleTimeout
|
||||
envWatchdogBusyTimeout := appConfig.WatchDogBusyTimeout == startupAppConfig.WatchDogBusyTimeout
|
||||
envSingleBackend := appConfig.SingleBackend == startupAppConfig.SingleBackend
|
||||
envParallelRequests := appConfig.ParallelBackendRequests == startupAppConfig.ParallelBackendRequests
|
||||
envThreads := appConfig.Threads == startupAppConfig.Threads
|
||||
envContextSize := appConfig.ContextSize == startupAppConfig.ContextSize
|
||||
envF16 := appConfig.F16 == startupAppConfig.F16
|
||||
envDebug := appConfig.Debug == startupAppConfig.Debug
|
||||
envCORS := appConfig.CORS == startupAppConfig.CORS
|
||||
envCSRF := appConfig.CSRF == startupAppConfig.CSRF
|
||||
envCORSAllowOrigins := appConfig.CORSAllowOrigins == startupAppConfig.CORSAllowOrigins
|
||||
envP2PToken := appConfig.P2PToken == startupAppConfig.P2PToken
|
||||
envP2PNetworkID := appConfig.P2PNetworkID == startupAppConfig.P2PNetworkID
|
||||
envFederated := appConfig.Federated == startupAppConfig.Federated
|
||||
envAutoloadGalleries := appConfig.AutoloadGalleries == startupAppConfig.AutoloadGalleries
|
||||
envAutoloadBackendGalleries := appConfig.AutoloadBackendGalleries == startupAppConfig.AutoloadBackendGalleries
|
||||
|
||||
if len(fileContent) > 0 {
|
||||
var settings runtimeSettings
|
||||
err := json.Unmarshal(fileContent, &settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Apply file settings only if they don't match startup values (i.e., not from env vars)
|
||||
if settings.WatchdogIdleEnabled != nil && !envWatchdogIdle {
|
||||
appConfig.WatchDogIdle = *settings.WatchdogIdleEnabled
|
||||
if appConfig.WatchDogIdle {
|
||||
appConfig.WatchDog = true
|
||||
}
|
||||
}
|
||||
if settings.WatchdogBusyEnabled != nil && !envWatchdogBusy {
|
||||
appConfig.WatchDogBusy = *settings.WatchdogBusyEnabled
|
||||
if appConfig.WatchDogBusy {
|
||||
appConfig.WatchDog = true
|
||||
}
|
||||
}
|
||||
if settings.WatchdogIdleTimeout != nil && !envWatchdogIdleTimeout {
|
||||
dur, err := time.ParseDuration(*settings.WatchdogIdleTimeout)
|
||||
if err == nil {
|
||||
appConfig.WatchDogIdleTimeout = dur
|
||||
} else {
|
||||
log.Warn().Err(err).Str("timeout", *settings.WatchdogIdleTimeout).Msg("invalid watchdog idle timeout in runtime_settings.json")
|
||||
}
|
||||
}
|
||||
if settings.WatchdogBusyTimeout != nil && !envWatchdogBusyTimeout {
|
||||
dur, err := time.ParseDuration(*settings.WatchdogBusyTimeout)
|
||||
if err == nil {
|
||||
appConfig.WatchDogBusyTimeout = dur
|
||||
} else {
|
||||
log.Warn().Err(err).Str("timeout", *settings.WatchdogBusyTimeout).Msg("invalid watchdog busy timeout in runtime_settings.json")
|
||||
}
|
||||
}
|
||||
if settings.SingleBackend != nil && !envSingleBackend {
|
||||
appConfig.SingleBackend = *settings.SingleBackend
|
||||
}
|
||||
if settings.ParallelBackendRequests != nil && !envParallelRequests {
|
||||
appConfig.ParallelBackendRequests = *settings.ParallelBackendRequests
|
||||
}
|
||||
if settings.Threads != nil && !envThreads {
|
||||
appConfig.Threads = *settings.Threads
|
||||
}
|
||||
if settings.ContextSize != nil && !envContextSize {
|
||||
appConfig.ContextSize = *settings.ContextSize
|
||||
}
|
||||
if settings.F16 != nil && !envF16 {
|
||||
appConfig.F16 = *settings.F16
|
||||
}
|
||||
if settings.Debug != nil && !envDebug {
|
||||
appConfig.Debug = *settings.Debug
|
||||
}
|
||||
if settings.CORS != nil && !envCORS {
|
||||
appConfig.CORS = *settings.CORS
|
||||
}
|
||||
if settings.CSRF != nil && !envCSRF {
|
||||
appConfig.CSRF = *settings.CSRF
|
||||
}
|
||||
if settings.CORSAllowOrigins != nil && !envCORSAllowOrigins {
|
||||
appConfig.CORSAllowOrigins = *settings.CORSAllowOrigins
|
||||
}
|
||||
if settings.P2PToken != nil && !envP2PToken {
|
||||
appConfig.P2PToken = *settings.P2PToken
|
||||
}
|
||||
if settings.P2PNetworkID != nil && !envP2PNetworkID {
|
||||
appConfig.P2PNetworkID = *settings.P2PNetworkID
|
||||
}
|
||||
if settings.Federated != nil && !envFederated {
|
||||
appConfig.Federated = *settings.Federated
|
||||
}
|
||||
if settings.Galleries != nil {
|
||||
appConfig.Galleries = *settings.Galleries
|
||||
}
|
||||
if settings.BackendGalleries != nil {
|
||||
appConfig.BackendGalleries = *settings.BackendGalleries
|
||||
}
|
||||
if settings.AutoloadGalleries != nil && !envAutoloadGalleries {
|
||||
appConfig.AutoloadGalleries = *settings.AutoloadGalleries
|
||||
}
|
||||
if settings.AutoloadBackendGalleries != nil && !envAutoloadBackendGalleries {
|
||||
appConfig.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries
|
||||
}
|
||||
if settings.ApiKeys != nil {
|
||||
// API keys from env vars (startup) should be kept, runtime settings keys replace all runtime keys
|
||||
// If runtime_settings.json specifies ApiKeys (even if empty), it replaces all runtime keys
|
||||
// Start with env keys, then add runtime_settings.json keys (which may be empty to clear them)
|
||||
envKeys := startupAppConfig.ApiKeys
|
||||
runtimeKeys := *settings.ApiKeys
|
||||
// Replace all runtime keys with what's in runtime_settings.json
|
||||
appConfig.ApiKeys = append(envKeys, runtimeKeys...)
|
||||
}
|
||||
|
||||
// If watchdog is enabled via file but not via env, ensure WatchDog flag is set
|
||||
if !envWatchdogIdle && !envWatchdogBusy {
|
||||
if settings.WatchdogEnabled != nil && *settings.WatchdogEnabled {
|
||||
appConfig.WatchDog = true
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Debug().Msg("runtime settings loaded from runtime_settings.json")
|
||||
return nil
|
||||
}
|
||||
return handler
|
||||
}
|
||||
|
||||
240
core/application/p2p.go
Normal file
240
core/application/p2p.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
|
||||
"github.com/mudler/edgevpn/pkg/node"
|
||||
"github.com/rs/zerolog/log"
|
||||
zlog "github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func (a *Application) StopP2P() error {
|
||||
if a.p2pCancel != nil {
|
||||
a.p2pCancel()
|
||||
a.p2pCancel = nil
|
||||
a.p2pCtx = nil
|
||||
// Wait a bit for shutdown to complete
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Application) StartP2P() error {
|
||||
// we need a p2p token
|
||||
if a.applicationConfig.P2PToken == "" {
|
||||
return fmt.Errorf("P2P token is not set")
|
||||
}
|
||||
|
||||
networkID := a.applicationConfig.P2PNetworkID
|
||||
|
||||
ctx, cancel := context.WithCancel(a.ApplicationConfig().Context)
|
||||
a.p2pCtx = ctx
|
||||
a.p2pCancel = cancel
|
||||
|
||||
var n *node.Node
|
||||
// Here we are avoiding creating multiple nodes:
|
||||
// - if the federated mode is enabled, we create a federated node and expose a service
|
||||
// - exposing a service creates a node with specific options, and we don't want to create another node
|
||||
|
||||
// If the federated mode is enabled, we expose a service to the local instance running
|
||||
// at r.Address
|
||||
if a.applicationConfig.Federated {
|
||||
_, port, err := net.SplitHostPort(a.applicationConfig.APIAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Here a new node is created and started
|
||||
// and a service is exposed by the node
|
||||
node, err := p2p.ExposeService(ctx, "localhost", port, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.FederatedID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := p2p.ServiceDiscoverer(ctx, node, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.FederatedID), nil, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n = node
|
||||
// start node sync in the background
|
||||
if err := a.p2pSync(ctx, node); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If a node wasn't created previously, create it
|
||||
if n == nil {
|
||||
node, err := p2p.NewNode(a.applicationConfig.P2PToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = node.Start(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting new node: %w", err)
|
||||
}
|
||||
n = node
|
||||
}
|
||||
|
||||
// Attach a ServiceDiscoverer to the p2p node
|
||||
log.Info().Msg("Starting P2P server discovery...")
|
||||
if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.WorkerID), func(serviceID string, node schema.NodeData) {
|
||||
var tunnelAddresses []string
|
||||
for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.WorkerID)) {
|
||||
if v.IsOnline() {
|
||||
tunnelAddresses = append(tunnelAddresses, v.TunnelAddress)
|
||||
} else {
|
||||
log.Info().Msgf("Node %s is offline", v.ID)
|
||||
}
|
||||
}
|
||||
if a.applicationConfig.TunnelCallback != nil {
|
||||
a.applicationConfig.TunnelCallback(tunnelAddresses)
|
||||
}
|
||||
}, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RestartP2P restarts the P2P stack with current ApplicationConfig settings
|
||||
// Note: This method signals that P2P should be restarted, but the actual restart
|
||||
// is handled by the caller to avoid import cycles
|
||||
func (a *Application) RestartP2P() error {
|
||||
a.p2pMutex.Lock()
|
||||
defer a.p2pMutex.Unlock()
|
||||
|
||||
// Stop existing P2P if running
|
||||
if a.p2pCancel != nil {
|
||||
a.p2pCancel()
|
||||
a.p2pCancel = nil
|
||||
a.p2pCtx = nil
|
||||
// Wait a bit for shutdown to complete
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
appConfig := a.ApplicationConfig()
|
||||
|
||||
// Start P2P if token is set
|
||||
if appConfig.P2PToken == "" {
|
||||
return fmt.Errorf("P2P token is not set")
|
||||
}
|
||||
|
||||
// Create new context for P2P
|
||||
ctx, cancel := context.WithCancel(appConfig.Context)
|
||||
a.p2pCtx = ctx
|
||||
a.p2pCancel = cancel
|
||||
|
||||
// Get API address from config
|
||||
address := appConfig.APIAddress
|
||||
if address == "" {
|
||||
address = "127.0.0.1:8080" // default
|
||||
}
|
||||
|
||||
// Start P2P stack in a goroutine
|
||||
go func() {
|
||||
if err := a.StartP2P(); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to start P2P stack")
|
||||
cancel() // Cancel context on error
|
||||
}
|
||||
}()
|
||||
log.Info().Msg("P2P stack restarted with new settings")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func syncState(ctx context.Context, n *node.Node, app *Application) error {
|
||||
zlog.Debug().Msg("[p2p-sync] Syncing state")
|
||||
|
||||
whatWeHave := []string{}
|
||||
for _, model := range app.ModelConfigLoader().GetAllModelsConfigs() {
|
||||
whatWeHave = append(whatWeHave, model.Name)
|
||||
}
|
||||
|
||||
ledger, _ := n.Ledger()
|
||||
currentData := ledger.CurrentData()
|
||||
zlog.Debug().Msgf("[p2p-sync] Current data: %v", currentData)
|
||||
data, exists := ledger.GetKey("shared_state", "models")
|
||||
if !exists {
|
||||
ledger.AnnounceUpdate(ctx, time.Minute, "shared_state", "models", whatWeHave)
|
||||
zlog.Debug().Msgf("No models found in the ledger, announced our models: %v", whatWeHave)
|
||||
}
|
||||
|
||||
models := []string{}
|
||||
if err := data.Unmarshal(&models); err != nil {
|
||||
zlog.Warn().Err(err).Msg("error unmarshalling models")
|
||||
return nil
|
||||
}
|
||||
|
||||
zlog.Debug().Msgf("[p2p-sync] Models that are present in this instance: %v\nModels that are in the ledger: %v", whatWeHave, models)
|
||||
|
||||
// Sync with our state
|
||||
whatIsNotThere := []string{}
|
||||
for _, model := range whatWeHave {
|
||||
if !slices.Contains(models, model) {
|
||||
whatIsNotThere = append(whatIsNotThere, model)
|
||||
}
|
||||
}
|
||||
if len(whatIsNotThere) > 0 {
|
||||
zlog.Debug().Msgf("[p2p-sync] Announcing our models: %v", append(models, whatIsNotThere...))
|
||||
ledger.AnnounceUpdate(
|
||||
ctx,
|
||||
1*time.Minute,
|
||||
"shared_state",
|
||||
"models",
|
||||
append(models, whatIsNotThere...),
|
||||
)
|
||||
}
|
||||
|
||||
// Check if we have a model that is not in our state, otherwise install it
|
||||
for _, model := range models {
|
||||
if slices.Contains(whatWeHave, model) {
|
||||
zlog.Debug().Msgf("[p2p-sync] Model %s is already present in this instance", model)
|
||||
continue
|
||||
}
|
||||
|
||||
// we install model
|
||||
zlog.Info().Msgf("[p2p-sync] Installing model which is not present in this instance: %s", model)
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
zlog.Error().Err(err).Msg("error generating UUID")
|
||||
continue
|
||||
}
|
||||
|
||||
app.GalleryService().ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: model,
|
||||
Galleries: app.ApplicationConfig().Galleries,
|
||||
BackendGalleries: app.ApplicationConfig().BackendGalleries,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Application) p2pSync(ctx context.Context, n *node.Node) error {
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(1 * time.Minute):
|
||||
if err := syncState(ctx, n, a); err != nil {
|
||||
zlog.Error().Err(err).Msg("error syncing state")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
@@ -1,8 +1,11 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
@@ -18,13 +21,24 @@ import (
|
||||
|
||||
func New(opts ...config.AppOption) (*Application, error) {
|
||||
options := config.NewApplicationConfig(opts...)
|
||||
|
||||
// Store a copy of the startup config (from env vars, before file loading)
|
||||
// This is used to determine if settings came from env vars vs file
|
||||
startupConfigCopy := *options
|
||||
application := newApplication(options)
|
||||
application.startupConfig = &startupConfigCopy
|
||||
|
||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.SystemState.Model.ModelsPath)
|
||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||
|
||||
if err := application.start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
caps, err := xsysinfo.CPUCapabilities()
|
||||
if err == nil {
|
||||
log.Debug().Msgf("CPU capabilities: %v", caps)
|
||||
|
||||
}
|
||||
gpus, err := xsysinfo.GPUs()
|
||||
if err == nil {
|
||||
@@ -56,12 +70,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
log.Error().Err(err).Msg("error installing models")
|
||||
}
|
||||
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
if err := coreStartup.InstallExternalBackends(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
log.Error().Err(err).Msg("error installing external backend")
|
||||
}
|
||||
}
|
||||
@@ -104,6 +118,13 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Load runtime settings from file if DynamicConfigsDir is set
|
||||
// This applies file settings with env var precedence (env vars take priority)
|
||||
// Note: startupConfigCopy was already created above, so it has the original env var values
|
||||
if options.DynamicConfigsDir != "" {
|
||||
loadRuntimeSettingsFromFile(options)
|
||||
}
|
||||
|
||||
// turn off any process that was started by GRPC if the context is canceled
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
@@ -114,21 +135,8 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
if options.WatchDog {
|
||||
wd := model.NewWatchDog(
|
||||
application.ModelLoader(),
|
||||
options.WatchDogBusyTimeout,
|
||||
options.WatchDogIdleTimeout,
|
||||
options.WatchDogBusy,
|
||||
options.WatchDogIdle)
|
||||
application.ModelLoader().SetWatchDog(wd)
|
||||
go wd.Run()
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
log.Debug().Msgf("Context canceled, shutting down")
|
||||
wd.Shutdown()
|
||||
}()
|
||||
}
|
||||
// Initialize watchdog with current settings (after loading from file)
|
||||
initializeWatchdog(application, options)
|
||||
|
||||
if options.LoadToMemory != nil && !options.SingleBackend {
|
||||
for _, m := range options.LoadToMemory {
|
||||
@@ -152,10 +160,6 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
// Watch the configuration directory
|
||||
startWatcher(options)
|
||||
|
||||
if err := application.start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info().Msg("core/startup process completed!")
|
||||
return application, nil
|
||||
}
|
||||
@@ -184,3 +188,131 @@ func startWatcher(options *config.ApplicationConfig) {
|
||||
log.Error().Err(err).Msg("failed creating watcher")
|
||||
}
|
||||
}
|
||||
|
||||
// loadRuntimeSettingsFromFile loads settings from runtime_settings.json with env var precedence
|
||||
// This function is called at startup, before env vars are applied via AppOptions.
|
||||
// Since env vars are applied via AppOptions in run.go, we need to check if they're set.
|
||||
// We do this by checking if the current options values differ from defaults, which would
|
||||
// indicate they were set from env vars. However, a simpler approach is to just apply
|
||||
// file settings here, and let the AppOptions (which are applied after this) override them.
|
||||
// But actually, this is called AFTER AppOptions are applied in New(), so we need to check env vars.
|
||||
// The cleanest solution: Store original values before applying file, or check if values match
|
||||
// what would be set from env vars. For now, we'll apply file settings and they'll be
|
||||
// overridden by AppOptions if env vars were set (but AppOptions are already applied).
|
||||
// Actually, this function is called in New() before AppOptions are fully processed for watchdog.
|
||||
// Let's check the call order: New() -> loadRuntimeSettingsFromFile() -> initializeWatchdog()
|
||||
// But AppOptions are applied in NewApplicationConfig() which is called first.
|
||||
// So at this point, options already has values from env vars. We should compare against
|
||||
// defaults to see if env vars were set. But we don't have defaults stored.
|
||||
// Simplest: Just apply file settings. If env vars were set, they're already in options.
|
||||
// The file watcher handler will handle runtime changes properly by comparing with startupAppConfig.
|
||||
func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
settingsFile := filepath.Join(options.DynamicConfigsDir, "runtime_settings.json")
|
||||
fileContent, err := os.ReadFile(settingsFile)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Debug().Msg("runtime_settings.json not found, using defaults")
|
||||
return
|
||||
}
|
||||
log.Warn().Err(err).Msg("failed to read runtime_settings.json")
|
||||
return
|
||||
}
|
||||
|
||||
var settings struct {
|
||||
WatchdogEnabled *bool `json:"watchdog_enabled,omitempty"`
|
||||
WatchdogIdleEnabled *bool `json:"watchdog_idle_enabled,omitempty"`
|
||||
WatchdogBusyEnabled *bool `json:"watchdog_busy_enabled,omitempty"`
|
||||
WatchdogIdleTimeout *string `json:"watchdog_idle_timeout,omitempty"`
|
||||
WatchdogBusyTimeout *string `json:"watchdog_busy_timeout,omitempty"`
|
||||
SingleBackend *bool `json:"single_backend,omitempty"`
|
||||
ParallelBackendRequests *bool `json:"parallel_backend_requests,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(fileContent, &settings); err != nil {
|
||||
log.Warn().Err(err).Msg("failed to parse runtime_settings.json")
|
||||
return
|
||||
}
|
||||
|
||||
// At this point, options already has values from env vars (via AppOptions in run.go).
|
||||
// To avoid env var duplication, we determine if env vars were set by checking if
|
||||
// current values differ from defaults. Defaults are: false for bools, 0 for durations.
|
||||
// If current value is at default, it likely wasn't set from env var, so we can apply file.
|
||||
// If current value is non-default, it was likely set from env var, so we preserve it.
|
||||
// Note: This means env vars explicitly setting to false/0 won't be distinguishable from defaults,
|
||||
// but that's an acceptable limitation to avoid env var duplication.
|
||||
|
||||
if settings.WatchdogIdleEnabled != nil {
|
||||
// Only apply if current value is default (false), suggesting it wasn't set from env var
|
||||
if !options.WatchDogIdle {
|
||||
options.WatchDogIdle = *settings.WatchdogIdleEnabled
|
||||
if options.WatchDogIdle {
|
||||
options.WatchDog = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if settings.WatchdogBusyEnabled != nil {
|
||||
if !options.WatchDogBusy {
|
||||
options.WatchDogBusy = *settings.WatchdogBusyEnabled
|
||||
if options.WatchDogBusy {
|
||||
options.WatchDog = true
|
||||
}
|
||||
}
|
||||
}
|
||||
if settings.WatchdogIdleTimeout != nil {
|
||||
// Only apply if current value is default (0), suggesting it wasn't set from env var
|
||||
if options.WatchDogIdleTimeout == 0 {
|
||||
dur, err := time.ParseDuration(*settings.WatchdogIdleTimeout)
|
||||
if err == nil {
|
||||
options.WatchDogIdleTimeout = dur
|
||||
} else {
|
||||
log.Warn().Err(err).Str("timeout", *settings.WatchdogIdleTimeout).Msg("invalid watchdog idle timeout in runtime_settings.json")
|
||||
}
|
||||
}
|
||||
}
|
||||
if settings.WatchdogBusyTimeout != nil {
|
||||
if options.WatchDogBusyTimeout == 0 {
|
||||
dur, err := time.ParseDuration(*settings.WatchdogBusyTimeout)
|
||||
if err == nil {
|
||||
options.WatchDogBusyTimeout = dur
|
||||
} else {
|
||||
log.Warn().Err(err).Str("timeout", *settings.WatchdogBusyTimeout).Msg("invalid watchdog busy timeout in runtime_settings.json")
|
||||
}
|
||||
}
|
||||
}
|
||||
if settings.SingleBackend != nil {
|
||||
if !options.SingleBackend {
|
||||
options.SingleBackend = *settings.SingleBackend
|
||||
}
|
||||
}
|
||||
if settings.ParallelBackendRequests != nil {
|
||||
if !options.ParallelBackendRequests {
|
||||
options.ParallelBackendRequests = *settings.ParallelBackendRequests
|
||||
}
|
||||
}
|
||||
if !options.WatchDogIdle && !options.WatchDogBusy {
|
||||
if settings.WatchdogEnabled != nil && *settings.WatchdogEnabled {
|
||||
options.WatchDog = true
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msg("Runtime settings loaded from runtime_settings.json")
|
||||
}
|
||||
|
||||
// initializeWatchdog initializes the watchdog with current ApplicationConfig settings
|
||||
func initializeWatchdog(application *Application, options *config.ApplicationConfig) {
|
||||
if options.WatchDog {
|
||||
wd := model.NewWatchDog(
|
||||
application.ModelLoader(),
|
||||
options.WatchDogBusyTimeout,
|
||||
options.WatchDogIdleTimeout,
|
||||
options.WatchDogBusy,
|
||||
options.WatchDogIdle)
|
||||
application.ModelLoader().SetWatchDog(wd)
|
||||
go wd.Run()
|
||||
go func() {
|
||||
<-options.Context.Done()
|
||||
log.Debug().Msgf("Context canceled, shutting down")
|
||||
wd.Shutdown()
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
88
core/application/watchdog.go
Normal file
88
core/application/watchdog.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func (a *Application) StopWatchdog() error {
|
||||
if a.watchdogStop != nil {
|
||||
close(a.watchdogStop)
|
||||
a.watchdogStop = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startWatchdog starts the watchdog with current ApplicationConfig settings
|
||||
// This is an internal method that assumes the caller holds the watchdogMutex
|
||||
func (a *Application) startWatchdog() error {
|
||||
appConfig := a.ApplicationConfig()
|
||||
|
||||
// Create new watchdog if enabled
|
||||
if appConfig.WatchDog {
|
||||
wd := model.NewWatchDog(
|
||||
a.modelLoader,
|
||||
appConfig.WatchDogBusyTimeout,
|
||||
appConfig.WatchDogIdleTimeout,
|
||||
appConfig.WatchDogBusy,
|
||||
appConfig.WatchDogIdle)
|
||||
a.modelLoader.SetWatchDog(wd)
|
||||
|
||||
// Create new stop channel
|
||||
a.watchdogStop = make(chan bool, 1)
|
||||
|
||||
// Start watchdog goroutine
|
||||
go wd.Run()
|
||||
|
||||
// Setup shutdown handler
|
||||
go func() {
|
||||
select {
|
||||
case <-a.watchdogStop:
|
||||
log.Debug().Msg("Watchdog stop signal received")
|
||||
wd.Shutdown()
|
||||
case <-appConfig.Context.Done():
|
||||
log.Debug().Msg("Context canceled, shutting down watchdog")
|
||||
wd.Shutdown()
|
||||
}
|
||||
}()
|
||||
|
||||
log.Info().Msg("Watchdog started with new settings")
|
||||
} else {
|
||||
log.Info().Msg("Watchdog disabled")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartWatchdog starts the watchdog with current ApplicationConfig settings
|
||||
func (a *Application) StartWatchdog() error {
|
||||
a.watchdogMutex.Lock()
|
||||
defer a.watchdogMutex.Unlock()
|
||||
|
||||
return a.startWatchdog()
|
||||
}
|
||||
|
||||
// RestartWatchdog restarts the watchdog with current ApplicationConfig settings
|
||||
func (a *Application) RestartWatchdog() error {
|
||||
a.watchdogMutex.Lock()
|
||||
defer a.watchdogMutex.Unlock()
|
||||
|
||||
// Shutdown existing watchdog if running
|
||||
if a.watchdogStop != nil {
|
||||
close(a.watchdogStop)
|
||||
a.watchdogStop = nil
|
||||
}
|
||||
|
||||
// Shutdown existing watchdog if running
|
||||
currentWD := a.modelLoader.GetWatchDog()
|
||||
if currentWD != nil {
|
||||
currentWD.Shutdown()
|
||||
// Wait a bit for shutdown to complete
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Start watchdog with new settings
|
||||
return a.startWatchdog()
|
||||
}
|
||||
@@ -40,3 +40,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
||||
|
||||
return fn, nil
|
||||
}
|
||||
|
||||
// ImageGenerationFunc is a test-friendly indirection to call image generation logic.
|
||||
// Tests can override this variable to provide a stub implementation.
|
||||
var ImageGenerationFunc = ImageGeneration
|
||||
|
||||
@@ -3,7 +3,6 @@ package backend
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -26,6 +25,7 @@ type LLMResponse struct {
|
||||
Response string // should this be []byte?
|
||||
Usage TokenUsage
|
||||
AudioOutput string
|
||||
Logprobs *schema.Logprobs // Logprobs from the backend response
|
||||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
@@ -35,7 +35,7 @@ type TokenUsage struct {
|
||||
TimingTokenGeneration float64
|
||||
}
|
||||
|
||||
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
||||
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (LLMResponse, error), error) {
|
||||
modelFile := c.Model
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
@@ -47,7 +47,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
if !slices.Contains(modelNames, c.Name) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
|
||||
//return nil, err
|
||||
@@ -65,29 +65,8 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
var protoMessages []*proto.Message
|
||||
// if we are using the tokenizer template, we need to convert the messages to proto messages
|
||||
// unless the prompt has already been tokenized (non-chat endpoints + functions)
|
||||
if c.TemplateConfig.UseTokenizerTemplate && s == "" {
|
||||
protoMessages = make([]*proto.Message, len(messages), len(messages))
|
||||
for i, message := range messages {
|
||||
protoMessages[i] = &proto.Message{
|
||||
Role: message.Role,
|
||||
}
|
||||
switch ct := message.Content.(type) {
|
||||
case string:
|
||||
protoMessages[i].Content = ct
|
||||
case []interface{}:
|
||||
// If using the tokenizer template, in case of multimodal we want to keep the multimodal content as and return only strings here
|
||||
data, _ := json.Marshal(ct)
|
||||
resultData := []struct {
|
||||
Text string `json:"text"`
|
||||
}{}
|
||||
json.Unmarshal(data, &resultData)
|
||||
for _, r := range resultData {
|
||||
protoMessages[i].Content += r.Text
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
|
||||
}
|
||||
}
|
||||
if c.TemplateConfig.UseTokenizerTemplate && len(messages) > 0 {
|
||||
protoMessages = messages.ToProto()
|
||||
}
|
||||
|
||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||
@@ -99,6 +78,21 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
opts.Images = images
|
||||
opts.Videos = videos
|
||||
opts.Audios = audios
|
||||
opts.Tools = tools
|
||||
opts.ToolChoice = toolChoice
|
||||
if logprobs != nil {
|
||||
opts.Logprobs = int32(*logprobs)
|
||||
}
|
||||
if topLogprobs != nil {
|
||||
opts.TopLogprobs = int32(*topLogprobs)
|
||||
}
|
||||
if len(logitBias) > 0 {
|
||||
// Serialize logit_bias map to JSON string for proto
|
||||
logitBiasJSON, err := json.Marshal(logitBias)
|
||||
if err == nil {
|
||||
opts.LogitBias = string(logitBiasJSON)
|
||||
}
|
||||
}
|
||||
|
||||
tokenUsage := TokenUsage{}
|
||||
|
||||
@@ -130,6 +124,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
}
|
||||
|
||||
ss := ""
|
||||
var logprobs *schema.Logprobs
|
||||
|
||||
var partialRune []byte
|
||||
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
|
||||
@@ -141,6 +136,14 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
|
||||
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
|
||||
|
||||
// Parse logprobs from reply if present (collect from last chunk that has them)
|
||||
if len(reply.Logprobs) > 0 {
|
||||
var parsedLogprobs schema.Logprobs
|
||||
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
|
||||
logprobs = &parsedLogprobs
|
||||
}
|
||||
}
|
||||
|
||||
// Process complete runes and accumulate them
|
||||
var completeRunes []byte
|
||||
for len(partialRune) > 0 {
|
||||
@@ -166,6 +169,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
return LLMResponse{
|
||||
Response: ss,
|
||||
Usage: tokenUsage,
|
||||
Logprobs: logprobs,
|
||||
}, err
|
||||
} else {
|
||||
// TODO: Is the chicken bit the only way to get here? is that acceptable?
|
||||
@@ -188,9 +192,19 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
response = c.TemplateConfig.ReplyPrefix + response
|
||||
}
|
||||
|
||||
// Parse logprobs from reply if present
|
||||
var logprobs *schema.Logprobs
|
||||
if len(reply.Logprobs) > 0 {
|
||||
var parsedLogprobs schema.Logprobs
|
||||
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
|
||||
logprobs = &parsedLogprobs
|
||||
}
|
||||
}
|
||||
|
||||
return LLMResponse{
|
||||
Response: response,
|
||||
Usage: tokenUsage,
|
||||
Logprobs: logprobs,
|
||||
}, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,7 +212,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
}
|
||||
}
|
||||
|
||||
return &pb.PredictOptions{
|
||||
pbOpts := &pb.PredictOptions{
|
||||
Temperature: float32(*c.Temperature),
|
||||
TopP: float32(*c.TopP),
|
||||
NDraft: c.NDraft,
|
||||
@@ -249,4 +249,6 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
TailFreeSamplingZ: float32(*c.TFZ),
|
||||
TypicalP: float32(*c.TypicalP),
|
||||
}
|
||||
// Logprobs and TopLogprobs are set by the caller if provided
|
||||
return pbOpts
|
||||
}
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
package cli_api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/edgevpn/pkg/node"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func StartP2PStack(ctx context.Context, address, token, networkID string, federated bool, app *application.Application) error {
|
||||
var n *node.Node
|
||||
// Here we are avoiding creating multiple nodes:
|
||||
// - if the federated mode is enabled, we create a federated node and expose a service
|
||||
// - exposing a service creates a node with specific options, and we don't want to create another node
|
||||
|
||||
// If the federated mode is enabled, we expose a service to the local instance running
|
||||
// at r.Address
|
||||
if federated {
|
||||
_, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Here a new node is created and started
|
||||
// and a service is exposed by the node
|
||||
node, err := p2p.ExposeService(ctx, "localhost", port, token, p2p.NetworkID(networkID, p2p.FederatedID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := p2p.ServiceDiscoverer(ctx, node, token, p2p.NetworkID(networkID, p2p.FederatedID), nil, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n = node
|
||||
|
||||
// start node sync in the background
|
||||
if err := p2p.Sync(ctx, node, app); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// If the p2p mode is enabled, we start the service discovery
|
||||
if token != "" {
|
||||
// If a node wasn't created previously, create it
|
||||
if n == nil {
|
||||
node, err := p2p.NewNode(token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = node.Start(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting new node: %w", err)
|
||||
}
|
||||
n = node
|
||||
}
|
||||
|
||||
// Attach a ServiceDiscoverer to the p2p node
|
||||
log.Info().Msg("Starting P2P server discovery...")
|
||||
if err := p2p.ServiceDiscoverer(ctx, n, token, p2p.NetworkID(networkID, p2p.WorkerID), func(serviceID string, node schema.NodeData) {
|
||||
var tunnelAddresses []string
|
||||
for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.WorkerID)) {
|
||||
if v.IsOnline() {
|
||||
tunnelAddresses = append(tunnelAddresses, v.TunnelAddress)
|
||||
} else {
|
||||
log.Info().Msgf("Node %s is offline", v.ID)
|
||||
}
|
||||
}
|
||||
tunnelEnvVar := strings.Join(tunnelAddresses, ",")
|
||||
|
||||
os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar)
|
||||
log.Debug().Msgf("setting LLAMACPP_GRPC_SERVERS to %s", tunnelEnvVar)
|
||||
}, true); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
@@ -102,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState, true)
|
||||
err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
err = startup.InstallExternalBackends(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -48,10 +48,12 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
|
||||
appHTTP := http.Explorer(db)
|
||||
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
if err := appHTTP.Shutdown(); err != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := appHTTP.Shutdown(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("error during shutdown")
|
||||
}
|
||||
})
|
||||
|
||||
return appHTTP.Listen(e.Address)
|
||||
return appHTTP.Start(e.Address)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/startup"
|
||||
@@ -78,6 +80,12 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
galleryService := services.NewGalleryService(&config.ApplicationConfig{}, model.NewModelLoader(systemState, true))
|
||||
err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var galleries []config.Gallery
|
||||
if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil {
|
||||
log.Error().Err(err).Msg("unable to load galleries")
|
||||
@@ -127,7 +135,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState, true)
|
||||
err = startup.InstallModels(galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
cli_api "github.com/mudler/LocalAI/core/cli/api"
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http"
|
||||
@@ -52,6 +51,7 @@ type RunCMD struct {
|
||||
UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"`
|
||||
APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"`
|
||||
DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface" group:"api"`
|
||||
DisableRuntimeSettings bool `env:"LOCALAI_DISABLE_RUNTIME_SETTINGS,DISABLE_RUNTIME_SETTINGS" default:"false" help:"Disables the runtime settings. When set to true, the server will not load the runtime settings from the runtime_settings.json file" group:"api"`
|
||||
DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"`
|
||||
OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"`
|
||||
UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"`
|
||||
@@ -98,6 +98,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
opts := []config.AppOption{
|
||||
config.WithContext(context.Background()),
|
||||
config.WithConfigFile(r.ModelsConfigFile),
|
||||
config.WithJSONStringPreload(r.PreloadModels),
|
||||
config.WithYAMLConfigPreload(r.PreloadModelsConfig),
|
||||
@@ -128,12 +129,22 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
config.WithLoadToMemory(r.LoadToMemory),
|
||||
config.WithMachineTag(r.MachineTag),
|
||||
config.WithAPIAddress(r.Address),
|
||||
config.WithTunnelCallback(func(tunnels []string) {
|
||||
tunnelEnvVar := strings.Join(tunnels, ",")
|
||||
// TODO: this is very specific to llama.cpp, we should have a more generic way to set the environment variable
|
||||
os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar)
|
||||
log.Debug().Msgf("setting LLAMACPP_GRPC_SERVERS to %s", tunnelEnvVar)
|
||||
}),
|
||||
}
|
||||
|
||||
if r.DisableMetricsEndpoint {
|
||||
opts = append(opts, config.DisableMetricsEndpoint)
|
||||
}
|
||||
|
||||
if r.DisableRuntimeSettings {
|
||||
opts = append(opts, config.DisableRuntimeSettings)
|
||||
}
|
||||
|
||||
token := ""
|
||||
if r.Peer2Peer || r.Peer2PeerToken != "" {
|
||||
log.Info().Msg("P2P mode enabled")
|
||||
@@ -152,7 +163,9 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
opts = append(opts, config.WithP2PToken(token))
|
||||
}
|
||||
|
||||
backgroundCtx := context.Background()
|
||||
if r.Federated {
|
||||
opts = append(opts, config.EnableFederated)
|
||||
}
|
||||
|
||||
idleWatchDog := r.EnableWatchdogIdle
|
||||
busyWatchDog := r.EnableWatchdogBusy
|
||||
@@ -222,8 +235,10 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cli_api.StartP2PStack(backgroundCtx, r.Address, token, r.Peer2PeerNetworkID, r.Federated, app); err != nil {
|
||||
return err
|
||||
if token != "" {
|
||||
if err := app.StartP2P(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
@@ -232,5 +247,5 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
})
|
||||
|
||||
return appHTTP.Listen(r.Address)
|
||||
return appHTTP.Start(r.Address)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -42,7 +43,7 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
|
||||
log.Error().Err(err).Msg("failed loading galleries")
|
||||
return "", err
|
||||
}
|
||||
err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
|
||||
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
|
||||
return "", err
|
||||
|
||||
@@ -33,6 +33,7 @@ type ApplicationConfig struct {
|
||||
ApiKeys []string
|
||||
P2PToken string
|
||||
P2PNetworkID string
|
||||
Federated bool
|
||||
|
||||
DisableWebUI bool
|
||||
EnforcePredownloadScans bool
|
||||
@@ -65,6 +66,10 @@ type ApplicationConfig struct {
|
||||
MachineTag string
|
||||
|
||||
APIAddress string
|
||||
|
||||
TunnelCallback func(tunnels []string)
|
||||
|
||||
DisableRuntimeSettings bool
|
||||
}
|
||||
|
||||
type AppOption func(*ApplicationConfig)
|
||||
@@ -73,7 +78,6 @@ func NewApplicationConfig(o ...AppOption) *ApplicationConfig {
|
||||
opt := &ApplicationConfig{
|
||||
Context: context.Background(),
|
||||
UploadLimitMB: 15,
|
||||
ContextSize: 512,
|
||||
Debug: true,
|
||||
}
|
||||
for _, oo := range o {
|
||||
@@ -152,6 +156,10 @@ var DisableWebUI = func(o *ApplicationConfig) {
|
||||
o.DisableWebUI = true
|
||||
}
|
||||
|
||||
var DisableRuntimeSettings = func(o *ApplicationConfig) {
|
||||
o.DisableRuntimeSettings = true
|
||||
}
|
||||
|
||||
func SetWatchDogBusyTimeout(t time.Duration) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.WatchDogBusyTimeout = t
|
||||
@@ -180,6 +188,10 @@ var EnableBackendGalleriesAutoload = func(o *ApplicationConfig) {
|
||||
o.AutoloadBackendGalleries = true
|
||||
}
|
||||
|
||||
var EnableFederated = func(o *ApplicationConfig) {
|
||||
o.Federated = true
|
||||
}
|
||||
|
||||
func WithExternalBackend(name string, uri string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
if o.ExternalGRPCBackends == nil {
|
||||
@@ -273,6 +285,12 @@ func WithContextSize(ctxSize int) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
func WithTunnelCallback(callback func(tunnels []string)) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.TunnelCallback = callback
|
||||
}
|
||||
}
|
||||
|
||||
func WithF16(f16 bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.F16 = f16
|
||||
|
||||
@@ -1,151 +1,17 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
)
|
||||
|
||||
type familyType uint8
|
||||
|
||||
const (
|
||||
Unknown familyType = iota
|
||||
LLaMa3
|
||||
CommandR
|
||||
Phi3
|
||||
ChatML
|
||||
Mistral03
|
||||
Gemma
|
||||
DeepSeek2
|
||||
)
|
||||
|
||||
const (
|
||||
defaultContextSize = 1024
|
||||
defaultNGPULayers = 99999999
|
||||
)
|
||||
|
||||
type settingsConfig struct {
|
||||
StopWords []string
|
||||
TemplateConfig TemplateConfig
|
||||
RepeatPenalty float64
|
||||
}
|
||||
|
||||
// default settings to adopt with a given model family
|
||||
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
|
||||
Gemma: {
|
||||
RepeatPenalty: 1.0,
|
||||
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "{{.Input }}\n<start_of_turn>model\n",
|
||||
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
|
||||
Completion: "{{.Input}}",
|
||||
},
|
||||
},
|
||||
DeepSeek2: {
|
||||
StopWords: []string{"<|end▁of▁sentence|>"},
|
||||
TemplateConfig: TemplateConfig{
|
||||
ChatMessage: `{{if eq .RoleName "user" -}}User: {{.Content }}
|
||||
{{ end -}}
|
||||
{{if eq .RoleName "assistant" -}}Assistant: {{.Content}}<|end▁of▁sentence|>{{end}}
|
||||
{{if eq .RoleName "system" -}}{{.Content}}
|
||||
{{end -}}`,
|
||||
Chat: "{{.Input -}}\nAssistant: ",
|
||||
},
|
||||
},
|
||||
LLaMa3: {
|
||||
StopWords: []string{"<|eot_id|>"},
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "<|begin_of_text|>{{.Input }}\n<|start_header_id|>assistant<|end_header_id|>",
|
||||
ChatMessage: "<|start_header_id|>{{ .RoleName }}<|end_header_id|>\n\n{{.Content }}<|eot_id|>",
|
||||
},
|
||||
},
|
||||
CommandR: {
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "{{.Input -}}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
||||
Functions: `<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
|
||||
You are a function calling AI model, you can call the following functions:
|
||||
## Available Tools
|
||||
{{range .Functions}}
|
||||
- {"type": "function", "function": {"name": "{{.Name}}", "description": "{{.Description}}", "parameters": {{toJson .Parameters}} }}
|
||||
{{end}}
|
||||
When using a tool, reply with JSON, for instance {"name": "tool_name", "arguments": {"param1": "value1", "param2": "value2"}}
|
||||
<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{.Input -}}`,
|
||||
ChatMessage: `{{if eq .RoleName "user" -}}
|
||||
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
|
||||
{{- else if eq .RoleName "system" -}}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
|
||||
{{- else if eq .RoleName "assistant" -}}
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
|
||||
{{- else if eq .RoleName "tool" -}}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
|
||||
{{- else if .FunctionCall -}}
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{toJson .FunctionCall}}}<|END_OF_TURN_TOKEN|>
|
||||
{{- end -}}`,
|
||||
},
|
||||
StopWords: []string{"<|END_OF_TURN_TOKEN|>"},
|
||||
},
|
||||
Phi3: {
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "{{.Input}}\n<|assistant|>",
|
||||
ChatMessage: "<|{{ .RoleName }}|>\n{{.Content}}<|end|>",
|
||||
Completion: "{{.Input}}",
|
||||
},
|
||||
StopWords: []string{"<|end|>", "<|endoftext|>"},
|
||||
},
|
||||
ChatML: {
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "{{.Input -}}\n<|im_start|>assistant",
|
||||
Functions: `<|im_start|>system
|
||||
You are a function calling AI model. You are provided with functions to execute. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:
|
||||
{{range .Functions}}
|
||||
{'type': 'function', 'function': {'name': '{{.Name}}', 'description': '{{.Description}}', 'parameters': {{toJson .Parameters}} }}
|
||||
{{end}}
|
||||
For each function call return a json object with function name and arguments
|
||||
<|im_end|>
|
||||
{{.Input -}}
|
||||
<|im_start|>assistant`,
|
||||
ChatMessage: `<|im_start|>{{ .RoleName }}
|
||||
{{ if .FunctionCall -}}
|
||||
Function call:
|
||||
{{ else if eq .RoleName "tool" -}}
|
||||
Function response:
|
||||
{{ end -}}
|
||||
{{ if .Content -}}
|
||||
{{.Content }}
|
||||
{{ end -}}
|
||||
{{ if .FunctionCall -}}
|
||||
{{toJson .FunctionCall}}
|
||||
{{ end -}}<|im_end|>`,
|
||||
},
|
||||
StopWords: []string{"<|im_end|>", "<dummy32000>", "</s>"},
|
||||
},
|
||||
Mistral03: {
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "{{.Input -}}",
|
||||
Functions: `[AVAILABLE_TOOLS] [{{range .Functions}}{"type": "function", "function": {"name": "{{.Name}}", "description": "{{.Description}}", "parameters": {{toJson .Parameters}} }}{{end}} ] [/AVAILABLE_TOOLS]{{.Input }}`,
|
||||
ChatMessage: `{{if eq .RoleName "user" -}}
|
||||
[INST] {{.Content }} [/INST]
|
||||
{{- else if .FunctionCall -}}
|
||||
[TOOL_CALLS] {{toJson .FunctionCall}} [/TOOL_CALLS]
|
||||
{{- else if eq .RoleName "tool" -}}
|
||||
[TOOL_RESULTS] {{.Content}} [/TOOL_RESULTS]
|
||||
{{- else -}}
|
||||
{{ .Content -}}
|
||||
{{ end -}}`,
|
||||
},
|
||||
StopWords: []string{"<|im_end|>", "<dummy32000>", "</tool_call>", "<|eot_id|>", "<|end_of_text|>", "</s>", "[/TOOL_CALLS]", "[/ACTIONS]"},
|
||||
},
|
||||
}
|
||||
|
||||
// this maps well known template used in HF to model families defined above
|
||||
var knownTemplates = map[string]familyType{
|
||||
`{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\n' + content + '<|im_end|>\n<|im_start|>assistant\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\n' }}{% endif %}{% endfor %}`: ChatML,
|
||||
`{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}`: Mistral03,
|
||||
}
|
||||
|
||||
func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
|
||||
if defaultCtx == 0 && cfg.ContextSize == nil {
|
||||
@@ -216,81 +82,9 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
cfg.Name = f.Metadata().Name
|
||||
}
|
||||
|
||||
family := identifyFamily(f)
|
||||
|
||||
if family == Unknown {
|
||||
log.Debug().Msgf("guessDefaultsFromFile: %s", "family not identified")
|
||||
return
|
||||
}
|
||||
|
||||
// identify template
|
||||
settings, ok := defaultsSettings[family]
|
||||
if ok {
|
||||
cfg.TemplateConfig = settings.TemplateConfig
|
||||
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: guessed template %+v", cfg.TemplateConfig)
|
||||
if len(cfg.StopWords) == 0 {
|
||||
cfg.StopWords = settings.StopWords
|
||||
}
|
||||
if cfg.RepeatPenalty == 0.0 {
|
||||
cfg.RepeatPenalty = settings.RepeatPenalty
|
||||
}
|
||||
} else {
|
||||
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family")
|
||||
}
|
||||
|
||||
if cfg.HasTemplate() {
|
||||
return
|
||||
}
|
||||
|
||||
// identify from well known templates first, otherwise use the raw jinja template
|
||||
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
|
||||
if found {
|
||||
// try to use the jinja template
|
||||
cfg.TemplateConfig.JinjaTemplate = true
|
||||
cfg.TemplateConfig.ChatMessage = chatTemplate.ValueString()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func identifyFamily(f *gguf.GGUFFile) familyType {
|
||||
|
||||
// identify from well known templates first
|
||||
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
|
||||
if found && chatTemplate.ValueString() != "" {
|
||||
if family, ok := knownTemplates[chatTemplate.ValueString()]; ok {
|
||||
return family
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise try to identify from the model properties
|
||||
arch := f.Architecture().Architecture
|
||||
eosTokenID := f.Tokenizer().EOSTokenID
|
||||
bosTokenID := f.Tokenizer().BOSTokenID
|
||||
|
||||
isYI := arch == "llama" && bosTokenID == 1 && eosTokenID == 2
|
||||
// WTF! Mistral0.3 and isYi have same bosTokenID and eosTokenID
|
||||
|
||||
llama3 := arch == "llama" && eosTokenID == 128009
|
||||
commandR := arch == "command-r" && eosTokenID == 255001
|
||||
qwen2 := arch == "qwen2"
|
||||
phi3 := arch == "phi-3"
|
||||
gemma := strings.HasPrefix(arch, "gemma") || strings.Contains(strings.ToLower(f.Metadata().Name), "gemma")
|
||||
deepseek2 := arch == "deepseek2"
|
||||
|
||||
switch {
|
||||
case deepseek2:
|
||||
return DeepSeek2
|
||||
case gemma:
|
||||
return Gemma
|
||||
case llama3:
|
||||
return LLaMa3
|
||||
case commandR:
|
||||
return CommandR
|
||||
case phi3:
|
||||
return Phi3
|
||||
case qwen2, isYI:
|
||||
return ChatML
|
||||
default:
|
||||
return Unknown
|
||||
}
|
||||
// Instruct to use template from llama.cpp
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.FunctionsConfig.GrammarConfig.NoGrammar = true
|
||||
cfg.Options = append(cfg.Options, "use_jinja:true")
|
||||
cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"slices"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/cogito"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -16,30 +18,31 @@ const (
|
||||
RAND_SEED = -1
|
||||
)
|
||||
|
||||
// @Description TTS configuration
|
||||
type TTSConfig struct {
|
||||
|
||||
// Voice wav path or id
|
||||
Voice string `yaml:"voice" json:"voice"`
|
||||
Voice string `yaml:"voice,omitempty" json:"voice,omitempty"`
|
||||
|
||||
AudioPath string `yaml:"audio_path" json:"audio_path"`
|
||||
AudioPath string `yaml:"audio_path,omitempty" json:"audio_path,omitempty"`
|
||||
}
|
||||
|
||||
// ModelConfig represents a model configuration
|
||||
// @Description ModelConfig represents a model configuration
|
||||
type ModelConfig struct {
|
||||
modelConfigFile string `yaml:"-" json:"-"`
|
||||
schema.PredictionOptions `yaml:"parameters" json:"parameters"`
|
||||
Name string `yaml:"name" json:"name"`
|
||||
schema.PredictionOptions `yaml:"parameters,omitempty" json:"parameters,omitempty"`
|
||||
Name string `yaml:"name,omitempty" json:"name,omitempty"`
|
||||
|
||||
F16 *bool `yaml:"f16" json:"f16"`
|
||||
Threads *int `yaml:"threads" json:"threads"`
|
||||
Debug *bool `yaml:"debug" json:"debug"`
|
||||
Roles map[string]string `yaml:"roles" json:"roles"`
|
||||
Embeddings *bool `yaml:"embeddings" json:"embeddings"`
|
||||
Backend string `yaml:"backend" json:"backend"`
|
||||
TemplateConfig TemplateConfig `yaml:"template" json:"template"`
|
||||
KnownUsecaseStrings []string `yaml:"known_usecases" json:"known_usecases"`
|
||||
F16 *bool `yaml:"f16,omitempty" json:"f16,omitempty"`
|
||||
Threads *int `yaml:"threads,omitempty" json:"threads,omitempty"`
|
||||
Debug *bool `yaml:"debug,omitempty" json:"debug,omitempty"`
|
||||
Roles map[string]string `yaml:"roles,omitempty" json:"roles,omitempty"`
|
||||
Embeddings *bool `yaml:"embeddings,omitempty" json:"embeddings,omitempty"`
|
||||
Backend string `yaml:"backend,omitempty" json:"backend,omitempty"`
|
||||
TemplateConfig TemplateConfig `yaml:"template,omitempty" json:"template,omitempty"`
|
||||
KnownUsecaseStrings []string `yaml:"known_usecases,omitempty" json:"known_usecases,omitempty"`
|
||||
KnownUsecases *ModelConfigUsecases `yaml:"-" json:"-"`
|
||||
Pipeline Pipeline `yaml:"pipeline" json:"pipeline"`
|
||||
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`
|
||||
|
||||
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
|
||||
InputToken [][]int `yaml:"-" json:"-"`
|
||||
@@ -47,96 +50,101 @@ type ModelConfig struct {
|
||||
ResponseFormat string `yaml:"-" json:"-"`
|
||||
ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"`
|
||||
|
||||
FunctionsConfig functions.FunctionsConfig `yaml:"function" json:"function"`
|
||||
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
|
||||
|
||||
FeatureFlag FeatureFlag `yaml:"feature_flags" json:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
|
||||
FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
|
||||
// LLM configs (GPT4ALL, Llama.cpp, ...)
|
||||
LLMConfig `yaml:",inline" json:",inline"`
|
||||
|
||||
// Diffusers
|
||||
Diffusers Diffusers `yaml:"diffusers" json:"diffusers"`
|
||||
Step int `yaml:"step" json:"step"`
|
||||
Diffusers Diffusers `yaml:"diffusers,omitempty" json:"diffusers,omitempty"`
|
||||
Step int `yaml:"step,omitempty" json:"step,omitempty"`
|
||||
|
||||
// GRPC Options
|
||||
GRPC GRPC `yaml:"grpc" json:"grpc"`
|
||||
GRPC GRPC `yaml:"grpc,omitempty" json:"grpc,omitempty"`
|
||||
|
||||
// TTS specifics
|
||||
TTSConfig `yaml:"tts" json:"tts"`
|
||||
TTSConfig `yaml:"tts,omitempty" json:"tts,omitempty"`
|
||||
|
||||
// CUDA
|
||||
// Explicitly enable CUDA or not (some backends might need it)
|
||||
CUDA bool `yaml:"cuda" json:"cuda"`
|
||||
CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"`
|
||||
|
||||
DownloadFiles []File `yaml:"download_files" json:"download_files"`
|
||||
DownloadFiles []File `yaml:"download_files,omitempty" json:"download_files,omitempty"`
|
||||
|
||||
Description string `yaml:"description" json:"description"`
|
||||
Usage string `yaml:"usage" json:"usage"`
|
||||
Description string `yaml:"description,omitempty" json:"description,omitempty"`
|
||||
Usage string `yaml:"usage,omitempty" json:"usage,omitempty"`
|
||||
|
||||
Options []string `yaml:"options" json:"options"`
|
||||
Overrides []string `yaml:"overrides" json:"overrides"`
|
||||
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
|
||||
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
|
||||
|
||||
MCP MCPConfig `yaml:"mcp" json:"mcp"`
|
||||
Agent AgentConfig `yaml:"agent" json:"agent"`
|
||||
MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"`
|
||||
Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"`
|
||||
}
|
||||
|
||||
// @Description MCP configuration
|
||||
type MCPConfig struct {
|
||||
Servers string `yaml:"remote" json:"remote"`
|
||||
Stdio string `yaml:"stdio" json:"stdio"`
|
||||
Servers string `yaml:"remote,omitempty" json:"remote,omitempty"`
|
||||
Stdio string `yaml:"stdio,omitempty" json:"stdio,omitempty"`
|
||||
}
|
||||
|
||||
// @Description Agent configuration
|
||||
type AgentConfig struct {
|
||||
MaxAttempts int `yaml:"max_attempts" json:"max_attempts"`
|
||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||
EnableReasoning bool `yaml:"enable_reasoning" json:"enable_reasoning"`
|
||||
EnablePlanning bool `yaml:"enable_planning" json:"enable_planning"`
|
||||
EnableMCPPrompts bool `yaml:"enable_mcp_prompts" json:"enable_mcp_prompts"`
|
||||
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator" json:"enable_plan_re_evaluator"`
|
||||
MaxAttempts int `yaml:"max_attempts,omitempty" json:"max_attempts,omitempty"`
|
||||
MaxIterations int `yaml:"max_iterations,omitempty" json:"max_iterations,omitempty"`
|
||||
EnableReasoning bool `yaml:"enable_reasoning,omitempty" json:"enable_reasoning,omitempty"`
|
||||
EnablePlanning bool `yaml:"enable_planning,omitempty" json:"enable_planning,omitempty"`
|
||||
EnableMCPPrompts bool `yaml:"enable_mcp_prompts,omitempty" json:"enable_mcp_prompts,omitempty"`
|
||||
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator,omitempty" json:"enable_plan_re_evaluator,omitempty"`
|
||||
}
|
||||
|
||||
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers]) {
|
||||
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
|
||||
var remote MCPGenericConfig[MCPRemoteServers]
|
||||
var stdio MCPGenericConfig[MCPSTDIOServers]
|
||||
|
||||
if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil {
|
||||
return remote, stdio
|
||||
return remote, stdio, err
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil {
|
||||
return remote, stdio
|
||||
return remote, stdio, err
|
||||
}
|
||||
|
||||
return remote, stdio
|
||||
return remote, stdio, nil
|
||||
}
|
||||
|
||||
// @Description MCP generic configuration
|
||||
type MCPGenericConfig[T any] struct {
|
||||
Servers T `yaml:"mcpServers" json:"mcpServers"`
|
||||
Servers T `yaml:"mcpServers,omitempty" json:"mcpServers,omitempty"`
|
||||
}
|
||||
type MCPRemoteServers map[string]MCPRemoteServer
|
||||
type MCPSTDIOServers map[string]MCPSTDIOServer
|
||||
|
||||
// @Description MCP remote server configuration
|
||||
type MCPRemoteServer struct {
|
||||
URL string `json:"url"`
|
||||
Token string `json:"token"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
}
|
||||
|
||||
// @Description MCP STDIO server configuration
|
||||
type MCPSTDIOServer struct {
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
}
|
||||
|
||||
// Pipeline defines other models to use for audio-to-audio
|
||||
// @Description Pipeline defines other models to use for audio-to-audio
|
||||
type Pipeline struct {
|
||||
TTS string `yaml:"tts" json:"tts"`
|
||||
LLM string `yaml:"llm" json:"llm"`
|
||||
Transcription string `yaml:"transcription" json:"transcription"`
|
||||
VAD string `yaml:"vad" json:"vad"`
|
||||
TTS string `yaml:"tts,omitempty" json:"tts,omitempty"`
|
||||
LLM string `yaml:"llm,omitempty" json:"llm,omitempty"`
|
||||
Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"`
|
||||
VAD string `yaml:"vad,omitempty" json:"vad,omitempty"`
|
||||
}
|
||||
|
||||
// @Description File configuration for model downloads
|
||||
type File struct {
|
||||
Filename string `yaml:"filename" json:"filename"`
|
||||
SHA256 string `yaml:"sha256" json:"sha256"`
|
||||
URI downloader.URI `yaml:"uri" json:"uri"`
|
||||
Filename string `yaml:"filename,omitempty" json:"filename,omitempty"`
|
||||
SHA256 string `yaml:"sha256,omitempty" json:"sha256,omitempty"`
|
||||
URI downloader.URI `yaml:"uri,omitempty" json:"uri,omitempty"`
|
||||
}
|
||||
|
||||
type FeatureFlag map[string]*bool
|
||||
@@ -148,126 +156,136 @@ func (ff FeatureFlag) Enabled(s string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// @Description GRPC configuration
|
||||
type GRPC struct {
|
||||
Attempts int `yaml:"attempts" json:"attempts"`
|
||||
AttemptsSleepTime int `yaml:"attempts_sleep_time" json:"attempts_sleep_time"`
|
||||
Attempts int `yaml:"attempts,omitempty" json:"attempts,omitempty"`
|
||||
AttemptsSleepTime int `yaml:"attempts_sleep_time,omitempty" json:"attempts_sleep_time,omitempty"`
|
||||
}
|
||||
|
||||
// @Description Diffusers configuration
|
||||
type Diffusers struct {
|
||||
CUDA bool `yaml:"cuda" json:"cuda"`
|
||||
PipelineType string `yaml:"pipeline_type" json:"pipeline_type"`
|
||||
SchedulerType string `yaml:"scheduler_type" json:"scheduler_type"`
|
||||
EnableParameters string `yaml:"enable_parameters" json:"enable_parameters"` // A list of comma separated parameters to specify
|
||||
IMG2IMG bool `yaml:"img2img" json:"img2img"` // Image to Image Diffuser
|
||||
ClipSkip int `yaml:"clip_skip" json:"clip_skip"` // Skip every N frames
|
||||
ClipModel string `yaml:"clip_model" json:"clip_model"` // Clip model to use
|
||||
ClipSubFolder string `yaml:"clip_subfolder" json:"clip_subfolder"` // Subfolder to use for clip model
|
||||
ControlNet string `yaml:"control_net" json:"control_net"`
|
||||
CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"`
|
||||
PipelineType string `yaml:"pipeline_type,omitempty" json:"pipeline_type,omitempty"`
|
||||
SchedulerType string `yaml:"scheduler_type,omitempty" json:"scheduler_type,omitempty"`
|
||||
EnableParameters string `yaml:"enable_parameters,omitempty" json:"enable_parameters,omitempty"` // A list of comma separated parameters to specify
|
||||
IMG2IMG bool `yaml:"img2img,omitempty" json:"img2img,omitempty"` // Image to Image Diffuser
|
||||
ClipSkip int `yaml:"clip_skip,omitempty" json:"clip_skip,omitempty"` // Skip every N frames
|
||||
ClipModel string `yaml:"clip_model,omitempty" json:"clip_model,omitempty"` // Clip model to use
|
||||
ClipSubFolder string `yaml:"clip_subfolder,omitempty" json:"clip_subfolder,omitempty"` // Subfolder to use for clip model
|
||||
ControlNet string `yaml:"control_net,omitempty" json:"control_net,omitempty"`
|
||||
}
|
||||
|
||||
// LLMConfig is a struct that holds the configuration that are
|
||||
// generic for most of the LLM backends.
|
||||
// @Description LLMConfig is a struct that holds the configuration that are generic for most of the LLM backends.
|
||||
type LLMConfig struct {
|
||||
SystemPrompt string `yaml:"system_prompt" json:"system_prompt"`
|
||||
TensorSplit string `yaml:"tensor_split" json:"tensor_split"`
|
||||
MainGPU string `yaml:"main_gpu" json:"main_gpu"`
|
||||
RMSNormEps float32 `yaml:"rms_norm_eps" json:"rms_norm_eps"`
|
||||
NGQA int32 `yaml:"ngqa" json:"ngqa"`
|
||||
PromptCachePath string `yaml:"prompt_cache_path" json:"prompt_cache_path"`
|
||||
PromptCacheAll bool `yaml:"prompt_cache_all" json:"prompt_cache_all"`
|
||||
PromptCacheRO bool `yaml:"prompt_cache_ro" json:"prompt_cache_ro"`
|
||||
MirostatETA *float64 `yaml:"mirostat_eta" json:"mirostat_eta"`
|
||||
MirostatTAU *float64 `yaml:"mirostat_tau" json:"mirostat_tau"`
|
||||
Mirostat *int `yaml:"mirostat" json:"mirostat"`
|
||||
NGPULayers *int `yaml:"gpu_layers" json:"gpu_layers"`
|
||||
MMap *bool `yaml:"mmap" json:"mmap"`
|
||||
MMlock *bool `yaml:"mmlock" json:"mmlock"`
|
||||
LowVRAM *bool `yaml:"low_vram" json:"low_vram"`
|
||||
Reranking *bool `yaml:"reranking" json:"reranking"`
|
||||
Grammar string `yaml:"grammar" json:"grammar"`
|
||||
StopWords []string `yaml:"stopwords" json:"stopwords"`
|
||||
Cutstrings []string `yaml:"cutstrings" json:"cutstrings"`
|
||||
ExtractRegex []string `yaml:"extract_regex" json:"extract_regex"`
|
||||
TrimSpace []string `yaml:"trimspace" json:"trimspace"`
|
||||
TrimSuffix []string `yaml:"trimsuffix" json:"trimsuffix"`
|
||||
SystemPrompt string `yaml:"system_prompt,omitempty" json:"system_prompt,omitempty"`
|
||||
TensorSplit string `yaml:"tensor_split,omitempty" json:"tensor_split,omitempty"`
|
||||
MainGPU string `yaml:"main_gpu,omitempty" json:"main_gpu,omitempty"`
|
||||
RMSNormEps float32 `yaml:"rms_norm_eps,omitempty" json:"rms_norm_eps,omitempty"`
|
||||
NGQA int32 `yaml:"ngqa,omitempty" json:"ngqa,omitempty"`
|
||||
PromptCachePath string `yaml:"prompt_cache_path,omitempty" json:"prompt_cache_path,omitempty"`
|
||||
PromptCacheAll bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"`
|
||||
PromptCacheRO bool `yaml:"prompt_cache_ro,omitempty" json:"prompt_cache_ro,omitempty"`
|
||||
MirostatETA *float64 `yaml:"mirostat_eta,omitempty" json:"mirostat_eta,omitempty"`
|
||||
MirostatTAU *float64 `yaml:"mirostat_tau,omitempty" json:"mirostat_tau,omitempty"`
|
||||
Mirostat *int `yaml:"mirostat,omitempty" json:"mirostat,omitempty"`
|
||||
NGPULayers *int `yaml:"gpu_layers,omitempty" json:"gpu_layers,omitempty"`
|
||||
MMap *bool `yaml:"mmap,omitempty" json:"mmap,omitempty"`
|
||||
MMlock *bool `yaml:"mmlock,omitempty" json:"mmlock,omitempty"`
|
||||
LowVRAM *bool `yaml:"low_vram,omitempty" json:"low_vram,omitempty"`
|
||||
Reranking *bool `yaml:"reranking,omitempty" json:"reranking,omitempty"`
|
||||
Grammar string `yaml:"grammar,omitempty" json:"grammar,omitempty"`
|
||||
StopWords []string `yaml:"stopwords,omitempty" json:"stopwords,omitempty"`
|
||||
Cutstrings []string `yaml:"cutstrings,omitempty" json:"cutstrings,omitempty"`
|
||||
ExtractRegex []string `yaml:"extract_regex,omitempty" json:"extract_regex,omitempty"`
|
||||
TrimSpace []string `yaml:"trimspace,omitempty" json:"trimspace,omitempty"`
|
||||
TrimSuffix []string `yaml:"trimsuffix,omitempty" json:"trimsuffix,omitempty"`
|
||||
|
||||
ContextSize *int `yaml:"context_size" json:"context_size"`
|
||||
NUMA bool `yaml:"numa" json:"numa"`
|
||||
LoraAdapter string `yaml:"lora_adapter" json:"lora_adapter"`
|
||||
LoraBase string `yaml:"lora_base" json:"lora_base"`
|
||||
LoraAdapters []string `yaml:"lora_adapters" json:"lora_adapters"`
|
||||
LoraScales []float32 `yaml:"lora_scales" json:"lora_scales"`
|
||||
LoraScale float32 `yaml:"lora_scale" json:"lora_scale"`
|
||||
NoMulMatQ bool `yaml:"no_mulmatq" json:"no_mulmatq"`
|
||||
DraftModel string `yaml:"draft_model" json:"draft_model"`
|
||||
NDraft int32 `yaml:"n_draft" json:"n_draft"`
|
||||
Quantization string `yaml:"quantization" json:"quantization"`
|
||||
LoadFormat string `yaml:"load_format" json:"load_format"`
|
||||
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization" json:"gpu_memory_utilization"` // vLLM
|
||||
TrustRemoteCode bool `yaml:"trust_remote_code" json:"trust_remote_code"` // vLLM
|
||||
EnforceEager bool `yaml:"enforce_eager" json:"enforce_eager"` // vLLM
|
||||
SwapSpace int `yaml:"swap_space" json:"swap_space"` // vLLM
|
||||
MaxModelLen int `yaml:"max_model_len" json:"max_model_len"` // vLLM
|
||||
TensorParallelSize int `yaml:"tensor_parallel_size" json:"tensor_parallel_size"` // vLLM
|
||||
DisableLogStatus bool `yaml:"disable_log_stats" json:"disable_log_stats"` // vLLM
|
||||
DType string `yaml:"dtype" json:"dtype"` // vLLM
|
||||
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt" json:"limit_mm_per_prompt"` // vLLM
|
||||
MMProj string `yaml:"mmproj" json:"mmproj"`
|
||||
ContextSize *int `yaml:"context_size,omitempty" json:"context_size,omitempty"`
|
||||
NUMA bool `yaml:"numa,omitempty" json:"numa,omitempty"`
|
||||
LoraAdapter string `yaml:"lora_adapter,omitempty" json:"lora_adapter,omitempty"`
|
||||
LoraBase string `yaml:"lora_base,omitempty" json:"lora_base,omitempty"`
|
||||
LoraAdapters []string `yaml:"lora_adapters,omitempty" json:"lora_adapters,omitempty"`
|
||||
LoraScales []float32 `yaml:"lora_scales,omitempty" json:"lora_scales,omitempty"`
|
||||
LoraScale float32 `yaml:"lora_scale,omitempty" json:"lora_scale,omitempty"`
|
||||
NoMulMatQ bool `yaml:"no_mulmatq,omitempty" json:"no_mulmatq,omitempty"`
|
||||
DraftModel string `yaml:"draft_model,omitempty" json:"draft_model,omitempty"`
|
||||
NDraft int32 `yaml:"n_draft,omitempty" json:"n_draft,omitempty"`
|
||||
Quantization string `yaml:"quantization,omitempty" json:"quantization,omitempty"`
|
||||
LoadFormat string `yaml:"load_format,omitempty" json:"load_format,omitempty"`
|
||||
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization,omitempty" json:"gpu_memory_utilization,omitempty"` // vLLM
|
||||
TrustRemoteCode bool `yaml:"trust_remote_code,omitempty" json:"trust_remote_code,omitempty"` // vLLM
|
||||
EnforceEager bool `yaml:"enforce_eager,omitempty" json:"enforce_eager,omitempty"` // vLLM
|
||||
SwapSpace int `yaml:"swap_space,omitempty" json:"swap_space,omitempty"` // vLLM
|
||||
MaxModelLen int `yaml:"max_model_len,omitempty" json:"max_model_len,omitempty"` // vLLM
|
||||
TensorParallelSize int `yaml:"tensor_parallel_size,omitempty" json:"tensor_parallel_size,omitempty"` // vLLM
|
||||
DisableLogStatus bool `yaml:"disable_log_stats,omitempty" json:"disable_log_stats,omitempty"` // vLLM
|
||||
DType string `yaml:"dtype,omitempty" json:"dtype,omitempty"` // vLLM
|
||||
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt,omitempty" json:"limit_mm_per_prompt,omitempty"` // vLLM
|
||||
MMProj string `yaml:"mmproj,omitempty" json:"mmproj,omitempty"`
|
||||
|
||||
FlashAttention *string `yaml:"flash_attention" json:"flash_attention"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"`
|
||||
CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"`
|
||||
CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"`
|
||||
FlashAttention *string `yaml:"flash_attention,omitempty" json:"flash_attention,omitempty"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading,omitempty" json:"no_kv_offloading,omitempty"`
|
||||
CacheTypeK string `yaml:"cache_type_k,omitempty" json:"cache_type_k,omitempty"`
|
||||
CacheTypeV string `yaml:"cache_type_v,omitempty" json:"cache_type_v,omitempty"`
|
||||
|
||||
RopeScaling string `yaml:"rope_scaling" json:"rope_scaling"`
|
||||
ModelType string `yaml:"type" json:"type"`
|
||||
RopeScaling string `yaml:"rope_scaling,omitempty" json:"rope_scaling,omitempty"`
|
||||
ModelType string `yaml:"type,omitempty" json:"type,omitempty"`
|
||||
|
||||
YarnExtFactor float32 `yaml:"yarn_ext_factor" json:"yarn_ext_factor"`
|
||||
YarnAttnFactor float32 `yaml:"yarn_attn_factor" json:"yarn_attn_factor"`
|
||||
YarnBetaFast float32 `yaml:"yarn_beta_fast" json:"yarn_beta_fast"`
|
||||
YarnBetaSlow float32 `yaml:"yarn_beta_slow" json:"yarn_beta_slow"`
|
||||
YarnExtFactor float32 `yaml:"yarn_ext_factor,omitempty" json:"yarn_ext_factor,omitempty"`
|
||||
YarnAttnFactor float32 `yaml:"yarn_attn_factor,omitempty" json:"yarn_attn_factor,omitempty"`
|
||||
YarnBetaFast float32 `yaml:"yarn_beta_fast,omitempty" json:"yarn_beta_fast,omitempty"`
|
||||
YarnBetaSlow float32 `yaml:"yarn_beta_slow,omitempty" json:"yarn_beta_slow,omitempty"`
|
||||
|
||||
CFGScale float32 `yaml:"cfg_scale" json:"cfg_scale"` // Classifier-Free Guidance Scale
|
||||
CFGScale float32 `yaml:"cfg_scale,omitempty" json:"cfg_scale,omitempty"` // Classifier-Free Guidance Scale
|
||||
}
|
||||
|
||||
// LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
|
||||
// @Description LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
|
||||
type LimitMMPerPrompt struct {
|
||||
LimitImagePerPrompt int `yaml:"image" json:"image"`
|
||||
LimitVideoPerPrompt int `yaml:"video" json:"video"`
|
||||
LimitAudioPerPrompt int `yaml:"audio" json:"audio"`
|
||||
LimitImagePerPrompt int `yaml:"image,omitempty" json:"image,omitempty"`
|
||||
LimitVideoPerPrompt int `yaml:"video,omitempty" json:"video,omitempty"`
|
||||
LimitAudioPerPrompt int `yaml:"audio,omitempty" json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
// TemplateConfig is a struct that holds the configuration of the templating system
|
||||
// @Description TemplateConfig is a struct that holds the configuration of the templating system
|
||||
type TemplateConfig struct {
|
||||
// Chat is the template used in the chat completion endpoint
|
||||
Chat string `yaml:"chat" json:"chat"`
|
||||
Chat string `yaml:"chat,omitempty" json:"chat,omitempty"`
|
||||
|
||||
// ChatMessage is the template used for chat messages
|
||||
ChatMessage string `yaml:"chat_message" json:"chat_message"`
|
||||
ChatMessage string `yaml:"chat_message,omitempty" json:"chat_message,omitempty"`
|
||||
|
||||
// Completion is the template used for completion requests
|
||||
Completion string `yaml:"completion" json:"completion"`
|
||||
Completion string `yaml:"completion,omitempty" json:"completion,omitempty"`
|
||||
|
||||
// Edit is the template used for edit completion requests
|
||||
Edit string `yaml:"edit" json:"edit"`
|
||||
Edit string `yaml:"edit,omitempty" json:"edit,omitempty"`
|
||||
|
||||
// Functions is the template used when tools are present in the client requests
|
||||
Functions string `yaml:"function" json:"function"`
|
||||
Functions string `yaml:"function,omitempty" json:"function,omitempty"`
|
||||
|
||||
// UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used.
|
||||
// Note: this is mostly consumed for backends such as vllm and transformers
|
||||
// that can use the tokenizers specified in the JSON config files of the models
|
||||
UseTokenizerTemplate bool `yaml:"use_tokenizer_template" json:"use_tokenizer_template"`
|
||||
UseTokenizerTemplate bool `yaml:"use_tokenizer_template,omitempty" json:"use_tokenizer_template,omitempty"`
|
||||
|
||||
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
|
||||
// It defaults to \n
|
||||
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character" json:"join_chat_messages_by_character"`
|
||||
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character,omitempty" json:"join_chat_messages_by_character,omitempty"`
|
||||
|
||||
Multimodal string `yaml:"multimodal" json:"multimodal"`
|
||||
Multimodal string `yaml:"multimodal,omitempty" json:"multimodal,omitempty"`
|
||||
|
||||
JinjaTemplate bool `yaml:"jinja_template" json:"jinja_template"`
|
||||
ReplyPrefix string `yaml:"reply_prefix,omitempty" json:"reply_prefix,omitempty"`
|
||||
}
|
||||
|
||||
ReplyPrefix string `yaml:"reply_prefix" json:"reply_prefix"`
|
||||
func (c *ModelConfig) syncKnownUsecasesFromString() {
|
||||
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
|
||||
// Make sure the usecases are valid, we rewrite with what we identified
|
||||
c.KnownUsecaseStrings = []string{}
|
||||
for k, usecase := range GetAllModelConfigUsecases() {
|
||||
if c.HasUsecases(usecase) {
|
||||
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
@@ -278,14 +296,7 @@ func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
}
|
||||
*c = ModelConfig(aux)
|
||||
|
||||
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
|
||||
// Make sure the usecases are valid, we rewrite with what we identified
|
||||
c.KnownUsecaseStrings = []string{}
|
||||
for k, usecase := range GetAllModelConfigUsecases() {
|
||||
if c.HasUsecases(usecase) {
|
||||
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
|
||||
}
|
||||
}
|
||||
c.syncKnownUsecasesFromString()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -462,9 +473,10 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
}
|
||||
|
||||
guessDefaultsFromFile(cfg, lo.modelPath, ctx)
|
||||
cfg.syncKnownUsecasesFromString()
|
||||
}
|
||||
|
||||
func (c *ModelConfig) Validate() bool {
|
||||
func (c *ModelConfig) Validate() (bool, error) {
|
||||
downloadedFileNames := []string{}
|
||||
for _, f := range c.DownloadFiles {
|
||||
downloadedFileNames = append(downloadedFileNames, f.Filename)
|
||||
@@ -478,21 +490,24 @@ func (c *ModelConfig) Validate() bool {
|
||||
}
|
||||
if strings.HasPrefix(n, string(os.PathSeparator)) ||
|
||||
strings.Contains(n, "..") {
|
||||
return false
|
||||
return false, fmt.Errorf("invalid file path: %s", n)
|
||||
}
|
||||
}
|
||||
|
||||
if c.Backend != "" {
|
||||
// a regex that checks that is a string name with no special characters, except '-' and '_'
|
||||
re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`)
|
||||
return re.MatchString(c.Backend)
|
||||
if !re.MatchString(c.Backend) {
|
||||
return false, fmt.Errorf("invalid backend name: %s", c.Backend)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return true
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *ModelConfig) HasTemplate() bool {
|
||||
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != ""
|
||||
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" || c.TemplateConfig.UseTokenizerTemplate
|
||||
}
|
||||
|
||||
func (c *ModelConfig) GetModelConfigFile() string {
|
||||
@@ -523,7 +538,8 @@ const (
|
||||
|
||||
func GetAllModelConfigUsecases() map[string]ModelConfigUsecases {
|
||||
return map[string]ModelConfigUsecases{
|
||||
"FLAG_ANY": FLAG_ANY,
|
||||
// Note: FLAG_ANY is intentionally excluded from this map
|
||||
// because it's 0 and would always match in HasUsecases checks
|
||||
"FLAG_CHAT": FLAG_CHAT,
|
||||
"FLAG_COMPLETION": FLAG_COMPLETION,
|
||||
"FLAG_EDIT": FLAG_EDIT,
|
||||
@@ -573,7 +589,7 @@ func (c *ModelConfig) HasUsecases(u ModelConfigUsecases) bool {
|
||||
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
|
||||
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
|
||||
if (u & FLAG_CHAT) == FLAG_CHAT {
|
||||
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" {
|
||||
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -625,7 +641,7 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
|
||||
}
|
||||
}
|
||||
if (u & FLAG_TTS) == FLAG_TTS {
|
||||
ttsBackends := []string{"bark-cpp", "piper", "transformers-musicgen"}
|
||||
ttsBackends := []string{"bark-cpp", "piper", "transformers-musicgen", "kokoro"}
|
||||
if !slices.Contains(ttsBackends, c.Backend) {
|
||||
return false
|
||||
}
|
||||
@@ -658,3 +674,40 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// BuildCogitoOptions generates cogito options from the model configuration
|
||||
// It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results
|
||||
func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
|
||||
cogitoOpts := []cogito.Option{
|
||||
cogito.WithIterations(3), // default to 3 iterations
|
||||
cogito.WithMaxAttempts(3), // default to 3 attempts
|
||||
cogito.WithForceReasoning(),
|
||||
}
|
||||
|
||||
// Apply agent configuration options
|
||||
if c.Agent.EnableReasoning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
|
||||
}
|
||||
|
||||
if c.Agent.EnablePlanning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
|
||||
}
|
||||
|
||||
if c.Agent.EnableMCPPrompts {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
|
||||
}
|
||||
|
||||
if c.Agent.EnablePlanReEvaluator {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
|
||||
}
|
||||
|
||||
if c.Agent.MaxIterations != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations))
|
||||
}
|
||||
|
||||
if c.Agent.MaxAttempts != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts))
|
||||
}
|
||||
|
||||
return cogitoOpts
|
||||
}
|
||||
|
||||
@@ -169,7 +169,7 @@ func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, op
|
||||
}
|
||||
|
||||
for _, cc := range c {
|
||||
if cc.Validate() {
|
||||
if valid, _ := cc.Validate(); valid {
|
||||
bcl.configs[cc.Name] = *cc
|
||||
}
|
||||
}
|
||||
@@ -184,7 +184,7 @@ func (bcl *ModelConfigLoader) ReadModelConfig(file string, opts ...ConfigLoaderO
|
||||
return fmt.Errorf("ReadModelConfig cannot read config file %q: %w", file, err)
|
||||
}
|
||||
|
||||
if c.Validate() {
|
||||
if valid, _ := c.Validate(); valid {
|
||||
bcl.configs[c.Name] = *c
|
||||
} else {
|
||||
return fmt.Errorf("config is not valid")
|
||||
@@ -362,7 +362,7 @@ func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...Conf
|
||||
log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadModelConfigsFromPath cannot read config file")
|
||||
continue
|
||||
}
|
||||
if c.Validate() {
|
||||
if valid, _ := c.Validate(); valid {
|
||||
bcl.configs[c.Name] = *c
|
||||
} else {
|
||||
log.Error().Err(err).Str("Name", c.Name).Msgf("config is not valid")
|
||||
|
||||
@@ -28,7 +28,9 @@ known_usecases:
|
||||
config, err := readModelConfigFromFile(tmp.Name())
|
||||
Expect(err).To(BeNil())
|
||||
Expect(config).ToNot(BeNil())
|
||||
Expect(config.Validate()).To(BeFalse())
|
||||
valid, err := config.Validate()
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(valid).To(BeFalse())
|
||||
Expect(config.KnownUsecases).ToNot(BeNil())
|
||||
})
|
||||
It("Test Validate", func() {
|
||||
@@ -46,7 +48,9 @@ parameters:
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
Expect(config.Name).To(Equal("bar-baz"))
|
||||
Expect(config.Validate()).To(BeTrue())
|
||||
valid, err := config.Validate()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(valid).To(BeTrue())
|
||||
|
||||
// download https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml
|
||||
httpClient := http.Client{}
|
||||
@@ -63,7 +67,9 @@ parameters:
|
||||
Expect(config).ToNot(BeNil())
|
||||
// two configs in config.yaml
|
||||
Expect(config.Name).To(Equal("hermes-2-pro-mistral"))
|
||||
Expect(config.Validate()).To(BeTrue())
|
||||
valid, err = config.Validate()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(valid).To(BeTrue())
|
||||
})
|
||||
})
|
||||
It("Properly handles backend usecase matching", func() {
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -68,7 +70,7 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
|
||||
}
|
||||
|
||||
// InstallBackendFromGallery installs a backend from the gallery.
|
||||
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
|
||||
func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
|
||||
if !force {
|
||||
// check if we already have the backend installed
|
||||
backends, err := ListSystemBackends(systemState)
|
||||
@@ -108,7 +110,7 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
|
||||
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
|
||||
|
||||
// Then, let's install the best backend
|
||||
if err := InstallBackend(systemState, modelLoader, bestBackend, downloadStatus); err != nil {
|
||||
if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -133,10 +135,10 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
|
||||
return nil
|
||||
}
|
||||
|
||||
return InstallBackend(systemState, modelLoader, backend, downloadStatus)
|
||||
return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus)
|
||||
}
|
||||
|
||||
func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
|
||||
func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
|
||||
if err != nil {
|
||||
@@ -162,23 +164,40 @@ func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoa
|
||||
return fmt.Errorf("failed copying: %w", err)
|
||||
}
|
||||
} else {
|
||||
uri := downloader.URI(config.URI)
|
||||
if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||
log.Debug().Str("uri", config.URI).Str("backendPath", backendPath).Msg("Downloading backend")
|
||||
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||
success := false
|
||||
// Try to download from mirrors
|
||||
for _, mirror := range config.Mirrors {
|
||||
if err := downloader.URI(mirror).DownloadFile(backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
// Check for cancellation before trying next mirror
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
success = true
|
||||
log.Debug().Str("uri", config.URI).Str("backendPath", backendPath).Msg("Downloaded backend")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !success {
|
||||
log.Error().Str("uri", config.URI).Str("backendPath", backendPath).Err(err).Msg("Failed to download backend")
|
||||
return fmt.Errorf("failed to download backend %q: %v", config.URI, err)
|
||||
}
|
||||
} else {
|
||||
log.Debug().Str("uri", config.URI).Str("backendPath", backendPath).Msg("Downloaded backend")
|
||||
}
|
||||
}
|
||||
|
||||
// sanity check - check if runfile is present
|
||||
runFile := filepath.Join(backendPath, runFile)
|
||||
if _, err := os.Stat(runFile); os.IsNotExist(err) {
|
||||
log.Error().Str("runFile", runFile).Msg("Run file not found")
|
||||
return fmt.Errorf("not a valid backend: run file not found %q", runFile)
|
||||
}
|
||||
|
||||
// Create metadata for the backend
|
||||
metadata := &BackendMetadata{
|
||||
Name: name,
|
||||
@@ -310,8 +329,10 @@ func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
log.Warn().Err(err).Msg("Failed to read system backends, proceeding with user-managed backends")
|
||||
} else if errors.Is(err, os.ErrNotExist) {
|
||||
log.Debug().Msg("No system backends found")
|
||||
}
|
||||
|
||||
// User-managed backends and alias collection
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -55,7 +56,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
|
||||
)
|
||||
must(err)
|
||||
sysDefault.GPUVendor = "" // force default selection
|
||||
backs, err := ListSystemBackends(sysDefault)
|
||||
backs, err := ListSystemBackends(sysDefault)
|
||||
must(err)
|
||||
aliasBack, ok := backs.Get("llama-cpp")
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -77,7 +78,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
|
||||
must(err)
|
||||
sysNvidia.GPUVendor = "nvidia"
|
||||
sysNvidia.VRAM = 8 * 1024 * 1024 * 1024
|
||||
backs, err = ListSystemBackends(sysNvidia)
|
||||
backs, err = ListSystemBackends(sysNvidia)
|
||||
must(err)
|
||||
aliasBack, ok = backs.Get("llama-cpp")
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -116,13 +117,13 @@ var _ = Describe("Gallery Backends", func() {
|
||||
|
||||
Describe("InstallBackendFromGallery", func() {
|
||||
It("should return error when backend is not found", func() {
|
||||
err := InstallBackendFromGallery(galleries, systemState, ml, "non-existent", nil, true)
|
||||
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
|
||||
})
|
||||
|
||||
It("should install backend from gallery", func() {
|
||||
err := InstallBackendFromGallery(galleries, systemState, ml, "test-backend", nil, true)
|
||||
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
|
||||
})
|
||||
@@ -298,7 +299,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -378,7 +379,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -462,7 +463,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -561,9 +562,9 @@ var _ = Describe("Gallery Backends", func() {
|
||||
system.WithBackendPath(newPath),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(newPath).To(BeADirectory())
|
||||
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
|
||||
})
|
||||
|
||||
It("should overwrite existing backend", func() {
|
||||
@@ -593,7 +594,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
|
||||
@@ -626,7 +627,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
|
||||
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
})
|
||||
@@ -647,7 +648,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lithammer/fuzzysearch/fuzzy"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/mudler/LocalAI/pkg/xsync"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
@@ -18,7 +21,20 @@ import (
|
||||
func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
|
||||
var config T
|
||||
uri := downloader.URI(url)
|
||||
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
|
||||
err := uri.ReadWithCallback(basePath, func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &config)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url")
|
||||
return config, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
|
||||
var config T
|
||||
uri := downloader.URI(url)
|
||||
err := uri.ReadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &config)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -61,12 +77,15 @@ func (gm GalleryElements[T]) Search(term string) GalleryElements[T] {
|
||||
term = strings.ToLower(term)
|
||||
for _, m := range gm {
|
||||
if fuzzy.Match(term, strings.ToLower(m.GetName())) ||
|
||||
fuzzy.Match(term, strings.ToLower(m.GetDescription())) ||
|
||||
fuzzy.Match(term, strings.ToLower(m.GetGallery().Name)) ||
|
||||
strings.Contains(strings.ToLower(m.GetName()), term) ||
|
||||
strings.Contains(strings.ToLower(m.GetDescription()), term) ||
|
||||
strings.Contains(strings.ToLower(m.GetGallery().Name), term) ||
|
||||
strings.Contains(strings.ToLower(strings.Join(m.GetTags(), ",")), term) {
|
||||
filteredModels = append(filteredModels, m)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredModels
|
||||
}
|
||||
|
||||
@@ -124,7 +143,7 @@ func AvailableGalleryModels(galleries []config.Gallery, systemState *system.Syst
|
||||
|
||||
// Get models from galleries
|
||||
for _, gallery := range galleries {
|
||||
galleryModels, err := getGalleryElements[*GalleryModel](gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool {
|
||||
galleryModels, err := getGalleryElements(gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool {
|
||||
if _, err := os.Stat(filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil {
|
||||
return true
|
||||
}
|
||||
@@ -165,7 +184,7 @@ func AvailableBackends(galleries []config.Gallery, systemState *system.SystemSta
|
||||
func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
|
||||
var refFile string
|
||||
uri := downloader.URI(url)
|
||||
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
|
||||
err := uri.ReadWithCallback(basePath, func(url string, d []byte) error {
|
||||
refFile = string(d)
|
||||
if len(refFile) == 0 {
|
||||
return fmt.Errorf("invalid reference file at url %s: %s", url, d)
|
||||
@@ -177,6 +196,17 @@ func findGalleryURLFromReferenceURL(url string, basePath string) (string, error)
|
||||
return refFile, err
|
||||
}
|
||||
|
||||
type galleryCacheEntry struct {
|
||||
yamlEntry []byte
|
||||
lastUpdated time.Time
|
||||
}
|
||||
|
||||
func (entry galleryCacheEntry) hasExpired() bool {
|
||||
return entry.lastUpdated.Before(time.Now().Add(-1 * time.Hour))
|
||||
}
|
||||
|
||||
var galleryCache = xsync.NewSyncedMap[string, galleryCacheEntry]()
|
||||
|
||||
func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath string, isInstalledCallback func(T) bool) ([]T, error) {
|
||||
var models []T = []T{}
|
||||
|
||||
@@ -187,16 +217,37 @@ func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath strin
|
||||
return models, err
|
||||
}
|
||||
}
|
||||
|
||||
cacheKey := fmt.Sprintf("%s-%s", gallery.Name, gallery.URL)
|
||||
if galleryCache.Exists(cacheKey) {
|
||||
entry := galleryCache.Get(cacheKey)
|
||||
// refresh if last updated is more than 1 hour ago
|
||||
if !entry.hasExpired() {
|
||||
err := yaml.Unmarshal(entry.yamlEntry, &models)
|
||||
if err != nil {
|
||||
return models, err
|
||||
}
|
||||
} else {
|
||||
galleryCache.Delete(cacheKey)
|
||||
}
|
||||
}
|
||||
|
||||
uri := downloader.URI(gallery.URL)
|
||||
|
||||
err := uri.DownloadWithCallback(basePath, func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &models)
|
||||
})
|
||||
if err != nil {
|
||||
if yamlErr, ok := err.(*yaml.TypeError); ok {
|
||||
log.Debug().Msgf("YAML errors: %s\n\nwreckage of models: %+v", strings.Join(yamlErr.Errors, "\n"), models)
|
||||
if len(models) == 0 {
|
||||
err := uri.ReadWithCallback(basePath, func(url string, d []byte) error {
|
||||
galleryCache.Set(cacheKey, galleryCacheEntry{
|
||||
yamlEntry: d,
|
||||
lastUpdated: time.Now(),
|
||||
})
|
||||
return yaml.Unmarshal(d, &models)
|
||||
})
|
||||
if err != nil {
|
||||
if yamlErr, ok := err.(*yaml.TypeError); ok {
|
||||
log.Debug().Msgf("YAML errors: %s\n\nwreckage of models: %+v", strings.Join(yamlErr.Errors, "\n"), models)
|
||||
}
|
||||
return models, fmt.Errorf("failed to read gallery elements: %w", err)
|
||||
}
|
||||
return models, err
|
||||
}
|
||||
|
||||
// Add gallery to models
|
||||
|
||||
121
core/gallery/importers/diffuser.go
Normal file
121
core/gallery/importers/diffuser.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var _ Importer = &DiffuserImporter{}
|
||||
|
||||
type DiffuserImporter struct{}
|
||||
|
||||
func (i *DiffuserImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok && b == "diffusers" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
if strings.Contains(file.Path, "model_index.json") ||
|
||||
strings.Contains(file.Path, "scheduler/scheduler_config.json") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *DiffuserImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
backend := "diffusers"
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok {
|
||||
backend = b
|
||||
}
|
||||
|
||||
pipelineType, ok := preferencesMap["pipeline_type"].(string)
|
||||
if !ok {
|
||||
pipelineType = "StableDiffusionPipeline"
|
||||
}
|
||||
|
||||
schedulerType, ok := preferencesMap["scheduler_type"].(string)
|
||||
if !ok {
|
||||
schedulerType = ""
|
||||
}
|
||||
|
||||
enableParameters, ok := preferencesMap["enable_parameters"].(string)
|
||||
if !ok {
|
||||
enableParameters = "negative_prompt,num_inference_steps"
|
||||
}
|
||||
|
||||
cuda := false
|
||||
if cudaVal, ok := preferencesMap["cuda"].(bool); ok {
|
||||
cuda = cudaVal
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"image"},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
},
|
||||
},
|
||||
Diffusers: config.Diffusers{
|
||||
PipelineType: pipelineType,
|
||||
SchedulerType: schedulerType,
|
||||
EnableParameters: enableParameters,
|
||||
CUDA: cuda,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
246
core/gallery/importers/diffuser_test.go
Normal file
246
core/gallery/importers/diffuser_test.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("DiffuserImporter", func() {
|
||||
var importer *DiffuserImporter
|
||||
|
||||
BeforeEach(func() {
|
||||
importer = &DiffuserImporter{}
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("should match when backend preference is diffusers", func() {
|
||||
preferences := json.RawMessage(`{"backend": "diffusers"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain model_index.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "model_index.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain scheduler config", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "scheduler/scheduler_config.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should not match when URI has no diffuser files and no backend preference", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/model.bin",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should not match when backend preference is different", func() {
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when JSON preferences are invalid", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("should import model config with default name and description", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-diffuser-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-diffuser-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-diffuser-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: diffusers"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-diffuser-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusionPipeline"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("enable_parameters: negative_prompt,num_inference_steps"))
|
||||
})
|
||||
|
||||
It("should import model config with custom name and description from preferences", func() {
|
||||
preferences := json.RawMessage(`{"name": "custom-diffuser", "description": "Custom diffuser model"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-diffuser"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom diffuser model"))
|
||||
})
|
||||
|
||||
It("should use custom pipeline_type from preferences", func() {
|
||||
preferences := json.RawMessage(`{"pipeline_type": "StableDiffusion3Pipeline"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusion3Pipeline"))
|
||||
})
|
||||
|
||||
It("should use default pipeline_type when not specified", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusionPipeline"))
|
||||
})
|
||||
|
||||
It("should use custom scheduler_type from preferences", func() {
|
||||
preferences := json.RawMessage(`{"scheduler_type": "k_dpmpp_2m"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("scheduler_type: k_dpmpp_2m"))
|
||||
})
|
||||
|
||||
It("should use cuda setting from preferences", func() {
|
||||
preferences := json.RawMessage(`{"cuda": true}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("cuda: true"))
|
||||
})
|
||||
|
||||
It("should use custom enable_parameters from preferences", func() {
|
||||
preferences := json.RawMessage(`{"enable_parameters": "num_inference_steps,guidance_scale"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("enable_parameters: num_inference_steps,guidance_scale"))
|
||||
})
|
||||
|
||||
It("should use custom backend from preferences", func() {
|
||||
preferences := json.RawMessage(`{"backend": "diffusers"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: diffusers"))
|
||||
})
|
||||
|
||||
It("should handle invalid JSON preferences", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
_, err := importer.Import(details)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should extract filename correctly from URI with path", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/test/path/to/model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("model"))
|
||||
})
|
||||
|
||||
It("should include known_usecases as image in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- image"))
|
||||
})
|
||||
|
||||
It("should include diffusers configuration in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("diffusers:"))
|
||||
})
|
||||
})
|
||||
})
|
||||
121
core/gallery/importers/importers.go
Normal file
121
core/gallery/importers/importers.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
var defaultImporters = []Importer{
|
||||
&LlamaCPPImporter{},
|
||||
&MLXImporter{},
|
||||
&VLLMImporter{},
|
||||
&TransformersImporter{},
|
||||
&DiffuserImporter{},
|
||||
}
|
||||
|
||||
type Details struct {
|
||||
HuggingFace *hfapi.ModelDetails
|
||||
URI string
|
||||
Preferences json.RawMessage
|
||||
}
|
||||
|
||||
type Importer interface {
|
||||
Match(details Details) bool
|
||||
Import(details Details) (gallery.ModelConfig, error)
|
||||
}
|
||||
|
||||
func hasYAMLExtension(uri string) bool {
|
||||
return strings.HasSuffix(uri, ".yaml") || strings.HasSuffix(uri, ".yml")
|
||||
}
|
||||
|
||||
func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.ModelConfig, error) {
|
||||
var err error
|
||||
var modelConfig gallery.ModelConfig
|
||||
|
||||
hf := hfapi.NewClient()
|
||||
|
||||
hfrepoID := strings.ReplaceAll(uri, "huggingface://", "")
|
||||
hfrepoID = strings.ReplaceAll(hfrepoID, "hf://", "")
|
||||
hfrepoID = strings.ReplaceAll(hfrepoID, "https://huggingface.co/", "")
|
||||
|
||||
hfDetails, err := hf.GetModelDetails(hfrepoID)
|
||||
if err != nil {
|
||||
// maybe not a HF repository
|
||||
// TODO: maybe we can check if the URI is a valid HF repository
|
||||
log.Debug().Str("uri", uri).Str("hfrepoID", hfrepoID).Msg("Failed to get model details, maybe not a HF repository")
|
||||
} else {
|
||||
log.Debug().Str("uri", uri).Msg("Got model details")
|
||||
log.Debug().Any("details", hfDetails).Msg("Model details")
|
||||
}
|
||||
|
||||
// handle local config files ("/my-model.yaml" or "file://my-model.yaml")
|
||||
localURI := uri
|
||||
if strings.HasPrefix(uri, downloader.LocalPrefix) {
|
||||
localURI = strings.TrimPrefix(uri, downloader.LocalPrefix)
|
||||
}
|
||||
|
||||
// if a file exists or it's an url that ends with .yaml or .yml, read the config file directly
|
||||
if _, e := os.Stat(localURI); hasYAMLExtension(localURI) && (e == nil || downloader.URI(localURI).LooksLikeURL()) {
|
||||
var modelYAML []byte
|
||||
if downloader.URI(localURI).LooksLikeURL() {
|
||||
err := downloader.URI(localURI).ReadWithCallback(localURI, func(url string, i []byte) error {
|
||||
modelYAML = i
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("filepath", localURI).Msg("error reading model definition")
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
} else {
|
||||
modelYAML, err = os.ReadFile(localURI)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("filepath", localURI).Msg("error reading model definition")
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
var modelConfig config.ModelConfig
|
||||
if e := yaml.Unmarshal(modelYAML, &modelConfig); e != nil {
|
||||
return gallery.ModelConfig{}, e
|
||||
}
|
||||
|
||||
configFile, err := yaml.Marshal(modelConfig)
|
||||
return gallery.ModelConfig{
|
||||
Description: modelConfig.Description,
|
||||
Name: modelConfig.Name,
|
||||
ConfigFile: string(configFile),
|
||||
}, err
|
||||
}
|
||||
|
||||
details := Details{
|
||||
HuggingFace: hfDetails,
|
||||
URI: uri,
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
importerMatched := false
|
||||
for _, importer := range defaultImporters {
|
||||
if importer.Match(details) {
|
||||
importerMatched = true
|
||||
modelConfig, err = importer.Import(details)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
if !importerMatched {
|
||||
return gallery.ModelConfig{}, fmt.Errorf("no importer matched for %s", uri)
|
||||
}
|
||||
return modelConfig, nil
|
||||
}
|
||||
13
core/gallery/importers/importers_suite_test.go
Normal file
13
core/gallery/importers/importers_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestImporters(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Importers test suite")
|
||||
}
|
||||
352
core/gallery/importers/importers_test.go
Normal file
352
core/gallery/importers/importers_test.go
Normal file
@@ -0,0 +1,352 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("DiscoverModelConfig", func() {
|
||||
|
||||
Context("With only a repository URI", func() {
|
||||
It("should discover and import using LlamaCPPImporter", func() {
|
||||
uri := "https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.Name).To(Equal("LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF/resolve/main/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].SHA256).To(Equal("4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("should discover and import using LlamaCPPImporter", func() {
|
||||
uri := "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: mmproj/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("mmproj/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("should discover and import using LlamaCPPImporter", func() {
|
||||
uri := "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"
|
||||
preferences := json.RawMessage(`{ "quantizations": "Q8_0", "mmproj_quantizations": "f16" }`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: mmproj/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("mmproj/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with .gguf URI", func() {
|
||||
It("should discover and import using LlamaCPPImporter", func() {
|
||||
uri := "https://example.com/my-model.gguf"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-model.gguf"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://example.com/my-model.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
|
||||
})
|
||||
|
||||
It("should use custom preferences when provided", func() {
|
||||
uri := "https://example.com/my-model.gguf"
|
||||
preferences := json.RawMessage(`{"name": "custom-name", "description": "Custom description"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-name"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom description"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with mlx-community URI", func() {
|
||||
It("should discover and import using MLXImporter", func() {
|
||||
uri := "https://huggingface.co/mlx-community/test-model"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("test-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mlx-community/test-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
|
||||
})
|
||||
|
||||
It("should use custom preferences when provided", func() {
|
||||
uri := "https://huggingface.co/mlx-community/test-model"
|
||||
preferences := json.RawMessage(`{"name": "custom-mlx", "description": "Custom MLX description"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-mlx"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom MLX description"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with backend preference", func() {
|
||||
It("should use llama-cpp backend when specified", func() {
|
||||
uri := "https://example.com/model"
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
|
||||
})
|
||||
|
||||
It("should use mlx backend when specified", func() {
|
||||
uri := "https://example.com/model"
|
||||
preferences := json.RawMessage(`{"backend": "mlx"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
|
||||
})
|
||||
|
||||
It("should use mlx-vlm backend when specified", func() {
|
||||
uri := "https://example.com/model"
|
||||
preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx-vlm"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with HuggingFace URI formats", func() {
|
||||
It("should handle huggingface:// prefix", func() {
|
||||
uri := "huggingface://mlx-community/test-model"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("test-model"))
|
||||
})
|
||||
|
||||
It("should handle hf:// prefix", func() {
|
||||
uri := "hf://mlx-community/test-model"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("test-model"))
|
||||
})
|
||||
|
||||
It("should handle https://huggingface.co/ prefix", func() {
|
||||
uri := "https://huggingface.co/mlx-community/test-model"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("test-model"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with invalid or non-matching URI", func() {
|
||||
It("should return error when no importer matches", func() {
|
||||
uri := "https://example.com/unknown-model.bin"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
// When no importer matches, the function returns empty config and error
|
||||
// The exact behavior depends on implementation, but typically an error is returned
|
||||
Expect(modelConfig.Name).To(BeEmpty())
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with invalid JSON preferences", func() {
|
||||
It("should return error when JSON is invalid even if URI matches", func() {
|
||||
uri := "https://example.com/model.gguf"
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
|
||||
// Even though Match() returns true for .gguf extension,
|
||||
// Import() will fail when trying to unmarshal invalid JSON preferences
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with local YAML config files", func() {
|
||||
var tempDir string
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tempDir, err = os.MkdirTemp("", "importers-test-*")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
os.RemoveAll(tempDir)
|
||||
})
|
||||
|
||||
It("should read local YAML file with file:// prefix", func() {
|
||||
yamlContent := `name: test-model
|
||||
backend: llama-cpp
|
||||
description: Test model from local YAML
|
||||
parameters:
|
||||
model: /path/to/model.gguf
|
||||
temperature: 0.7
|
||||
`
|
||||
yamlFile := filepath.Join(tempDir, "test-model.yaml")
|
||||
err := os.WriteFile(yamlFile, []byte(yamlContent), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
uri := "file://" + yamlFile
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("test-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Test model from local YAML"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("name: test-model"))
|
||||
})
|
||||
|
||||
It("should read local YAML file without file:// prefix (direct path)", func() {
|
||||
yamlContent := `name: direct-path-model
|
||||
backend: mlx
|
||||
description: Test model from direct path
|
||||
parameters:
|
||||
model: /path/to/model.safetensors
|
||||
`
|
||||
yamlFile := filepath.Join(tempDir, "direct-model.yaml")
|
||||
err := os.WriteFile(yamlFile, []byte(yamlContent), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
uri := yamlFile
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("direct-path-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Test model from direct path"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
|
||||
})
|
||||
|
||||
It("should read local YAML file with .yml extension", func() {
|
||||
yamlContent := `name: yml-extension-model
|
||||
backend: transformers
|
||||
description: Test model with .yml extension
|
||||
parameters:
|
||||
model: /path/to/model
|
||||
`
|
||||
yamlFile := filepath.Join(tempDir, "test-model.yml")
|
||||
err := os.WriteFile(yamlFile, []byte(yamlContent), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
uri := "file://" + yamlFile
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("yml-extension-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Test model with .yml extension"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
|
||||
})
|
||||
|
||||
It("should ignore preferences when reading YAML files directly", func() {
|
||||
yamlContent := `name: yaml-model
|
||||
backend: llama-cpp
|
||||
description: Original description
|
||||
parameters:
|
||||
model: /path/to/model.gguf
|
||||
`
|
||||
yamlFile := filepath.Join(tempDir, "prefs-test.yaml")
|
||||
err := os.WriteFile(yamlFile, []byte(yamlContent), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
uri := "file://" + yamlFile
|
||||
// Preferences should be ignored when reading YAML directly
|
||||
preferences := json.RawMessage(`{"name": "custom-name", "description": "Custom description", "backend": "mlx"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
// Should use values from YAML file, not preferences
|
||||
Expect(modelConfig.Name).To(Equal("yaml-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Original description"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
|
||||
})
|
||||
|
||||
It("should return error when local YAML file doesn't exist", func() {
|
||||
nonExistentFile := filepath.Join(tempDir, "nonexistent.yaml")
|
||||
uri := "file://" + nonExistentFile
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("should return error when YAML file is invalid/malformed", func() {
|
||||
invalidYaml := `name: invalid-model
|
||||
backend: llama-cpp
|
||||
invalid: yaml: content: [unclosed bracket
|
||||
`
|
||||
yamlFile := filepath.Join(tempDir, "invalid.yaml")
|
||||
err := os.WriteFile(yamlFile, []byte(invalidYaml), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
uri := "file://" + yamlFile
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
260
core/gallery/importers/llama-cpp.go
Normal file
260
core/gallery/importers/llama-cpp.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &LlamaCPPImporter{}
|
||||
|
||||
type LlamaCPPImporter struct{}
|
||||
|
||||
func (i *LlamaCPPImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to marshal preferences")
|
||||
return false
|
||||
}
|
||||
|
||||
preferencesMap := make(map[string]any)
|
||||
|
||||
if len(preferences) > 0 {
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to unmarshal preferences")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
uri := downloader.URI(details.URI)
|
||||
|
||||
if preferencesMap["backend"] == "llama-cpp" {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasSuffix(details.URI, ".gguf") {
|
||||
return true
|
||||
}
|
||||
|
||||
if uri.LooksLikeOCI() {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
if strings.HasSuffix(file.Path, ".gguf") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
|
||||
log.Debug().Str("uri", details.URI).Msg("llama.cpp importer matched")
|
||||
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
preferedQuantizations, _ := preferencesMap["quantizations"].(string)
|
||||
quants := []string{"q4_k_m"}
|
||||
if preferedQuantizations != "" {
|
||||
quants = strings.Split(preferedQuantizations, ",")
|
||||
}
|
||||
|
||||
mmprojQuants, _ := preferencesMap["mmproj_quantizations"].(string)
|
||||
mmprojQuantsList := []string{"fp16"}
|
||||
if mmprojQuants != "" {
|
||||
mmprojQuantsList = strings.Split(mmprojQuants, ",")
|
||||
}
|
||||
|
||||
embeddings, _ := preferencesMap["embeddings"].(string)
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Options: []string{"use_jinja:true"},
|
||||
Backend: "llama-cpp",
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
},
|
||||
FunctionsConfig: functions.FunctionsConfig{
|
||||
GrammarConfig: functions.GrammarConfig{
|
||||
NoGrammar: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if embeddings != "" && strings.ToLower(embeddings) == "true" || strings.ToLower(embeddings) == "yes" {
|
||||
trueV := true
|
||||
modelConfig.Embeddings = &trueV
|
||||
}
|
||||
|
||||
cfg := gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
}
|
||||
|
||||
uri := downloader.URI(details.URI)
|
||||
|
||||
switch {
|
||||
case uri.LooksLikeOCI():
|
||||
ociName := strings.TrimPrefix(string(uri), downloader.OCIPrefix)
|
||||
ociName = strings.TrimPrefix(ociName, downloader.OllamaPrefix)
|
||||
ociName = strings.ReplaceAll(ociName, "/", "__")
|
||||
ociName = strings.ReplaceAll(ociName, ":", "__")
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: details.URI,
|
||||
Filename: ociName,
|
||||
})
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: ociName,
|
||||
},
|
||||
}
|
||||
case uri.LooksLikeURL() && strings.HasSuffix(details.URI, ".gguf"):
|
||||
// Extract filename from URL
|
||||
fileName, e := uri.FilenameFromUrl()
|
||||
if e != nil {
|
||||
return gallery.ModelConfig{}, e
|
||||
}
|
||||
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: details.URI,
|
||||
Filename: fileName,
|
||||
})
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: fileName,
|
||||
},
|
||||
}
|
||||
case strings.HasSuffix(details.URI, ".gguf"):
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: details.URI,
|
||||
Filename: filepath.Base(details.URI),
|
||||
})
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: filepath.Base(details.URI),
|
||||
},
|
||||
}
|
||||
case details.HuggingFace != nil:
|
||||
// We want to:
|
||||
// Get first the chosen quants that match filenames
|
||||
// OR the first mmproj/gguf file found
|
||||
var lastMMProjFile *gallery.File
|
||||
var lastGGUFFile *gallery.File
|
||||
foundPreferedQuant := false
|
||||
foundPreferedMMprojQuant := false
|
||||
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
// Get the mmproj prefered quants
|
||||
if strings.Contains(strings.ToLower(file.Path), "mmproj") {
|
||||
lastMMProjFile = &gallery.File{
|
||||
URI: file.URL,
|
||||
Filename: filepath.Join("mmproj", filepath.Base(file.Path)),
|
||||
SHA256: file.SHA256,
|
||||
}
|
||||
if slices.ContainsFunc(mmprojQuantsList, func(quant string) bool {
|
||||
return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
|
||||
}) {
|
||||
cfg.Files = append(cfg.Files, *lastMMProjFile)
|
||||
foundPreferedMMprojQuant = true
|
||||
}
|
||||
} else if strings.HasSuffix(strings.ToLower(file.Path), "gguf") {
|
||||
lastGGUFFile = &gallery.File{
|
||||
URI: file.URL,
|
||||
Filename: filepath.Base(file.Path),
|
||||
SHA256: file.SHA256,
|
||||
}
|
||||
// get the files of the prefered quants
|
||||
if slices.ContainsFunc(quants, func(quant string) bool {
|
||||
return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
|
||||
}) {
|
||||
foundPreferedQuant = true
|
||||
cfg.Files = append(cfg.Files, *lastGGUFFile)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure to add at least one file if not already present (which is the latest one)
|
||||
if lastMMProjFile != nil && !foundPreferedMMprojQuant {
|
||||
if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
|
||||
return f.Filename == lastMMProjFile.Filename
|
||||
}) {
|
||||
cfg.Files = append(cfg.Files, *lastMMProjFile)
|
||||
}
|
||||
}
|
||||
|
||||
if lastGGUFFile != nil && !foundPreferedQuant {
|
||||
if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
|
||||
return f.Filename == lastGGUFFile.Filename
|
||||
}) {
|
||||
cfg.Files = append(cfg.Files, *lastGGUFFile)
|
||||
}
|
||||
}
|
||||
|
||||
// Find first mmproj file and configure it in the config file
|
||||
for _, file := range cfg.Files {
|
||||
if !strings.Contains(strings.ToLower(file.Filename), "mmproj") {
|
||||
continue
|
||||
}
|
||||
modelConfig.MMProj = file.Filename
|
||||
break
|
||||
}
|
||||
|
||||
// Find first non-mmproj file and configure it in the config file
|
||||
for _, file := range cfg.Files {
|
||||
if strings.Contains(strings.ToLower(file.Filename), "mmproj") {
|
||||
continue
|
||||
}
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: file.Filename,
|
||||
},
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
cfg.ConfigFile = string(data)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
132
core/gallery/importers/llama-cpp_test.go
Normal file
132
core/gallery/importers/llama-cpp_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("LlamaCPPImporter", func() {
|
||||
var importer *LlamaCPPImporter
|
||||
|
||||
BeforeEach(func() {
|
||||
importer = &LlamaCPPImporter{}
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("should match when URI ends with .gguf", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/model.gguf",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when backend preference is llama-cpp", func() {
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should not match when URI does not end with .gguf and no backend preference", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/model.bin",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should not match when backend preference is different", func() {
|
||||
preferences := json.RawMessage(`{"backend": "mlx"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when JSON preferences are invalid", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
// Invalid JSON causes Match to return false early
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("should import model config with default name and description", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-model.gguf"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://example.com/my-model.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("should import model config with custom name and description from preferences", func() {
|
||||
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom description"))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("should handle invalid JSON preferences", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
_, err := importer.Import(details)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should extract filename correctly from URI with path", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/path/to/model.gguf",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/path/to/model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
})
|
||||
94
core/gallery/importers/mlx.go
Normal file
94
core/gallery/importers/mlx.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &MLXImporter{}
|
||||
|
||||
type MLXImporter struct{}
|
||||
|
||||
func (i *MLXImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok && b == "mlx" || b == "mlx-vlm" {
|
||||
return true
|
||||
}
|
||||
|
||||
// All https://huggingface.co/mlx-community/*
|
||||
if strings.Contains(details.URI, "mlx-community/") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
backend := "mlx"
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok {
|
||||
backend = b
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
},
|
||||
},
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
147
core/gallery/importers/mlx_test.go
Normal file
147
core/gallery/importers/mlx_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("MLXImporter", func() {
|
||||
var importer *importers.MLXImporter
|
||||
|
||||
BeforeEach(func() {
|
||||
importer = &importers.MLXImporter{}
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("should match when URI contains mlx-community/", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/test-model",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when backend preference is mlx", func() {
|
||||
preferences := json.RawMessage(`{"backend": "mlx"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when backend preference is mlx-vlm", func() {
|
||||
preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should not match when URI does not contain mlx-community/ and no backend preference", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/other-org/test-model",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should not match when backend preference is different", func() {
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when JSON preferences are invalid", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/test-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
// Invalid JSON causes Match to return false early
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("should import model config with default name and description", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/test-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("test-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mlx-community/test-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/mlx-community/test-model"))
|
||||
})
|
||||
|
||||
It("should import model config with custom name and description from preferences", func() {
|
||||
preferences := json.RawMessage(`{"name": "custom-mlx-model", "description": "Custom MLX description"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/test-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-mlx-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom MLX description"))
|
||||
})
|
||||
|
||||
It("should use custom backend from preferences", func() {
|
||||
preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/test-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx-vlm"))
|
||||
})
|
||||
|
||||
It("should handle invalid JSON preferences", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/test-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
_, err := importer.Import(details)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should extract filename correctly from URI with path", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/path/to/model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("model"))
|
||||
})
|
||||
})
|
||||
})
|
||||
110
core/gallery/importers/transformers.go
Normal file
110
core/gallery/importers/transformers.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &TransformersImporter{}
|
||||
|
||||
type TransformersImporter struct{}
|
||||
|
||||
func (i *TransformersImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok && b == "transformers" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
if strings.Contains(file.Path, "tokenizer.json") ||
|
||||
strings.Contains(file.Path, "tokenizer_config.json") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
backend := "transformers"
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok {
|
||||
backend = b
|
||||
}
|
||||
|
||||
modelType, ok := preferencesMap["type"].(string)
|
||||
if !ok {
|
||||
modelType = "AutoModelForCausalLM"
|
||||
}
|
||||
|
||||
quantization, ok := preferencesMap["quantization"].(string)
|
||||
if !ok {
|
||||
quantization = ""
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
},
|
||||
},
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
},
|
||||
}
|
||||
modelConfig.ModelType = modelType
|
||||
modelConfig.Quantization = quantization
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
219
core/gallery/importers/transformers_test.go
Normal file
219
core/gallery/importers/transformers_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("TransformersImporter", func() {
|
||||
var importer *TransformersImporter
|
||||
|
||||
BeforeEach(func() {
|
||||
importer = &TransformersImporter{}
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("should match when backend preference is transformers", func() {
|
||||
preferences := json.RawMessage(`{"backend": "transformers"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain tokenizer.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain tokenizer_config.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer_config.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should not match when URI has no tokenizer files and no backend preference", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/model.bin",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should not match when backend preference is different", func() {
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when JSON preferences are invalid", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("should import model config with default name and description", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
|
||||
})
|
||||
|
||||
It("should import model config with custom name and description from preferences", func() {
|
||||
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom description"))
|
||||
})
|
||||
|
||||
It("should use custom model type from preferences", func() {
|
||||
preferences := json.RawMessage(`{"type": "SentenceTransformer"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: SentenceTransformer"))
|
||||
})
|
||||
|
||||
It("should use default model type when not specified", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
|
||||
})
|
||||
|
||||
It("should use custom backend from preferences", func() {
|
||||
preferences := json.RawMessage(`{"backend": "transformers"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
|
||||
})
|
||||
|
||||
It("should use quantization from preferences", func() {
|
||||
preferences := json.RawMessage(`{"quantization": "int8"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("quantization: int8"))
|
||||
})
|
||||
|
||||
It("should handle invalid JSON preferences", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
_, err := importer.Import(details)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should extract filename correctly from URI with path", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/test/path/to/model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("model"))
|
||||
})
|
||||
|
||||
It("should include use_tokenizer_template in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
|
||||
})
|
||||
|
||||
It("should include known_usecases in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
|
||||
})
|
||||
})
|
||||
})
|
||||
98
core/gallery/importers/vllm.go
Normal file
98
core/gallery/importers/vllm.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &VLLMImporter{}
|
||||
|
||||
type VLLMImporter struct{}
|
||||
|
||||
func (i *VLLMImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok && b == "vllm" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
if strings.Contains(file.Path, "tokenizer.json") ||
|
||||
strings.Contains(file.Path, "tokenizer_config.json") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
backend := "vllm"
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok {
|
||||
backend = b
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
},
|
||||
},
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
181
core/gallery/importers/vllm_test.go
Normal file
181
core/gallery/importers/vllm_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("VLLMImporter", func() {
|
||||
var importer *VLLMImporter
|
||||
|
||||
BeforeEach(func() {
|
||||
importer = &VLLMImporter{}
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("should match when backend preference is vllm", func() {
|
||||
preferences := json.RawMessage(`{"backend": "vllm"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain tokenizer.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain tokenizer_config.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer_config.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should not match when URI has no tokenizer files and no backend preference", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/model.bin",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should not match when backend preference is different", func() {
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when JSON preferences are invalid", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("should import model config with default name and description", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
|
||||
})
|
||||
|
||||
It("should import model config with custom name and description from preferences", func() {
|
||||
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom description"))
|
||||
})
|
||||
|
||||
It("should use custom backend from preferences", func() {
|
||||
preferences := json.RawMessage(`{"backend": "vllm"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
|
||||
})
|
||||
|
||||
It("should handle invalid JSON preferences", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
_, err := importer.Import(details)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should extract filename correctly from URI with path", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/test/path/to/model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("model"))
|
||||
})
|
||||
|
||||
It("should include use_tokenizer_template in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
|
||||
})
|
||||
|
||||
It("should include known_usecases in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,14 +1,15 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
lconfig "github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -16,7 +17,7 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/yaml.v2"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
/*
|
||||
@@ -72,7 +73,8 @@ type PromptTemplate struct {
|
||||
|
||||
// Installs a model from the gallery
|
||||
func InstallModelFromGallery(
|
||||
modelGalleries, backendGalleries []config.Gallery,
|
||||
ctx context.Context,
|
||||
modelGalleries, backendGalleries []lconfig.Gallery,
|
||||
systemState *system.SystemState,
|
||||
modelLoader *model.ModelLoader,
|
||||
name string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error {
|
||||
@@ -84,7 +86,7 @@ func InstallModelFromGallery(
|
||||
|
||||
if len(model.URL) > 0 {
|
||||
var err error
|
||||
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath)
|
||||
config, err = GetGalleryConfigFromURLWithContext[ModelConfig](ctx, model.URL, systemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -125,7 +127,7 @@ func InstallModelFromGallery(
|
||||
return err
|
||||
}
|
||||
|
||||
installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
|
||||
installedModel, err := InstallModel(ctx, systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -133,7 +135,7 @@ func InstallModelFromGallery(
|
||||
if automaticallyInstallBackend && installedModel.Backend != "" {
|
||||
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
|
||||
|
||||
if err := InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
|
||||
if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -154,7 +156,7 @@ func InstallModelFromGallery(
|
||||
return applyModel(model)
|
||||
}
|
||||
|
||||
func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
||||
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
||||
basePath := systemState.Model.ModelsPath
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(basePath, 0750)
|
||||
@@ -168,6 +170,13 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
|
||||
|
||||
// Download files and verify their SHA
|
||||
for i, file := range config.Files {
|
||||
// Check for cancellation before each file
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
|
||||
|
||||
if err := utils.VerifyPath(file.Filename, basePath); err != nil {
|
||||
@@ -185,7 +194,7 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
|
||||
}
|
||||
}
|
||||
uri := downloader.URI(file.URI)
|
||||
if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
|
||||
if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -251,8 +260,8 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
|
||||
return nil, fmt.Errorf("failed to unmarshal updated config YAML: %v", err)
|
||||
}
|
||||
|
||||
if !modelConfig.Validate() {
|
||||
return nil, fmt.Errorf("failed to validate updated config YAML")
|
||||
if valid, err := modelConfig.Validate(); !valid {
|
||||
return nil, fmt.Errorf("failed to validate updated config YAML: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(configFilePath, updatedConfigYAML, 0600)
|
||||
@@ -285,21 +294,32 @@ func GetLocalModelConfiguration(basePath string, name string) (*ModelConfig, err
|
||||
return ReadConfigFile[ModelConfig](galleryFile)
|
||||
}
|
||||
|
||||
func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
|
||||
additionalFiles := []string{}
|
||||
func listModelFiles(systemState *system.SystemState, name string) ([]string, error) {
|
||||
|
||||
configFile := filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", name))
|
||||
if err := utils.VerifyPath(configFile, systemState.Model.ModelsPath); err != nil {
|
||||
return fmt.Errorf("failed to verify path %s: %w", configFile, err)
|
||||
return nil, fmt.Errorf("failed to verify path %s: %w", configFile, err)
|
||||
}
|
||||
|
||||
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
galleryFile := filepath.Join(systemState.Model.ModelsPath, galleryFileName(name))
|
||||
if err := utils.VerifyPath(galleryFile, systemState.Model.ModelsPath); err != nil {
|
||||
return nil, fmt.Errorf("failed to verify path %s: %w", galleryFile, err)
|
||||
}
|
||||
|
||||
additionalFiles := []string{}
|
||||
allFiles := []string{}
|
||||
|
||||
// Galleryname is the name of the model in this case
|
||||
dat, err := os.ReadFile(configFile)
|
||||
if err == nil {
|
||||
modelConfig := &config.ModelConfig{}
|
||||
modelConfig := &lconfig.ModelConfig{}
|
||||
|
||||
err = yaml.Unmarshal(dat, &modelConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if modelConfig.Model != "" {
|
||||
additionalFiles = append(additionalFiles, modelConfig.ModelFileName())
|
||||
@@ -310,26 +330,15 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
|
||||
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
|
||||
|
||||
galleryFile := filepath.Join(systemState.Model.ModelsPath, galleryFileName(name))
|
||||
if err := utils.VerifyPath(galleryFile, systemState.Model.ModelsPath); err != nil {
|
||||
return fmt.Errorf("failed to verify path %s: %w", galleryFile, err)
|
||||
}
|
||||
|
||||
var filesToRemove []string
|
||||
|
||||
// Delete all the files associated to the model
|
||||
// read the model config
|
||||
galleryconfig, err := ReadConfigFile[ModelConfig](galleryFile)
|
||||
if err == nil && galleryconfig != nil {
|
||||
for _, f := range galleryconfig.Files {
|
||||
fullPath := filepath.Join(systemState.Model.ModelsPath, f.Filename)
|
||||
if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil {
|
||||
return fmt.Errorf("failed to verify path %s: %w", fullPath, err)
|
||||
return allFiles, fmt.Errorf("failed to verify path %s: %w", fullPath, err)
|
||||
}
|
||||
filesToRemove = append(filesToRemove, fullPath)
|
||||
allFiles = append(allFiles, fullPath)
|
||||
}
|
||||
} else {
|
||||
log.Error().Err(err).Msgf("failed to read gallery file %s", configFile)
|
||||
@@ -338,18 +347,68 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
|
||||
for _, f := range additionalFiles {
|
||||
fullPath := filepath.Join(filepath.Join(systemState.Model.ModelsPath, f))
|
||||
if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil {
|
||||
return fmt.Errorf("failed to verify path %s: %w", fullPath, err)
|
||||
return allFiles, fmt.Errorf("failed to verify path %s: %w", fullPath, err)
|
||||
}
|
||||
filesToRemove = append(filesToRemove, fullPath)
|
||||
allFiles = append(allFiles, fullPath)
|
||||
}
|
||||
|
||||
filesToRemove = append(filesToRemove, galleryFile)
|
||||
allFiles = append(allFiles, galleryFile)
|
||||
|
||||
// skip duplicates
|
||||
filesToRemove = utils.Unique(filesToRemove)
|
||||
allFiles = utils.Unique(allFiles)
|
||||
|
||||
return allFiles, nil
|
||||
}
|
||||
|
||||
func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
|
||||
configFile := filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", name))
|
||||
|
||||
filesToRemove, err := listModelFiles(systemState, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allOtherFiles := []string{}
|
||||
// Get all files of all other models
|
||||
fi, err := os.ReadDir(systemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, f := range fi {
|
||||
if f.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(f.Name(), "._gallery_") {
|
||||
continue
|
||||
}
|
||||
if !strings.HasSuffix(f.Name(), ".yaml") && !strings.HasSuffix(f.Name(), ".yml") {
|
||||
continue
|
||||
}
|
||||
if f.Name() == fmt.Sprintf("%s.yaml", name) || f.Name() == fmt.Sprintf("%s.yml", name) {
|
||||
continue
|
||||
}
|
||||
|
||||
name := strings.TrimSuffix(f.Name(), ".yaml")
|
||||
name = strings.TrimSuffix(name, ".yml")
|
||||
|
||||
log.Debug().Msgf("Checking file %s", f.Name())
|
||||
files, err := listModelFiles(systemState, name)
|
||||
if err != nil {
|
||||
log.Debug().Err(err).Msgf("failed to list files for model %s", f.Name())
|
||||
continue
|
||||
}
|
||||
allOtherFiles = append(allOtherFiles, files...)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Files to remove: %+v", filesToRemove)
|
||||
log.Debug().Msgf("All other files: %+v", allOtherFiles)
|
||||
|
||||
// Removing files
|
||||
for _, f := range filesToRemove {
|
||||
if slices.Contains(allOtherFiles, f) {
|
||||
log.Debug().Msgf("Skipping file %s because it is part of another model", f)
|
||||
continue
|
||||
}
|
||||
if e := os.Remove(f); e != nil {
|
||||
log.Error().Err(e).Msgf("failed to remove file %s", f)
|
||||
}
|
||||
@@ -360,7 +419,7 @@ func DeleteModelFromSystem(systemState *system.SystemState, name string) error {
|
||||
|
||||
// This is ***NEVER*** going to be perfect or finished.
|
||||
// This is a BEST EFFORT function to surface known-vulnerable models to users.
|
||||
func SafetyScanGalleryModels(galleries []config.Gallery, systemState *system.SystemState) error {
|
||||
func SafetyScanGalleryModels(galleries []lconfig.Gallery, systemState *system.SystemState) error {
|
||||
galleryModels, err := AvailableGalleryModels(galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -34,7 +35,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
|
||||
@@ -88,7 +89,7 @@ var _ = Describe("Model test", func() {
|
||||
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
|
||||
Expect(models[0].Installed).To(BeFalse())
|
||||
|
||||
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
|
||||
err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
|
||||
@@ -129,7 +130,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||
@@ -149,7 +150,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||
@@ -179,8 +180,101 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("does not delete shared model files when one config is deleted", func() {
|
||||
tempdir, err := os.MkdirTemp("", "test")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
defer os.RemoveAll(tempdir)
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create a shared model file
|
||||
sharedModelFile := filepath.Join(tempdir, "shared_model.bin")
|
||||
err = os.WriteFile(sharedModelFile, []byte("fake model content"), 0600)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create first model configuration
|
||||
config1 := `name: model1
|
||||
model: shared_model.bin`
|
||||
err = os.WriteFile(filepath.Join(tempdir, "model1.yaml"), []byte(config1), 0600)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create first model's gallery file
|
||||
galleryConfig1 := ModelConfig{
|
||||
Name: "model1",
|
||||
Files: []File{
|
||||
{Filename: "shared_model.bin"},
|
||||
},
|
||||
}
|
||||
galleryData1, err := yaml.Marshal(galleryConfig1)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = os.WriteFile(filepath.Join(tempdir, "._gallery_model1.yaml"), galleryData1, 0600)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create second model configuration sharing the same model file
|
||||
config2 := `name: model2
|
||||
model: shared_model.bin`
|
||||
err = os.WriteFile(filepath.Join(tempdir, "model2.yaml"), []byte(config2), 0600)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Create second model's gallery file
|
||||
galleryConfig2 := ModelConfig{
|
||||
Name: "model2",
|
||||
Files: []File{
|
||||
{Filename: "shared_model.bin"},
|
||||
},
|
||||
}
|
||||
galleryData2, err := yaml.Marshal(galleryConfig2)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
err = os.WriteFile(filepath.Join(tempdir, "._gallery_model2.yaml"), galleryData2, 0600)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify both configurations exist
|
||||
_, err = os.Stat(filepath.Join(tempdir, "model1.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = os.Stat(filepath.Join(tempdir, "model2.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify the shared model file exists
|
||||
_, err = os.Stat(sharedModelFile)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Delete the first model
|
||||
err = DeleteModelFromSystem(systemState, "model1")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify the first configuration is deleted
|
||||
_, err = os.Stat(filepath.Join(tempdir, "model1.yaml"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue())
|
||||
|
||||
// Verify the shared model file still exists (not deleted because model2 still uses it)
|
||||
_, err = os.Stat(sharedModelFile)
|
||||
Expect(err).ToNot(HaveOccurred(), "shared model file should not be deleted when used by other configs")
|
||||
|
||||
// Verify the second configuration still exists
|
||||
_, err = os.Stat(filepath.Join(tempdir, "model2.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Now delete the second model
|
||||
err = DeleteModelFromSystem(systemState, "model2")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// Verify the second configuration is deleted
|
||||
_, err = os.Stat(filepath.Join(tempdir, "model2.yaml"))
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue())
|
||||
|
||||
// Verify the shared model file is now deleted (no more references)
|
||||
_, err = os.Stat(sharedModelFile)
|
||||
Expect(err).To(HaveOccurred(), "shared model file should be deleted when no configs reference it")
|
||||
Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
238
core/http/app.go
238
core/http/app.go
@@ -4,30 +4,23 @@ import (
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/dave-gray101/v2keyauth"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/http/routes"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
|
||||
"github.com/gofiber/contrib/fiberzerolog"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/csrf"
|
||||
"github.com/gofiber/fiber/v2/middleware/favicon"
|
||||
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
|
||||
// swagger handler
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -49,86 +42,85 @@ var embedDirStatic embed.FS
|
||||
// @in header
|
||||
// @name Authorization
|
||||
|
||||
func API(application *application.Application) (*fiber.App, error) {
|
||||
func API(application *application.Application) (*echo.Echo, error) {
|
||||
e := echo.New()
|
||||
|
||||
fiberCfg := fiber.Config{
|
||||
Views: renderEngine(),
|
||||
BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
// We disable the Fiber startup message as it does not conform to structured logging.
|
||||
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
||||
DisableStartupMessage: true,
|
||||
// Override default error handler
|
||||
// Set body limit
|
||||
if application.ApplicationConfig().UploadLimitMB > 0 {
|
||||
e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", application.ApplicationConfig().UploadLimitMB)))
|
||||
}
|
||||
|
||||
// Set error handler
|
||||
if !application.ApplicationConfig().OpaqueErrors {
|
||||
// Normally, return errors as JSON responses
|
||||
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
|
||||
// Status code defaults to 500
|
||||
code := fiber.StatusInternalServerError
|
||||
e.HTTPErrorHandler = func(err error, c echo.Context) {
|
||||
code := http.StatusInternalServerError
|
||||
var he *echo.HTTPError
|
||||
if errors.As(err, &he) {
|
||||
code = he.Code
|
||||
}
|
||||
|
||||
// Retrieve the custom status code if it's a *fiber.Error
|
||||
var e *fiber.Error
|
||||
if errors.As(err, &e) {
|
||||
code = e.Code
|
||||
// Handle 404 errors with HTML rendering when appropriate
|
||||
if code == http.StatusNotFound {
|
||||
notFoundHandler(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Send custom error page
|
||||
return ctx.Status(code).JSON(
|
||||
schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: err.Error(), Code: code},
|
||||
},
|
||||
)
|
||||
c.JSON(code, schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: err.Error(), Code: code},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// If OpaqueErrors are required, replace everything with a blank 500.
|
||||
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, _ error) error {
|
||||
return ctx.Status(500).SendString("")
|
||||
e.HTTPErrorHandler = func(err error, c echo.Context) {
|
||||
code := http.StatusInternalServerError
|
||||
var he *echo.HTTPError
|
||||
if errors.As(err, &he) {
|
||||
code = he.Code
|
||||
}
|
||||
c.NoContent(code)
|
||||
}
|
||||
}
|
||||
|
||||
router := fiber.New(fiberCfg)
|
||||
// Set renderer
|
||||
e.Renderer = renderEngine()
|
||||
|
||||
router.Use(middleware.StripPathPrefix())
|
||||
// Hide banner
|
||||
e.HideBanner = true
|
||||
|
||||
// Middleware - StripPathPrefix must be registered early as it uses Rewrite which runs before routing
|
||||
e.Pre(httpMiddleware.StripPathPrefix())
|
||||
|
||||
if application.ApplicationConfig().MachineTag != "" {
|
||||
router.Use(func(c *fiber.Ctx) error {
|
||||
c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag)
|
||||
|
||||
return c.Next()
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Response().Header().Set("Machine-Tag", application.ApplicationConfig().MachineTag)
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
router.Use("/v1/realtime", func(c *fiber.Ctx) error {
|
||||
if websocket.IsWebSocketUpgrade(c) {
|
||||
// Returns true if the client requested upgrade to the WebSocket protocol
|
||||
return c.Next()
|
||||
// Custom logger middleware using zerolog
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
req := c.Request()
|
||||
res := c.Response()
|
||||
start := log.Logger.Info()
|
||||
err := next(c)
|
||||
start.
|
||||
Str("method", req.Method).
|
||||
Str("path", req.URL.Path).
|
||||
Int("status", res.Status).
|
||||
Msg("HTTP request")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
|
||||
scheme := "http"
|
||||
if listenData.TLS {
|
||||
scheme = "https"
|
||||
}
|
||||
log.Info().Str("endpoint", scheme+"://"+listenData.Host+":"+listenData.Port).Msg("LocalAI API is listening! Please connect to the endpoint for API documentation.")
|
||||
return nil
|
||||
})
|
||||
|
||||
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
|
||||
logger := log.Logger
|
||||
router.Use(fiberzerolog.New(fiberzerolog.Config{
|
||||
Logger: &logger,
|
||||
}))
|
||||
|
||||
// Default middleware config
|
||||
|
||||
// Recover middleware
|
||||
if !application.ApplicationConfig().Debug {
|
||||
router.Use(recover.New())
|
||||
e.Use(middleware.Recover())
|
||||
}
|
||||
|
||||
// OpenTelemetry metrics for Prometheus export
|
||||
// Metrics middleware
|
||||
if !application.ApplicationConfig().DisableMetrics {
|
||||
metricsService, err := services.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
@@ -136,35 +128,40 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||
}
|
||||
|
||||
if metricsService != nil {
|
||||
router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
router.Hooks().OnShutdown(func() error {
|
||||
return metricsService.Shutdown()
|
||||
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
metricsService.Shutdown()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Health Checks should always be exempt from auth, so register these first
|
||||
routes.HealthRoutes(router)
|
||||
routes.HealthRoutes(e)
|
||||
|
||||
kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
|
||||
if err != nil || kaConfig == nil {
|
||||
// Get key auth middleware
|
||||
keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create key auth config: %w", err)
|
||||
}
|
||||
|
||||
httpFS := http.FS(embedDirStatic)
|
||||
// Favicon handler
|
||||
e.GET("/favicon.svg", func(c echo.Context) error {
|
||||
data, err := embedDirStatic.ReadFile("static/favicon.svg")
|
||||
if err != nil {
|
||||
return c.NoContent(http.StatusNotFound)
|
||||
}
|
||||
c.Response().Header().Set("Content-Type", "image/svg+xml")
|
||||
return c.Blob(http.StatusOK, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
router.Use(favicon.New(favicon.Config{
|
||||
URL: "/favicon.svg",
|
||||
FileSystem: httpFS,
|
||||
File: "static/favicon.svg",
|
||||
}))
|
||||
|
||||
router.Use("/static", filesystem.New(filesystem.Config{
|
||||
Root: httpFS,
|
||||
PathPrefix: "static",
|
||||
Browse: true,
|
||||
}))
|
||||
// Static files - use fs.Sub to create a filesystem rooted at "static"
|
||||
staticFS, err := fs.Sub(embedDirStatic, "static")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create static filesystem: %w", err)
|
||||
}
|
||||
e.StaticFS("/static", staticFS)
|
||||
|
||||
// Generated content directories
|
||||
if application.ApplicationConfig().GeneratedContentDir != "" {
|
||||
os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)
|
||||
audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
|
||||
@@ -175,62 +172,53 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||
os.MkdirAll(imagePath, 0750)
|
||||
os.MkdirAll(videoPath, 0750)
|
||||
|
||||
router.Static("/generated-audio", audioPath)
|
||||
router.Static("/generated-images", imagePath)
|
||||
router.Static("/generated-videos", videoPath)
|
||||
e.Static("/generated-audio", audioPath)
|
||||
e.Static("/generated-images", imagePath)
|
||||
e.Static("/generated-videos", videoPath)
|
||||
}
|
||||
|
||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
|
||||
router.Use(v2keyauth.New(*kaConfig))
|
||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration
|
||||
e.Use(keyAuthMiddleware)
|
||||
|
||||
// CORS middleware
|
||||
if application.ApplicationConfig().CORS {
|
||||
var c func(ctx *fiber.Ctx) error
|
||||
if application.ApplicationConfig().CORSAllowOrigins == "" {
|
||||
c = cors.New()
|
||||
} else {
|
||||
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
|
||||
corsConfig := middleware.CORSConfig{}
|
||||
if application.ApplicationConfig().CORSAllowOrigins != "" {
|
||||
corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",")
|
||||
}
|
||||
|
||||
router.Use(c)
|
||||
e.Use(middleware.CORSWithConfig(corsConfig))
|
||||
}
|
||||
|
||||
// CSRF middleware
|
||||
if application.ApplicationConfig().CSRF {
|
||||
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
|
||||
router.Use(csrf.New())
|
||||
e.Use(middleware.CSRF())
|
||||
}
|
||||
|
||||
requestExtractor := middleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
routes.RegisterElevenLabsRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
routes.RegisterLocalAIRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
|
||||
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
|
||||
var opcache *services.OpCache
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
|
||||
// Create metrics store for tracking usage (before API routes registration)
|
||||
metricsStore := services.NewInMemoryMetricsStore()
|
||||
|
||||
// Add metrics middleware BEFORE API routes so it can intercept them
|
||||
router.Use(middleware.MetricsMiddleware(metricsStore))
|
||||
|
||||
// Register cleanup on shutdown
|
||||
router.Hooks().OnShutdown(func() error {
|
||||
metricsStore.Stop()
|
||||
log.Info().Msg("Metrics store stopped")
|
||||
return nil
|
||||
})
|
||||
|
||||
// Create opcache for tracking UI operations
|
||||
opcache := services.NewOpCache(application.GalleryService())
|
||||
routes.RegisterUIAPIRoutes(router, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, metricsStore)
|
||||
routes.RegisterUIRoutes(router, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
opcache = services.NewOpCache(application.GalleryService())
|
||||
}
|
||||
|
||||
routes.RegisterJINARoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator())
|
||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
}
|
||||
routes.RegisterJINARoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Define a custom 404 handler
|
||||
// Note: keep this at the bottom!
|
||||
router.Use(notFoundHandler)
|
||||
// Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route
|
||||
|
||||
return router, nil
|
||||
// Log startup message
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
log.Info().Msg("LocalAI API server shutting down")
|
||||
})
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
@@ -10,13 +10,14 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
openaigo "github.com/otiai10/openaigo"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
@@ -85,7 +87,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
|
||||
response := []gallery.GalleryModel{}
|
||||
uri := downloader.URI(url)
|
||||
// TODO: No tests currently seem to exercise file:// urls. Fix?
|
||||
err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
|
||||
err := uri.ReadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error {
|
||||
// Unmarshal YAML data into a struct
|
||||
return json.Unmarshal(i, &response)
|
||||
})
|
||||
@@ -266,7 +268,7 @@ const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b8
|
||||
|
||||
var _ = Describe("API test", func() {
|
||||
|
||||
var app *fiber.App
|
||||
var app *echo.Echo
|
||||
var client *openai.Client
|
||||
var client2 *openaigo.Client
|
||||
var c context.Context
|
||||
@@ -339,7 +341,11 @@ var _ = Describe("API test", func() {
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig(apiKey)
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -358,7 +364,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func(sc SpecContext) {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := os.RemoveAll(tmpdir)
|
||||
@@ -505,6 +513,124 @@ var _ = Describe("API test", func() {
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
Context("Importing models from URI", func() {
|
||||
var testYamlFile string
|
||||
|
||||
BeforeEach(func() {
|
||||
// Create a test YAML config file
|
||||
yamlContent := `name: test-import-model
|
||||
backend: llama-cpp
|
||||
description: Test model imported from file URI
|
||||
parameters:
|
||||
model: path/to/model.gguf
|
||||
temperature: 0.7
|
||||
`
|
||||
testYamlFile = filepath.Join(tmpdir, "test-import.yaml")
|
||||
err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
err := os.Remove(testYamlFile)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should import model from file:// URI pointing to local YAML config", func() {
|
||||
importReq := schema.ImportModelRequest{
|
||||
URI: "file://" + testYamlFile,
|
||||
Preferences: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
var response schema.GalleryResponse
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response.ID).ToNot(BeEmpty())
|
||||
|
||||
uuid := response.ID
|
||||
resp := map[string]interface{}{}
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
resp = response
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
// Check that the model was imported successfully
|
||||
Expect(resp["message"]).ToNot(ContainSubstring("error"))
|
||||
Expect(resp["error"]).To(BeNil())
|
||||
|
||||
// Verify the model config file was created
|
||||
dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
content := map[string]interface{}{}
|
||||
err = yaml.Unmarshal(dat, &content)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(content["name"]).To(Equal("test-import-model"))
|
||||
Expect(content["backend"]).To(Equal("llama-cpp"))
|
||||
})
|
||||
|
||||
It("should return error when file:// URI points to non-existent file", func() {
|
||||
nonExistentFile := filepath.Join(tmpdir, "nonexistent.yaml")
|
||||
importReq := schema.ImportModelRequest{
|
||||
URI: "file://" + nonExistentFile,
|
||||
Preferences: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
var response schema.GalleryResponse
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
|
||||
// The endpoint should return an error immediately
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("failed to discover model config"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Importing models from URI can't point to absolute paths", func() {
|
||||
var testYamlFile string
|
||||
|
||||
BeforeEach(func() {
|
||||
// Create a test YAML config file
|
||||
yamlContent := `name: test-import-model
|
||||
backend: llama-cpp
|
||||
description: Test model imported from file URI
|
||||
parameters:
|
||||
model: /path/to/model.gguf
|
||||
temperature: 0.7
|
||||
`
|
||||
testYamlFile = filepath.Join(tmpdir, "test-import.yaml")
|
||||
err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
err := os.Remove(testYamlFile)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should fail to import model from file:// URI pointing to local YAML config", func() {
|
||||
importReq := schema.ImportModelRequest{
|
||||
URI: "file://" + testYamlFile,
|
||||
Preferences: json.RawMessage(`{}`),
|
||||
}
|
||||
|
||||
var response schema.GalleryResponse
|
||||
err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(response.ID).ToNot(BeEmpty())
|
||||
|
||||
uuid := response.ID
|
||||
resp := map[string]interface{}{}
|
||||
Eventually(func() bool {
|
||||
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
|
||||
resp = response
|
||||
return response["processed"].(bool)
|
||||
}, "360s", "10s").Should(Equal(true))
|
||||
|
||||
// Check that the model was imported successfully
|
||||
Expect(resp["message"]).To(ContainSubstring("error"))
|
||||
Expect(resp["error"]).ToNot(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Context("Model gallery", func() {
|
||||
@@ -547,7 +673,11 @@ var _ = Describe("API test", func() {
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -566,7 +696,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := os.RemoveAll(tmpdir)
|
||||
@@ -755,7 +887,11 @@ var _ = Describe("API test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -773,7 +909,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
@@ -796,6 +934,83 @@ var _ = Describe("API test", func() {
|
||||
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns logprobs in chat completions when requested", func() {
|
||||
topLogprobsVal := 3
|
||||
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
|
||||
Model: "testmodel.ggml",
|
||||
LogProbs: true,
|
||||
TopLogProbs: topLogprobsVal,
|
||||
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(len(response.Choices)).To(Equal(1))
|
||||
Expect(response.Choices[0].Message).ToNot(BeNil())
|
||||
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
|
||||
// Verify logprobs are present and have correct structure
|
||||
Expect(response.Choices[0].LogProbs).ToNot(BeNil())
|
||||
Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty())
|
||||
|
||||
Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1))
|
||||
|
||||
foundatLeastToken := ""
|
||||
foundAtLeastBytes := []byte{}
|
||||
foundAtLeastTopLogprobBytes := []byte{}
|
||||
foundatLeastTopLogprob := ""
|
||||
// Verify logprobs content structure matches OpenAI format
|
||||
for _, logprobContent := range response.Choices[0].LogProbs.Content {
|
||||
// Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it
|
||||
if len(logprobContent.Bytes) > 0 {
|
||||
foundAtLeastBytes = logprobContent.Bytes
|
||||
}
|
||||
if len(logprobContent.Token) > 0 {
|
||||
foundatLeastToken = logprobContent.Token
|
||||
}
|
||||
Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0
|
||||
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1))
|
||||
|
||||
// If top_logprobs is requested, verify top_logprobs array respects the limit
|
||||
if len(logprobContent.TopLogProbs) > 0 {
|
||||
// Should respect top_logprobs limit (3 in this test)
|
||||
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal))
|
||||
for _, topLogprob := range logprobContent.TopLogProbs {
|
||||
if len(topLogprob.Bytes) > 0 {
|
||||
foundAtLeastTopLogprobBytes = topLogprob.Bytes
|
||||
}
|
||||
if len(topLogprob.Token) > 0 {
|
||||
foundatLeastTopLogprob = topLogprob.Token
|
||||
}
|
||||
Expect(topLogprob.LogProb).To(BeNumerically("<=", 0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Expect(foundAtLeastBytes).ToNot(BeEmpty())
|
||||
Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty())
|
||||
Expect(foundatLeastToken).ToNot(BeEmpty())
|
||||
Expect(foundatLeastTopLogprob).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("applies logit_bias to chat completions when requested", func() {
|
||||
// logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
|
||||
// According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion
|
||||
logitBias := map[string]int{
|
||||
"15043": 1, // Bias token ID 15043 (example token ID) with bias value 1
|
||||
}
|
||||
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
|
||||
Model: "testmodel.ggml",
|
||||
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
|
||||
LogitBias: logitBias,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(response.Choices)).To(Equal(1))
|
||||
Expect(response.Choices[0].Message).ToNot(BeNil())
|
||||
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
// If logit_bias is applied, the response should be generated successfully
|
||||
// We can't easily verify the bias effect without knowing the actual token IDs for the model,
|
||||
// but the fact that the request succeeds confirms the API accepts and processes logit_bias
|
||||
})
|
||||
|
||||
It("returns errors", func() {
|
||||
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -984,6 +1199,9 @@ var _ = Describe("API test", func() {
|
||||
|
||||
Context("Config file", func() {
|
||||
BeforeEach(func() {
|
||||
if runtime.GOOS != "linux" {
|
||||
Skip("run this test only on linux")
|
||||
}
|
||||
modelPath := os.Getenv("MODELS_PATH")
|
||||
backendPath := os.Getenv("BACKENDS_PATH")
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
@@ -1006,7 +1224,11 @@ var _ = Describe("API test", func() {
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -1022,7 +1244,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -15,17 +17,17 @@ import (
|
||||
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/sound-generation [post]
|
||||
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
|
||||
if !ok || input.ModelID == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
|
||||
@@ -35,7 +37,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -17,19 +18,19 @@ import (
|
||||
// @Param request body schema.TTSRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/text-to-speech/{voice-id} [post]
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
voiceID := c.Params("voice-id")
|
||||
voiceID := c.Param("voice-id")
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
|
||||
if !ok || input.ModelID == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request received")
|
||||
@@ -38,6 +39,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,28 +2,32 @@ package explorer
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/explorer"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
)
|
||||
|
||||
func Dashboard() func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
summary := fiber.Map{
|
||||
func Dashboard() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI API - " + internal.PrintableVersion(),
|
||||
"Version": internal.PrintableVersion(),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
}
|
||||
|
||||
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
accept := c.Request().Header.Get("Accept")
|
||||
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "html")) {
|
||||
// The client expects a JSON response
|
||||
return c.Status(fiber.StatusOK).JSON(summary)
|
||||
return c.JSON(http.StatusOK, summary)
|
||||
} else {
|
||||
// Render index
|
||||
return c.Render("views/explorer", summary)
|
||||
return c.Render(http.StatusOK, "views/explorer", summary)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -39,8 +43,8 @@ type Network struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
results := []Network{}
|
||||
for _, token := range db.TokenList() {
|
||||
networkData, exists := db.Get(token) // get the token data
|
||||
@@ -61,44 +65,44 @@ func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
|
||||
return len(results[i].Clusters) > len(results[j].Clusters)
|
||||
})
|
||||
|
||||
return c.JSON(results)
|
||||
return c.JSON(http.StatusOK, results)
|
||||
}
|
||||
}
|
||||
|
||||
func AddNetwork(db *explorer.Database) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func AddNetwork(db *explorer.Database) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
request := new(AddNetworkRequest)
|
||||
if err := c.BodyParser(request); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||
if err := c.Bind(request); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"})
|
||||
}
|
||||
|
||||
if request.Token == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"})
|
||||
}
|
||||
|
||||
if request.Name == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Name is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"})
|
||||
}
|
||||
|
||||
if request.Description == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Description is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"})
|
||||
}
|
||||
|
||||
// TODO: check if token is valid, otherwise reject
|
||||
// try to decode the token from base64
|
||||
_, err := base64.StdEncoding.DecodeString(request.Token)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid token"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"})
|
||||
}
|
||||
|
||||
if _, exists := db.Get(request.Token); exists {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token already exists"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"})
|
||||
}
|
||||
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Cannot add token"})
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"})
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Token added"})
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package jina
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -17,24 +18,36 @@ import (
|
||||
// @Param request body schema.JINARerankRequest true "query params"
|
||||
// @Success 200 {object} schema.JINARerankResponse "Response"
|
||||
// @Router /v1/rerank [post]
|
||||
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received")
|
||||
|
||||
var requestTopN int32
|
||||
docs := int32(len(input.Documents))
|
||||
if input.TopN == nil { // omit top_n to get all
|
||||
requestTopN = docs
|
||||
} else {
|
||||
requestTopN = int32(*input.TopN)
|
||||
if requestTopN < 1 {
|
||||
return c.JSON(http.StatusUnprocessableEntity, "top_n - should be greater than or equal to 1")
|
||||
}
|
||||
if requestTopN > docs { // make it more obvious for backends
|
||||
requestTopN = docs
|
||||
}
|
||||
}
|
||||
request := &proto.RerankRequest{
|
||||
Query: input.Query,
|
||||
TopN: int32(input.TopN),
|
||||
TopN: requestTopN,
|
||||
Documents: input.Documents,
|
||||
}
|
||||
|
||||
@@ -58,6 +71,6 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
response.Usage.TotalTokens = int(results.Usage.TotalTokens)
|
||||
response.Usage.PromptTokens = int(results.Usage.PromptTokens)
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(response)
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -39,13 +39,13 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste
|
||||
// @Summary Returns the job status
|
||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
||||
// @Router /backends/jobs/{uuid} [get]
|
||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
status := mgs.backendApplier.GetStatus(c.Params("uuid"))
|
||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
status := mgs.backendApplier.GetStatus(c.Param("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(status)
|
||||
return c.JSON(200, status)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,9 +53,9 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) erro
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
||||
// @Router /backends/jobs [get]
|
||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(mgs.backendApplier.GetAllStatus())
|
||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, mgs.backendApplier.GetAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,11 +64,11 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) err
|
||||
// @Param request body GalleryBackend true "query params"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/apply [post]
|
||||
func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(GalleryBackend)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -76,13 +76,13 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend]{
|
||||
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: input.ID,
|
||||
Galleries: mgs.galleries,
|
||||
}
|
||||
|
||||
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,11 +91,11 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err
|
||||
// @Param name path string true "Backend name"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/delete/{name} [post]
|
||||
func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backendName := c.Params("name")
|
||||
func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backendName := c.Param("name")
|
||||
|
||||
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend]{
|
||||
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
|
||||
Delete: true,
|
||||
GalleryElementName: backendName,
|
||||
Galleries: mgs.galleries,
|
||||
@@ -106,7 +106,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,13 +114,13 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
|
||||
// @Summary List all Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends [get]
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(backends.GetAll())
|
||||
return c.JSON(200, backends.GetAll())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,14 +129,14 @@ func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.Syst
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /backends/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
log.Debug().Msgf("Listing backend galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
return c.Blob(200, "application/json", dat)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,12 +144,12 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.
|
||||
// @Summary List all available Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends/available [get]
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(backends)
|
||||
return c.JSON(200, backends)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,45 +1,45 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// BackendMonitorEndpoint returns the status of the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Success 200 {object} proto.StatusResponse "Response"
|
||||
// @Router /backend/monitor [get]
|
||||
func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := bm.CheckAndSample(input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
// BackendShutdownEndpoint shuts down the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Router /backend/shutdown [post]
|
||||
func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bm.ShutdownModel(input.Model)
|
||||
}
|
||||
}
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// BackendMonitorEndpoint returns the status of the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Success 200 {object} proto.StatusResponse "Response"
|
||||
// @Router /backend/monitor [get]
|
||||
func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := bm.CheckAndSample(input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// BackendShutdownEndpoint shuts down the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Router /backend/shutdown [post]
|
||||
func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bm.ShutdownModel(input.Model)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -16,17 +16,17 @@ import (
|
||||
// @Param request body schema.DetectionRequest true "query params"
|
||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
||||
// @Router /v1/detection [post]
|
||||
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection")
|
||||
@@ -54,6 +54,6 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user