mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-03 11:13:31 -05:00
Compare commits
122 Commits
feat/stats
...
docs/impro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e8a54f4b6 | ||
|
|
18d11396cd | ||
|
|
93cd688f40 | ||
|
|
721c3f962b | ||
|
|
fb834805db | ||
|
|
839aa7b42b | ||
|
|
e963a45d66 | ||
|
|
c313b2c671 | ||
|
|
137f16336e | ||
|
|
d7f9f3ac93 | ||
|
|
cd7d384500 | ||
|
|
d1a0dd10e6 | ||
|
|
be8cf838c2 | ||
|
|
3276d1cdaf | ||
|
|
5e5f01badd | ||
|
|
6d0f646c37 | ||
|
|
99d31667f8 | ||
|
|
47b546afdc | ||
|
|
a09d49da43 | ||
|
|
1cdcaf0152 | ||
|
|
03e9f4b140 | ||
|
|
7129409bf6 | ||
|
|
d9e9ec6825 | ||
|
|
b82645d28d | ||
|
|
735ca757fa | ||
|
|
b1d1f2a37d | ||
|
|
3728552e94 | ||
|
|
87d0020c10 | ||
|
|
a8eb537071 | ||
|
|
04fe0b0da8 | ||
|
|
fae93e5ba2 | ||
|
|
b606034243 | ||
|
|
5f4663252d | ||
|
|
80bb7c5f67 | ||
|
|
f6881ea023 | ||
|
|
5651a19aa1 | ||
|
|
c834cdb826 | ||
|
|
fa2caef63d | ||
|
|
31abc799f9 | ||
|
|
2368395a0c | ||
|
|
bf77c11b65 | ||
|
|
8876073f5c | ||
|
|
8432915cb8 | ||
|
|
9ddb94b507 | ||
|
|
e42f0f7e79 | ||
|
|
34bc1bda1e | ||
|
|
01cd58a739 | ||
|
|
679d43c2f5 | ||
|
|
4730b52461 | ||
|
|
f678c6b0a9 | ||
|
|
2f2f9beee7 | ||
|
|
8ac7e28c12 | ||
|
|
c5c3538115 | ||
|
|
5ef16b5693 | ||
|
|
02cc8cbcaa | ||
|
|
e5e86d0acb | ||
|
|
edd35d2b33 | ||
|
|
e8cc29e364 | ||
|
|
8f7c499f17 | ||
|
|
ea446fde08 | ||
|
|
122e4c7094 | ||
|
|
2573102317 | ||
|
|
41b60fcfd3 | ||
|
|
cb81869140 | ||
|
|
db9957b94e | ||
|
|
98158881c2 | ||
|
|
79247a5d17 | ||
|
|
46b7a4c5f2 | ||
|
|
436e2d91d0 | ||
|
|
a86fdc4087 | ||
|
|
c7ac6ca687 | ||
|
|
7088327e8d | ||
|
|
e2cb44ef37 | ||
|
|
3a40b4129c | ||
|
|
4ca8055f21 | ||
|
|
704786cc6d | ||
|
|
e5ce1fd9cc | ||
|
|
ea2037f141 | ||
|
|
567fa62330 | ||
|
|
d424a27fa2 | ||
|
|
3ce9cb566d | ||
|
|
ee7638a9b0 | ||
|
|
e57e50e441 | ||
|
|
81880e7975 | ||
|
|
2cad2c8591 | ||
|
|
b87b41ee45 | ||
|
|
424acd66ad | ||
|
|
3cd8234550 | ||
|
|
c70a0f05b8 | ||
|
|
f85e2dd1b8 | ||
|
|
e485bdf9ab | ||
|
|
495c4ee694 | ||
|
|
161d1a0344 | ||
|
|
b6d1def96f | ||
|
|
9ecfdc5938 | ||
|
|
c332ef5cce | ||
|
|
6e7a8c6041 | ||
|
|
43e707ec4f | ||
|
|
fed3663a74 | ||
|
|
5b72798db3 | ||
|
|
d24d6d4e93 | ||
|
|
50ee1fbe06 | ||
|
|
19f3425ce0 | ||
|
|
a6ef245534 | ||
|
|
88cb379c2d | ||
|
|
0ddb2e8dcf | ||
|
|
91b9301bec | ||
|
|
fad5868f7b | ||
|
|
1e5b9135df | ||
|
|
36d19e23e0 | ||
|
|
cba9d1aac0 | ||
|
|
dd21a0d2f9 | ||
|
|
302a43b3ae | ||
|
|
2955061b42 | ||
|
|
84644ab693 | ||
|
|
b8f40dde1e | ||
|
|
a6c9789a54 | ||
|
|
a48d9ce27c | ||
|
|
fb825a2708 | ||
|
|
5558dce449 | ||
|
|
cf74a11e65 | ||
|
|
86b5deec81 |
8
.air.toml
Normal file
8
.air.toml
Normal file
@@ -0,0 +1,8 @@
|
||||
# .air.toml
|
||||
[build]
|
||||
cmd = "make build"
|
||||
bin = "./local-ai"
|
||||
args_bin = [ "--debug" ]
|
||||
include_ext = ["go", "html", "yaml", "toml", "json", "txt", "md"]
|
||||
exclude_dir = ["pkg/grpc/proto"]
|
||||
delay = 1000
|
||||
4
.github/gallery-agent/agent.go
vendored
4
.github/gallery-agent/agent.go
vendored
@@ -7,8 +7,8 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||
"github.com/mudler/cogito"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
cogito "github.com/mudler/cogito"
|
||||
|
||||
"github.com/mudler/cogito/structures"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
|
||||
39
.github/gallery-agent/go.mod
vendored
39
.github/gallery-agent/go.mod
vendored
@@ -1,39 +0,0 @@
|
||||
module github.com/go-skynet/LocalAI/.github/gallery-agent
|
||||
|
||||
go 1.24.1
|
||||
|
||||
require (
|
||||
github.com/mudler/cogito v0.3.0
|
||||
github.com/onsi/ginkgo/v2 v2.25.3
|
||||
github.com/onsi/gomega v1.38.2
|
||||
github.com/sashabaranov/go-openai v1.41.2
|
||||
github.com/tmc/langchaingo v0.1.13
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
dario.cat/mergo v1.0.1 // indirect
|
||||
github.com/Masterminds/goutils v1.1.1 // indirect
|
||||
github.com/Masterminds/semver/v3 v3.4.0 // indirect
|
||||
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/huandu/xstrings v1.5.0 // indirect
|
||||
github.com/mitchellh/copystructure v1.2.0 // indirect
|
||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||
github.com/modelcontextprotocol/go-sdk v1.0.0 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/spf13/cast v1.7.0 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
go.uber.org/automaxprocs v1.6.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.41.0 // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.28.0 // indirect
|
||||
golang.org/x/tools v0.36.0 // indirect
|
||||
)
|
||||
168
.github/gallery-agent/go.sum
vendored
168
.github/gallery-agent/go.sum
vendored
@@ -1,168 +0,0 @@
|
||||
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
|
||||
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
|
||||
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
|
||||
github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0=
|
||||
github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
|
||||
github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs=
|
||||
github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
|
||||
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
|
||||
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
|
||||
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
|
||||
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
|
||||
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
|
||||
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
|
||||
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v28.2.2+incompatible h1:CjwRSksz8Yo4+RmQ339Dp/D2tGO5JxwYeqtMOEe0LDw=
|
||||
github.com/docker/docker v28.2.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
|
||||
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
|
||||
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8=
|
||||
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
|
||||
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
|
||||
github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
|
||||
github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ=
|
||||
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
||||
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
|
||||
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
|
||||
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
|
||||
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
|
||||
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
|
||||
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
|
||||
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
|
||||
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
|
||||
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
|
||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||
github.com/modelcontextprotocol/go-sdk v1.0.0 h1:Z4MSjLi38bTgLrd/LjSmofqRqyBiVKRyQSJgw8q8V74=
|
||||
github.com/modelcontextprotocol/go-sdk v1.0.0/go.mod h1:nYtYQroQ2KQiM0/SbyEPUWQ6xs4B95gJjEalc9AQyOs=
|
||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/mudler/cogito v0.3.0 h1:NbVAO3bLkK5oGSY0xq87jlz8C9OIsLW55s+8Hfzeu9s=
|
||||
github.com/mudler/cogito v0.3.0/go.mod h1:abMwl+CUjCp87IufA2quZdZt0bbLaHHN79o17HbUKxU=
|
||||
github.com/onsi/ginkgo/v2 v2.25.3 h1:Ty8+Yi/ayDAGtk4XxmmfUy4GabvM+MegeB4cDLRi6nw=
|
||||
github.com/onsi/ginkgo/v2 v2.25.3/go.mod h1:43uiyQC4Ed2tkOzLsEYm7hnrb7UJTWHYNsuy3bG/snE=
|
||||
github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A=
|
||||
github.com/onsi/gomega v1.38.2/go.mod h1:W2MJcYxRGV63b418Ai34Ud0hEdTVXq9NW9+Sx6uXf3k=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
|
||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
|
||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
|
||||
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
|
||||
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM=
|
||||
github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
|
||||
github.com/shirou/gopsutil/v4 v4.25.5 h1:rtd9piuSMGeU8g1RMXjZs9y9luK5BwtnG7dZaQUJAsc=
|
||||
github.com/shirou/gopsutil/v4 v4.25.5/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
|
||||
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/testcontainers/testcontainers-go v0.38.0 h1:d7uEapLcv2P8AvH8ahLqDMMxda2W9gQN1nRbHS28HBw=
|
||||
github.com/testcontainers/testcontainers-go v0.38.0/go.mod h1:C52c9MoHpWO+C4aqmgSU+hxlR5jlEayWtgYrb8Pzz1w=
|
||||
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
|
||||
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
|
||||
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
|
||||
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
|
||||
github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA=
|
||||
github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
|
||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
|
||||
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
|
||||
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
|
||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
2
.github/gallery-agent/main.go
vendored
2
.github/gallery-agent/main.go
vendored
@@ -9,7 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
// ProcessedModelFile represents a processed model file with additional metadata
|
||||
|
||||
8
.github/gallery-agent/tools.go
vendored
8
.github/gallery-agent/tools.go
vendored
@@ -3,9 +3,9 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/tmc/langchaingo/jsonschema"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
openai "github.com/sashabaranov/go-openai"
|
||||
jsonschema "github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
|
||||
// Get repository README from HF
|
||||
@@ -13,7 +13,7 @@ type HFReadmeTool struct {
|
||||
client *hfapi.Client
|
||||
}
|
||||
|
||||
func (s *HFReadmeTool) Run(args map[string]any) (string, error) {
|
||||
func (s *HFReadmeTool) Execute(args map[string]any) (string, error) {
|
||||
q, ok := args["repository"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("no query")
|
||||
|
||||
4
.github/workflows/bump_deps.yaml
vendored
4
.github/workflows/bump_deps.yaml
vendored
@@ -1,10 +1,10 @@
|
||||
name: Bump dependencies
|
||||
name: Bump Backend dependencies
|
||||
on:
|
||||
schedule:
|
||||
- cron: 0 20 * * *
|
||||
workflow_dispatch:
|
||||
jobs:
|
||||
bump:
|
||||
bump-backends:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
||||
4
.github/workflows/bump_docs.yaml
vendored
4
.github/workflows/bump_docs.yaml
vendored
@@ -1,10 +1,10 @@
|
||||
name: Bump dependencies
|
||||
name: Bump Documentation
|
||||
on:
|
||||
schedule:
|
||||
- cron: 0 20 * * *
|
||||
workflow_dispatch:
|
||||
jobs:
|
||||
bump:
|
||||
bump-docs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
||||
4
.github/workflows/deploy-explorer.yaml
vendored
4
.github/workflows/deploy-explorer.yaml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
run: |
|
||||
CGO_ENABLED=0 make build
|
||||
- name: rm
|
||||
uses: appleboy/ssh-action@v1.2.2
|
||||
uses: appleboy/ssh-action@v1.2.3
|
||||
with:
|
||||
host: ${{ secrets.EXPLORER_SSH_HOST }}
|
||||
username: ${{ secrets.EXPLORER_SSH_USERNAME }}
|
||||
@@ -53,7 +53,7 @@ jobs:
|
||||
rm: true
|
||||
target: ./local-ai
|
||||
- name: restarting
|
||||
uses: appleboy/ssh-action@v1.2.2
|
||||
uses: appleboy/ssh-action@v1.2.3
|
||||
with:
|
||||
host: ${{ secrets.EXPLORER_SSH_HOST }}
|
||||
username: ${{ secrets.EXPLORER_SSH_USERNAME }}
|
||||
|
||||
11
.github/workflows/gallery-agent.yaml
vendored
11
.github/workflows/gallery-agent.yaml
vendored
@@ -2,7 +2,7 @@ name: Gallery Agent
|
||||
on:
|
||||
|
||||
schedule:
|
||||
- cron: '0 */1 * * *' # Run every 4 hours
|
||||
- cron: '0 */3 * * *' # Run every 4 hours
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
search_term:
|
||||
@@ -39,11 +39,6 @@ jobs:
|
||||
with:
|
||||
go-version: '1.21'
|
||||
|
||||
- name: Build gallery agent
|
||||
run: |
|
||||
cd .github/gallery-agent
|
||||
go mod download
|
||||
go build -o gallery-agent .
|
||||
|
||||
- name: Run gallery agent
|
||||
env:
|
||||
@@ -56,9 +51,7 @@ jobs:
|
||||
MAX_MODELS: ${{ github.event.inputs.max_models || '1' }}
|
||||
run: |
|
||||
export GALLERY_INDEX_PATH=$PWD/gallery/index.yaml
|
||||
cd .github/gallery-agent
|
||||
./gallery-agent
|
||||
rm -rf gallery-agent
|
||||
go run .github/gallery-agent
|
||||
|
||||
- name: Check for changes
|
||||
id: check_changes
|
||||
|
||||
@@ -30,6 +30,7 @@ Thank you for your interest in contributing to LocalAI! We appreciate your time
|
||||
3. Install the required dependencies ( see https://localai.io/basics/build/#build-localai-locally )
|
||||
4. Build LocalAI: `make build`
|
||||
5. Run LocalAI: `./local-ai`
|
||||
6. To Build and live reload: `make build-dev`
|
||||
|
||||
## Contributing
|
||||
|
||||
@@ -76,7 +77,7 @@ LOCALAI_IMAGE_TAG=test LOCALAI_IMAGE=local-ai-aio make run-e2e-aio
|
||||
## Documentation
|
||||
|
||||
We are welcome the contribution of the documents, please open new PR or create a new issue. The documentation is available under `docs/` https://github.com/mudler/LocalAI/tree/master/docs
|
||||
|
||||
|
||||
## Community and Communication
|
||||
|
||||
- You can reach out via the Github issue tracker.
|
||||
|
||||
4
Makefile
4
Makefile
@@ -103,6 +103,10 @@ build-launcher: ## Build the launcher application
|
||||
|
||||
build-all: build build-launcher ## Build both server and launcher
|
||||
|
||||
build-dev: ## Run LocalAI in dev mode with live reload
|
||||
@command -v air >/dev/null 2>&1 || go install github.com/air-verse/air@latest
|
||||
air -c .air.toml
|
||||
|
||||
dev-dist:
|
||||
$(GORELEASER) build --snapshot --clean
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@
|
||||
|
||||
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
|
||||
>
|
||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🌍 Explorer](https://explorer.localai.io) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
|
||||
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
|
||||
[](https://t.me/localaiofficial_bot)
|
||||
|
||||
[](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[](https://artifacthub.io/packages/search?repo=localai)
|
||||
@@ -116,6 +116,8 @@ For more installation options, see [Installer Options](https://localai.io/docs/a
|
||||
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
|
||||
</a>
|
||||
|
||||
> Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244
|
||||
|
||||
Or run with docker:
|
||||
|
||||
> **💡 Docker Run vs Docker Start**
|
||||
@@ -200,7 +202,7 @@ local-ai run oci://localai/phi-2:latest
|
||||
|
||||
> ⚡ **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection).
|
||||
|
||||
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html)
|
||||
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html), if you are interested in our roadmap items and future enhancements, you can see the [Issues labeled as Roadmap here](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
|
||||
|
||||
## 📰 Latest project news
|
||||
|
||||
|
||||
@@ -154,6 +154,10 @@ message PredictOptions {
|
||||
repeated string Videos = 45;
|
||||
repeated string Audios = 46;
|
||||
string CorrelationId = 47;
|
||||
string Tools = 48; // JSON array of available tools/functions for tool calling
|
||||
string ToolChoice = 49; // JSON string or object specifying tool choice behavior
|
||||
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
|
||||
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
|
||||
}
|
||||
|
||||
// The response message containing the result
|
||||
@@ -164,6 +168,7 @@ message Reply {
|
||||
double timing_prompt_processing = 4;
|
||||
double timing_token_generation = 5;
|
||||
bytes audio = 6;
|
||||
bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format
|
||||
}
|
||||
|
||||
message GrammarTrigger {
|
||||
@@ -382,6 +387,11 @@ message StatusResponse {
|
||||
message Message {
|
||||
string role = 1;
|
||||
string content = 2;
|
||||
// Optional fields for OpenAI-compatible message format
|
||||
string name = 3; // Tool name (for tool messages)
|
||||
string tool_call_id = 4; // Tool call ID (for tool messages)
|
||||
string reasoning_content = 5; // Reasoning content (for thinking models)
|
||||
string tool_calls = 6; // Tool calls as JSON string (for assistant messages with tool calls)
|
||||
}
|
||||
|
||||
message DetectOptions {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
LLAMA_VERSION?=5a4ff43e7dd049e35942bc3d12361dab2f155544
|
||||
LLAMA_VERSION?=80deff3648b93727422461c41c7279ef1dac7452
|
||||
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
|
||||
|
||||
CMAKE_ARGS?=
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
|
||||
|
||||
# whisper.cpp version
|
||||
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
|
||||
WHISPER_CPP_VERSION?=f16c12f3f55f5bd3d6ac8cf2f31ab90a42c884d5
|
||||
WHISPER_CPP_VERSION?=d9b7613b34a343848af572cc14467fc5e82fc788
|
||||
SO_TARGET?=libgowhisper.so
|
||||
|
||||
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
accelerate
|
||||
torch
|
||||
torchaudio
|
||||
numpy>=1.24.0,<1.26.0
|
||||
transformers
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
torch==2.6.0+cu118
|
||||
torchaudio==2.6.0+cu118
|
||||
transformers==4.46.3
|
||||
numpy>=1.24.0,<1.26.0
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
@@ -1,6 +1,7 @@
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
numpy>=1.24.0,<1.26.0
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
torch==2.6.0+rocm6.1
|
||||
torchaudio==2.6.0+rocm6.1
|
||||
transformers
|
||||
numpy>=1.24.0,<1.26.0
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
|
||||
@@ -3,6 +3,7 @@ intel-extension-for-pytorch==2.3.110+xpu
|
||||
torch==2.3.1+cxx11.abi
|
||||
torchaudio==2.3.1+cxx11.abi
|
||||
transformers
|
||||
numpy>=1.24.0,<1.26.0
|
||||
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
|
||||
@@ -2,5 +2,6 @@
|
||||
torch
|
||||
torchaudio
|
||||
transformers
|
||||
numpy>=1.24.0,<1.26.0
|
||||
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
|
||||
accelerate
|
||||
|
||||
@@ -61,7 +61,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if request.PipelineType != "": # Reuse the PipelineType field for language
|
||||
kwargs['lang'] = request.PipelineType
|
||||
self.model_name = model_name
|
||||
self.model = Reranker(model_name, **kwargs)
|
||||
self.model = Reranker(model_name, **kwargs)
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
@@ -75,12 +75,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
documents.append(doc)
|
||||
ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents))))
|
||||
# Prepare results to return
|
||||
cropped_results = ranked_results.top_k(request.top_n) if request.top_n > 0 else ranked_results
|
||||
results = [
|
||||
backend_pb2.DocumentResult(
|
||||
index=res.doc_id,
|
||||
text=res.text,
|
||||
relevance_score=res.score
|
||||
) for res in ranked_results.results
|
||||
) for res in (cropped_results)
|
||||
]
|
||||
|
||||
# Calculate the usage and total tokens
|
||||
|
||||
@@ -76,7 +76,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||
)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
|
||||
rerank_response = stub.Rerank(request)
|
||||
print(rerank_response.results[0])
|
||||
self.assertIsNotNone(rerank_response.results)
|
||||
@@ -87,4 +87,60 @@ class TestBackendServicer(unittest.TestCase):
|
||||
print(err)
|
||||
self.fail("Reranker service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
self.tearDown()
|
||||
|
||||
def test_rerank_omit_top_n(self):
|
||||
"""
|
||||
This method tests if the embeddings are generated successfully even top_n is omitted
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
request = backend_pb2.RerankRequest(
|
||||
query="I love you",
|
||||
documents=["I hate you", "I really like you"],
|
||||
top_n=0 #
|
||||
)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
rerank_response = stub.Rerank(request)
|
||||
print(rerank_response.results[0])
|
||||
self.assertIsNotNone(rerank_response.results)
|
||||
self.assertEqual(len(rerank_response.results), 2)
|
||||
self.assertEqual(rerank_response.results[0].text, "I really like you")
|
||||
self.assertEqual(rerank_response.results[1].text, "I hate you")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Reranker service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_rerank_crop(self):
|
||||
"""
|
||||
This method tests top_n cropping
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
request = backend_pb2.RerankRequest(
|
||||
query="I love you",
|
||||
documents=["I hate you", "I really like you", "I hate ignoring top_n"],
|
||||
top_n=2
|
||||
)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
rerank_response = stub.Rerank(request)
|
||||
print(rerank_response.results[0])
|
||||
self.assertIsNotNone(rerank_response.results)
|
||||
self.assertEqual(len(rerank_response.results), 2)
|
||||
self.assertEqual(rerank_response.results[0].text, "I really like you")
|
||||
self.assertEqual(rerank_response.results[1].text, "I hate you")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Reranker service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
@@ -22,9 +22,15 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
|
||||
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.SystemState.Model.ModelsPath)
|
||||
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
|
||||
|
||||
if err := application.start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
caps, err := xsysinfo.CPUCapabilities()
|
||||
if err == nil {
|
||||
log.Debug().Msgf("CPU capabilities: %v", caps)
|
||||
|
||||
}
|
||||
gpus, err := xsysinfo.GPUs()
|
||||
if err == nil {
|
||||
@@ -56,12 +62,12 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
|
||||
log.Error().Err(err).Msg("error installing models")
|
||||
}
|
||||
|
||||
for _, backend := range options.ExternalBackends {
|
||||
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
if err := coreStartup.InstallExternalBackends(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
|
||||
log.Error().Err(err).Msg("error installing external backend")
|
||||
}
|
||||
}
|
||||
@@ -152,10 +158,6 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
// Watch the configuration directory
|
||||
startWatcher(options)
|
||||
|
||||
if err := application.start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info().Msg("core/startup process completed!")
|
||||
return application, nil
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package backend
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
@@ -26,6 +25,7 @@ type LLMResponse struct {
|
||||
Response string // should this be []byte?
|
||||
Usage TokenUsage
|
||||
AudioOutput string
|
||||
Logprobs *schema.Logprobs // Logprobs from the backend response
|
||||
}
|
||||
|
||||
type TokenUsage struct {
|
||||
@@ -35,7 +35,7 @@ type TokenUsage struct {
|
||||
TimingTokenGeneration float64
|
||||
}
|
||||
|
||||
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
|
||||
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (LLMResponse, error), error) {
|
||||
modelFile := c.Model
|
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
@@ -47,7 +47,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
if !slices.Contains(modelNames, c.Name) {
|
||||
utils.ResetDownloadTimers()
|
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
|
||||
//return nil, err
|
||||
@@ -65,29 +65,8 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
var protoMessages []*proto.Message
|
||||
// if we are using the tokenizer template, we need to convert the messages to proto messages
|
||||
// unless the prompt has already been tokenized (non-chat endpoints + functions)
|
||||
if c.TemplateConfig.UseTokenizerTemplate && s == "" {
|
||||
protoMessages = make([]*proto.Message, len(messages), len(messages))
|
||||
for i, message := range messages {
|
||||
protoMessages[i] = &proto.Message{
|
||||
Role: message.Role,
|
||||
}
|
||||
switch ct := message.Content.(type) {
|
||||
case string:
|
||||
protoMessages[i].Content = ct
|
||||
case []interface{}:
|
||||
// If using the tokenizer template, in case of multimodal we want to keep the multimodal content as and return only strings here
|
||||
data, _ := json.Marshal(ct)
|
||||
resultData := []struct {
|
||||
Text string `json:"text"`
|
||||
}{}
|
||||
json.Unmarshal(data, &resultData)
|
||||
for _, r := range resultData {
|
||||
protoMessages[i].Content += r.Text
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
|
||||
}
|
||||
}
|
||||
if c.TemplateConfig.UseTokenizerTemplate && len(messages) > 0 {
|
||||
protoMessages = messages.ToProto()
|
||||
}
|
||||
|
||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||
@@ -99,6 +78,21 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
opts.Images = images
|
||||
opts.Videos = videos
|
||||
opts.Audios = audios
|
||||
opts.Tools = tools
|
||||
opts.ToolChoice = toolChoice
|
||||
if logprobs != nil {
|
||||
opts.Logprobs = int32(*logprobs)
|
||||
}
|
||||
if topLogprobs != nil {
|
||||
opts.TopLogprobs = int32(*topLogprobs)
|
||||
}
|
||||
if len(logitBias) > 0 {
|
||||
// Serialize logit_bias map to JSON string for proto
|
||||
logitBiasJSON, err := json.Marshal(logitBias)
|
||||
if err == nil {
|
||||
opts.LogitBias = string(logitBiasJSON)
|
||||
}
|
||||
}
|
||||
|
||||
tokenUsage := TokenUsage{}
|
||||
|
||||
@@ -130,6 +124,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
}
|
||||
|
||||
ss := ""
|
||||
var logprobs *schema.Logprobs
|
||||
|
||||
var partialRune []byte
|
||||
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
|
||||
@@ -141,6 +136,14 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
|
||||
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
|
||||
|
||||
// Parse logprobs from reply if present (collect from last chunk that has them)
|
||||
if len(reply.Logprobs) > 0 {
|
||||
var parsedLogprobs schema.Logprobs
|
||||
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
|
||||
logprobs = &parsedLogprobs
|
||||
}
|
||||
}
|
||||
|
||||
// Process complete runes and accumulate them
|
||||
var completeRunes []byte
|
||||
for len(partialRune) > 0 {
|
||||
@@ -166,6 +169,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
return LLMResponse{
|
||||
Response: ss,
|
||||
Usage: tokenUsage,
|
||||
Logprobs: logprobs,
|
||||
}, err
|
||||
} else {
|
||||
// TODO: Is the chicken bit the only way to get here? is that acceptable?
|
||||
@@ -188,9 +192,19 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
|
||||
response = c.TemplateConfig.ReplyPrefix + response
|
||||
}
|
||||
|
||||
// Parse logprobs from reply if present
|
||||
var logprobs *schema.Logprobs
|
||||
if len(reply.Logprobs) > 0 {
|
||||
var parsedLogprobs schema.Logprobs
|
||||
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
|
||||
logprobs = &parsedLogprobs
|
||||
}
|
||||
}
|
||||
|
||||
return LLMResponse{
|
||||
Response: response,
|
||||
Usage: tokenUsage,
|
||||
Logprobs: logprobs,
|
||||
}, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,7 +212,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
}
|
||||
}
|
||||
|
||||
return &pb.PredictOptions{
|
||||
pbOpts := &pb.PredictOptions{
|
||||
Temperature: float32(*c.Temperature),
|
||||
TopP: float32(*c.TopP),
|
||||
NDraft: c.NDraft,
|
||||
@@ -249,4 +249,6 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
|
||||
TailFreeSamplingZ: float32(*c.TFZ),
|
||||
TypicalP: float32(*c.TypicalP),
|
||||
}
|
||||
// Logprobs and TopLogprobs are set by the caller if provided
|
||||
return pbOpts
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
@@ -102,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState, true)
|
||||
err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
err = startup.InstallExternalBackends(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -48,10 +48,12 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
|
||||
appHTTP := http.Explorer(db)
|
||||
|
||||
signals.RegisterGracefulTerminationHandler(func() {
|
||||
if err := appHTTP.Shutdown(); err != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := appHTTP.Shutdown(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("error during shutdown")
|
||||
}
|
||||
})
|
||||
|
||||
return appHTTP.Listen(e.Address)
|
||||
return appHTTP.Start(e.Address)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
cliContext "github.com/mudler/LocalAI/core/cli/context"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/startup"
|
||||
@@ -78,6 +80,12 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
galleryService := services.NewGalleryService(&config.ApplicationConfig{}, model.NewModelLoader(systemState, true))
|
||||
err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var galleries []config.Gallery
|
||||
if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil {
|
||||
log.Error().Err(err).Msg("unable to load galleries")
|
||||
@@ -127,7 +135,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
|
||||
modelLoader := model.NewModelLoader(systemState, true)
|
||||
err = startup.InstallModels(galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -232,5 +232,5 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
})
|
||||
|
||||
return appHTTP.Listen(r.Address)
|
||||
return appHTTP.Start(r.Address)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package worker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -42,7 +43,7 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
|
||||
log.Error().Err(err).Msg("failed loading galleries")
|
||||
return "", err
|
||||
}
|
||||
err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
|
||||
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
|
||||
return "", err
|
||||
|
||||
@@ -1,151 +1,17 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/pkg/xsysinfo"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
gguf "github.com/gpustack/gguf-parser-go"
|
||||
)
|
||||
|
||||
type familyType uint8
|
||||
|
||||
const (
|
||||
Unknown familyType = iota
|
||||
LLaMa3
|
||||
CommandR
|
||||
Phi3
|
||||
ChatML
|
||||
Mistral03
|
||||
Gemma
|
||||
DeepSeek2
|
||||
)
|
||||
|
||||
const (
|
||||
defaultContextSize = 1024
|
||||
defaultNGPULayers = 99999999
|
||||
)
|
||||
|
||||
type settingsConfig struct {
|
||||
StopWords []string
|
||||
TemplateConfig TemplateConfig
|
||||
RepeatPenalty float64
|
||||
}
|
||||
|
||||
// default settings to adopt with a given model family
|
||||
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
|
||||
Gemma: {
|
||||
RepeatPenalty: 1.0,
|
||||
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "{{.Input }}\n<start_of_turn>model\n",
|
||||
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
|
||||
Completion: "{{.Input}}",
|
||||
},
|
||||
},
|
||||
DeepSeek2: {
|
||||
StopWords: []string{"<|end▁of▁sentence|>"},
|
||||
TemplateConfig: TemplateConfig{
|
||||
ChatMessage: `{{if eq .RoleName "user" -}}User: {{.Content }}
|
||||
{{ end -}}
|
||||
{{if eq .RoleName "assistant" -}}Assistant: {{.Content}}<|end▁of▁sentence|>{{end}}
|
||||
{{if eq .RoleName "system" -}}{{.Content}}
|
||||
{{end -}}`,
|
||||
Chat: "{{.Input -}}\nAssistant: ",
|
||||
},
|
||||
},
|
||||
LLaMa3: {
|
||||
StopWords: []string{"<|eot_id|>"},
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "<|begin_of_text|>{{.Input }}\n<|start_header_id|>assistant<|end_header_id|>",
|
||||
ChatMessage: "<|start_header_id|>{{ .RoleName }}<|end_header_id|>\n\n{{.Content }}<|eot_id|>",
|
||||
},
|
||||
},
|
||||
CommandR: {
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "{{.Input -}}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
|
||||
Functions: `<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
|
||||
You are a function calling AI model, you can call the following functions:
|
||||
## Available Tools
|
||||
{{range .Functions}}
|
||||
- {"type": "function", "function": {"name": "{{.Name}}", "description": "{{.Description}}", "parameters": {{toJson .Parameters}} }}
|
||||
{{end}}
|
||||
When using a tool, reply with JSON, for instance {"name": "tool_name", "arguments": {"param1": "value1", "param2": "value2"}}
|
||||
<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{.Input -}}`,
|
||||
ChatMessage: `{{if eq .RoleName "user" -}}
|
||||
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
|
||||
{{- else if eq .RoleName "system" -}}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
|
||||
{{- else if eq .RoleName "assistant" -}}
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
|
||||
{{- else if eq .RoleName "tool" -}}
|
||||
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
|
||||
{{- else if .FunctionCall -}}
|
||||
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{toJson .FunctionCall}}}<|END_OF_TURN_TOKEN|>
|
||||
{{- end -}}`,
|
||||
},
|
||||
StopWords: []string{"<|END_OF_TURN_TOKEN|>"},
|
||||
},
|
||||
Phi3: {
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "{{.Input}}\n<|assistant|>",
|
||||
ChatMessage: "<|{{ .RoleName }}|>\n{{.Content}}<|end|>",
|
||||
Completion: "{{.Input}}",
|
||||
},
|
||||
StopWords: []string{"<|end|>", "<|endoftext|>"},
|
||||
},
|
||||
ChatML: {
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "{{.Input -}}\n<|im_start|>assistant",
|
||||
Functions: `<|im_start|>system
|
||||
You are a function calling AI model. You are provided with functions to execute. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:
|
||||
{{range .Functions}}
|
||||
{'type': 'function', 'function': {'name': '{{.Name}}', 'description': '{{.Description}}', 'parameters': {{toJson .Parameters}} }}
|
||||
{{end}}
|
||||
For each function call return a json object with function name and arguments
|
||||
<|im_end|>
|
||||
{{.Input -}}
|
||||
<|im_start|>assistant`,
|
||||
ChatMessage: `<|im_start|>{{ .RoleName }}
|
||||
{{ if .FunctionCall -}}
|
||||
Function call:
|
||||
{{ else if eq .RoleName "tool" -}}
|
||||
Function response:
|
||||
{{ end -}}
|
||||
{{ if .Content -}}
|
||||
{{.Content }}
|
||||
{{ end -}}
|
||||
{{ if .FunctionCall -}}
|
||||
{{toJson .FunctionCall}}
|
||||
{{ end -}}<|im_end|>`,
|
||||
},
|
||||
StopWords: []string{"<|im_end|>", "<dummy32000>", "</s>"},
|
||||
},
|
||||
Mistral03: {
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "{{.Input -}}",
|
||||
Functions: `[AVAILABLE_TOOLS] [{{range .Functions}}{"type": "function", "function": {"name": "{{.Name}}", "description": "{{.Description}}", "parameters": {{toJson .Parameters}} }}{{end}} ] [/AVAILABLE_TOOLS]{{.Input }}`,
|
||||
ChatMessage: `{{if eq .RoleName "user" -}}
|
||||
[INST] {{.Content }} [/INST]
|
||||
{{- else if .FunctionCall -}}
|
||||
[TOOL_CALLS] {{toJson .FunctionCall}} [/TOOL_CALLS]
|
||||
{{- else if eq .RoleName "tool" -}}
|
||||
[TOOL_RESULTS] {{.Content}} [/TOOL_RESULTS]
|
||||
{{- else -}}
|
||||
{{ .Content -}}
|
||||
{{ end -}}`,
|
||||
},
|
||||
StopWords: []string{"<|im_end|>", "<dummy32000>", "</tool_call>", "<|eot_id|>", "<|end_of_text|>", "</s>", "[/TOOL_CALLS]", "[/ACTIONS]"},
|
||||
},
|
||||
}
|
||||
|
||||
// this maps well known template used in HF to model families defined above
|
||||
var knownTemplates = map[string]familyType{
|
||||
`{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\n' + content + '<|im_end|>\n<|im_start|>assistant\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\n' }}{% endif %}{% endfor %}`: ChatML,
|
||||
`{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}`: Mistral03,
|
||||
}
|
||||
|
||||
func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
|
||||
if defaultCtx == 0 && cfg.ContextSize == nil {
|
||||
@@ -216,81 +82,9 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
cfg.Name = f.Metadata().Name
|
||||
}
|
||||
|
||||
family := identifyFamily(f)
|
||||
|
||||
if family == Unknown {
|
||||
log.Debug().Msgf("guessDefaultsFromFile: %s", "family not identified")
|
||||
return
|
||||
}
|
||||
|
||||
// identify template
|
||||
settings, ok := defaultsSettings[family]
|
||||
if ok {
|
||||
cfg.TemplateConfig = settings.TemplateConfig
|
||||
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: guessed template %+v", cfg.TemplateConfig)
|
||||
if len(cfg.StopWords) == 0 {
|
||||
cfg.StopWords = settings.StopWords
|
||||
}
|
||||
if cfg.RepeatPenalty == 0.0 {
|
||||
cfg.RepeatPenalty = settings.RepeatPenalty
|
||||
}
|
||||
} else {
|
||||
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family")
|
||||
}
|
||||
|
||||
if cfg.HasTemplate() {
|
||||
return
|
||||
}
|
||||
|
||||
// identify from well known templates first, otherwise use the raw jinja template
|
||||
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
|
||||
if found {
|
||||
// try to use the jinja template
|
||||
cfg.TemplateConfig.JinjaTemplate = true
|
||||
cfg.TemplateConfig.ChatMessage = chatTemplate.ValueString()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func identifyFamily(f *gguf.GGUFFile) familyType {
|
||||
|
||||
// identify from well known templates first
|
||||
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
|
||||
if found && chatTemplate.ValueString() != "" {
|
||||
if family, ok := knownTemplates[chatTemplate.ValueString()]; ok {
|
||||
return family
|
||||
}
|
||||
}
|
||||
|
||||
// otherwise try to identify from the model properties
|
||||
arch := f.Architecture().Architecture
|
||||
eosTokenID := f.Tokenizer().EOSTokenID
|
||||
bosTokenID := f.Tokenizer().BOSTokenID
|
||||
|
||||
isYI := arch == "llama" && bosTokenID == 1 && eosTokenID == 2
|
||||
// WTF! Mistral0.3 and isYi have same bosTokenID and eosTokenID
|
||||
|
||||
llama3 := arch == "llama" && eosTokenID == 128009
|
||||
commandR := arch == "command-r" && eosTokenID == 255001
|
||||
qwen2 := arch == "qwen2"
|
||||
phi3 := arch == "phi-3"
|
||||
gemma := strings.HasPrefix(arch, "gemma") || strings.Contains(strings.ToLower(f.Metadata().Name), "gemma")
|
||||
deepseek2 := arch == "deepseek2"
|
||||
|
||||
switch {
|
||||
case deepseek2:
|
||||
return DeepSeek2
|
||||
case gemma:
|
||||
return Gemma
|
||||
case llama3:
|
||||
return LLaMa3
|
||||
case commandR:
|
||||
return CommandR
|
||||
case phi3:
|
||||
return Phi3
|
||||
case qwen2, isYI:
|
||||
return ChatML
|
||||
default:
|
||||
return Unknown
|
||||
}
|
||||
// Instruct to use template from llama.cpp
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.FunctionsConfig.GrammarConfig.NoGrammar = true
|
||||
cfg.Options = append(cfg.Options, "use_jinja:true")
|
||||
cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT")
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/cogito"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -16,30 +17,31 @@ const (
|
||||
RAND_SEED = -1
|
||||
)
|
||||
|
||||
// @Description TTS configuration
|
||||
type TTSConfig struct {
|
||||
|
||||
// Voice wav path or id
|
||||
Voice string `yaml:"voice" json:"voice"`
|
||||
Voice string `yaml:"voice,omitempty" json:"voice,omitempty"`
|
||||
|
||||
AudioPath string `yaml:"audio_path" json:"audio_path"`
|
||||
AudioPath string `yaml:"audio_path,omitempty" json:"audio_path,omitempty"`
|
||||
}
|
||||
|
||||
// ModelConfig represents a model configuration
|
||||
// @Description ModelConfig represents a model configuration
|
||||
type ModelConfig struct {
|
||||
modelConfigFile string `yaml:"-" json:"-"`
|
||||
schema.PredictionOptions `yaml:"parameters" json:"parameters"`
|
||||
Name string `yaml:"name" json:"name"`
|
||||
schema.PredictionOptions `yaml:"parameters,omitempty" json:"parameters,omitempty"`
|
||||
Name string `yaml:"name,omitempty" json:"name,omitempty"`
|
||||
|
||||
F16 *bool `yaml:"f16" json:"f16"`
|
||||
Threads *int `yaml:"threads" json:"threads"`
|
||||
Debug *bool `yaml:"debug" json:"debug"`
|
||||
Roles map[string]string `yaml:"roles" json:"roles"`
|
||||
Embeddings *bool `yaml:"embeddings" json:"embeddings"`
|
||||
Backend string `yaml:"backend" json:"backend"`
|
||||
TemplateConfig TemplateConfig `yaml:"template" json:"template"`
|
||||
KnownUsecaseStrings []string `yaml:"known_usecases" json:"known_usecases"`
|
||||
F16 *bool `yaml:"f16,omitempty" json:"f16,omitempty"`
|
||||
Threads *int `yaml:"threads,omitempty" json:"threads,omitempty"`
|
||||
Debug *bool `yaml:"debug,omitempty" json:"debug,omitempty"`
|
||||
Roles map[string]string `yaml:"roles,omitempty" json:"roles,omitempty"`
|
||||
Embeddings *bool `yaml:"embeddings,omitempty" json:"embeddings,omitempty"`
|
||||
Backend string `yaml:"backend,omitempty" json:"backend,omitempty"`
|
||||
TemplateConfig TemplateConfig `yaml:"template,omitempty" json:"template,omitempty"`
|
||||
KnownUsecaseStrings []string `yaml:"known_usecases,omitempty" json:"known_usecases,omitempty"`
|
||||
KnownUsecases *ModelConfigUsecases `yaml:"-" json:"-"`
|
||||
Pipeline Pipeline `yaml:"pipeline" json:"pipeline"`
|
||||
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`
|
||||
|
||||
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
|
||||
InputToken [][]int `yaml:"-" json:"-"`
|
||||
@@ -47,96 +49,101 @@ type ModelConfig struct {
|
||||
ResponseFormat string `yaml:"-" json:"-"`
|
||||
ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"`
|
||||
|
||||
FunctionsConfig functions.FunctionsConfig `yaml:"function" json:"function"`
|
||||
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
|
||||
|
||||
FeatureFlag FeatureFlag `yaml:"feature_flags" json:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
|
||||
FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
|
||||
// LLM configs (GPT4ALL, Llama.cpp, ...)
|
||||
LLMConfig `yaml:",inline" json:",inline"`
|
||||
|
||||
// Diffusers
|
||||
Diffusers Diffusers `yaml:"diffusers" json:"diffusers"`
|
||||
Step int `yaml:"step" json:"step"`
|
||||
Diffusers Diffusers `yaml:"diffusers,omitempty" json:"diffusers,omitempty"`
|
||||
Step int `yaml:"step,omitempty" json:"step,omitempty"`
|
||||
|
||||
// GRPC Options
|
||||
GRPC GRPC `yaml:"grpc" json:"grpc"`
|
||||
GRPC GRPC `yaml:"grpc,omitempty" json:"grpc,omitempty"`
|
||||
|
||||
// TTS specifics
|
||||
TTSConfig `yaml:"tts" json:"tts"`
|
||||
TTSConfig `yaml:"tts,omitempty" json:"tts,omitempty"`
|
||||
|
||||
// CUDA
|
||||
// Explicitly enable CUDA or not (some backends might need it)
|
||||
CUDA bool `yaml:"cuda" json:"cuda"`
|
||||
CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"`
|
||||
|
||||
DownloadFiles []File `yaml:"download_files" json:"download_files"`
|
||||
DownloadFiles []File `yaml:"download_files,omitempty" json:"download_files,omitempty"`
|
||||
|
||||
Description string `yaml:"description" json:"description"`
|
||||
Usage string `yaml:"usage" json:"usage"`
|
||||
Description string `yaml:"description,omitempty" json:"description,omitempty"`
|
||||
Usage string `yaml:"usage,omitempty" json:"usage,omitempty"`
|
||||
|
||||
Options []string `yaml:"options" json:"options"`
|
||||
Overrides []string `yaml:"overrides" json:"overrides"`
|
||||
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
|
||||
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
|
||||
|
||||
MCP MCPConfig `yaml:"mcp" json:"mcp"`
|
||||
Agent AgentConfig `yaml:"agent" json:"agent"`
|
||||
MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"`
|
||||
Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"`
|
||||
}
|
||||
|
||||
// @Description MCP configuration
|
||||
type MCPConfig struct {
|
||||
Servers string `yaml:"remote" json:"remote"`
|
||||
Stdio string `yaml:"stdio" json:"stdio"`
|
||||
Servers string `yaml:"remote,omitempty" json:"remote,omitempty"`
|
||||
Stdio string `yaml:"stdio,omitempty" json:"stdio,omitempty"`
|
||||
}
|
||||
|
||||
// @Description Agent configuration
|
||||
type AgentConfig struct {
|
||||
MaxAttempts int `yaml:"max_attempts" json:"max_attempts"`
|
||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||
EnableReasoning bool `yaml:"enable_reasoning" json:"enable_reasoning"`
|
||||
EnablePlanning bool `yaml:"enable_planning" json:"enable_planning"`
|
||||
EnableMCPPrompts bool `yaml:"enable_mcp_prompts" json:"enable_mcp_prompts"`
|
||||
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator" json:"enable_plan_re_evaluator"`
|
||||
MaxAttempts int `yaml:"max_attempts,omitempty" json:"max_attempts,omitempty"`
|
||||
MaxIterations int `yaml:"max_iterations,omitempty" json:"max_iterations,omitempty"`
|
||||
EnableReasoning bool `yaml:"enable_reasoning,omitempty" json:"enable_reasoning,omitempty"`
|
||||
EnablePlanning bool `yaml:"enable_planning,omitempty" json:"enable_planning,omitempty"`
|
||||
EnableMCPPrompts bool `yaml:"enable_mcp_prompts,omitempty" json:"enable_mcp_prompts,omitempty"`
|
||||
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator,omitempty" json:"enable_plan_re_evaluator,omitempty"`
|
||||
}
|
||||
|
||||
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers]) {
|
||||
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
|
||||
var remote MCPGenericConfig[MCPRemoteServers]
|
||||
var stdio MCPGenericConfig[MCPSTDIOServers]
|
||||
|
||||
if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil {
|
||||
return remote, stdio
|
||||
return remote, stdio, err
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil {
|
||||
return remote, stdio
|
||||
return remote, stdio, err
|
||||
}
|
||||
|
||||
return remote, stdio
|
||||
return remote, stdio, nil
|
||||
}
|
||||
|
||||
// @Description MCP generic configuration
|
||||
type MCPGenericConfig[T any] struct {
|
||||
Servers T `yaml:"mcpServers" json:"mcpServers"`
|
||||
Servers T `yaml:"mcpServers,omitempty" json:"mcpServers,omitempty"`
|
||||
}
|
||||
type MCPRemoteServers map[string]MCPRemoteServer
|
||||
type MCPSTDIOServers map[string]MCPSTDIOServer
|
||||
|
||||
// @Description MCP remote server configuration
|
||||
type MCPRemoteServer struct {
|
||||
URL string `json:"url"`
|
||||
Token string `json:"token"`
|
||||
URL string `json:"url,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
}
|
||||
|
||||
// @Description MCP STDIO server configuration
|
||||
type MCPSTDIOServer struct {
|
||||
Args []string `json:"args"`
|
||||
Env map[string]string `json:"env"`
|
||||
Command string `json:"command"`
|
||||
Args []string `json:"args,omitempty"`
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
Command string `json:"command,omitempty"`
|
||||
}
|
||||
|
||||
// Pipeline defines other models to use for audio-to-audio
|
||||
// @Description Pipeline defines other models to use for audio-to-audio
|
||||
type Pipeline struct {
|
||||
TTS string `yaml:"tts" json:"tts"`
|
||||
LLM string `yaml:"llm" json:"llm"`
|
||||
Transcription string `yaml:"transcription" json:"transcription"`
|
||||
VAD string `yaml:"vad" json:"vad"`
|
||||
TTS string `yaml:"tts,omitempty" json:"tts,omitempty"`
|
||||
LLM string `yaml:"llm,omitempty" json:"llm,omitempty"`
|
||||
Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"`
|
||||
VAD string `yaml:"vad,omitempty" json:"vad,omitempty"`
|
||||
}
|
||||
|
||||
// @Description File configuration for model downloads
|
||||
type File struct {
|
||||
Filename string `yaml:"filename" json:"filename"`
|
||||
SHA256 string `yaml:"sha256" json:"sha256"`
|
||||
URI downloader.URI `yaml:"uri" json:"uri"`
|
||||
Filename string `yaml:"filename,omitempty" json:"filename,omitempty"`
|
||||
SHA256 string `yaml:"sha256,omitempty" json:"sha256,omitempty"`
|
||||
URI downloader.URI `yaml:"uri,omitempty" json:"uri,omitempty"`
|
||||
}
|
||||
|
||||
type FeatureFlag map[string]*bool
|
||||
@@ -148,126 +155,136 @@ func (ff FeatureFlag) Enabled(s string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// @Description GRPC configuration
|
||||
type GRPC struct {
|
||||
Attempts int `yaml:"attempts" json:"attempts"`
|
||||
AttemptsSleepTime int `yaml:"attempts_sleep_time" json:"attempts_sleep_time"`
|
||||
Attempts int `yaml:"attempts,omitempty" json:"attempts,omitempty"`
|
||||
AttemptsSleepTime int `yaml:"attempts_sleep_time,omitempty" json:"attempts_sleep_time,omitempty"`
|
||||
}
|
||||
|
||||
// @Description Diffusers configuration
|
||||
type Diffusers struct {
|
||||
CUDA bool `yaml:"cuda" json:"cuda"`
|
||||
PipelineType string `yaml:"pipeline_type" json:"pipeline_type"`
|
||||
SchedulerType string `yaml:"scheduler_type" json:"scheduler_type"`
|
||||
EnableParameters string `yaml:"enable_parameters" json:"enable_parameters"` // A list of comma separated parameters to specify
|
||||
IMG2IMG bool `yaml:"img2img" json:"img2img"` // Image to Image Diffuser
|
||||
ClipSkip int `yaml:"clip_skip" json:"clip_skip"` // Skip every N frames
|
||||
ClipModel string `yaml:"clip_model" json:"clip_model"` // Clip model to use
|
||||
ClipSubFolder string `yaml:"clip_subfolder" json:"clip_subfolder"` // Subfolder to use for clip model
|
||||
ControlNet string `yaml:"control_net" json:"control_net"`
|
||||
CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"`
|
||||
PipelineType string `yaml:"pipeline_type,omitempty" json:"pipeline_type,omitempty"`
|
||||
SchedulerType string `yaml:"scheduler_type,omitempty" json:"scheduler_type,omitempty"`
|
||||
EnableParameters string `yaml:"enable_parameters,omitempty" json:"enable_parameters,omitempty"` // A list of comma separated parameters to specify
|
||||
IMG2IMG bool `yaml:"img2img,omitempty" json:"img2img,omitempty"` // Image to Image Diffuser
|
||||
ClipSkip int `yaml:"clip_skip,omitempty" json:"clip_skip,omitempty"` // Skip every N frames
|
||||
ClipModel string `yaml:"clip_model,omitempty" json:"clip_model,omitempty"` // Clip model to use
|
||||
ClipSubFolder string `yaml:"clip_subfolder,omitempty" json:"clip_subfolder,omitempty"` // Subfolder to use for clip model
|
||||
ControlNet string `yaml:"control_net,omitempty" json:"control_net,omitempty"`
|
||||
}
|
||||
|
||||
// LLMConfig is a struct that holds the configuration that are
|
||||
// generic for most of the LLM backends.
|
||||
// @Description LLMConfig is a struct that holds the configuration that are generic for most of the LLM backends.
|
||||
type LLMConfig struct {
|
||||
SystemPrompt string `yaml:"system_prompt" json:"system_prompt"`
|
||||
TensorSplit string `yaml:"tensor_split" json:"tensor_split"`
|
||||
MainGPU string `yaml:"main_gpu" json:"main_gpu"`
|
||||
RMSNormEps float32 `yaml:"rms_norm_eps" json:"rms_norm_eps"`
|
||||
NGQA int32 `yaml:"ngqa" json:"ngqa"`
|
||||
PromptCachePath string `yaml:"prompt_cache_path" json:"prompt_cache_path"`
|
||||
PromptCacheAll bool `yaml:"prompt_cache_all" json:"prompt_cache_all"`
|
||||
PromptCacheRO bool `yaml:"prompt_cache_ro" json:"prompt_cache_ro"`
|
||||
MirostatETA *float64 `yaml:"mirostat_eta" json:"mirostat_eta"`
|
||||
MirostatTAU *float64 `yaml:"mirostat_tau" json:"mirostat_tau"`
|
||||
Mirostat *int `yaml:"mirostat" json:"mirostat"`
|
||||
NGPULayers *int `yaml:"gpu_layers" json:"gpu_layers"`
|
||||
MMap *bool `yaml:"mmap" json:"mmap"`
|
||||
MMlock *bool `yaml:"mmlock" json:"mmlock"`
|
||||
LowVRAM *bool `yaml:"low_vram" json:"low_vram"`
|
||||
Reranking *bool `yaml:"reranking" json:"reranking"`
|
||||
Grammar string `yaml:"grammar" json:"grammar"`
|
||||
StopWords []string `yaml:"stopwords" json:"stopwords"`
|
||||
Cutstrings []string `yaml:"cutstrings" json:"cutstrings"`
|
||||
ExtractRegex []string `yaml:"extract_regex" json:"extract_regex"`
|
||||
TrimSpace []string `yaml:"trimspace" json:"trimspace"`
|
||||
TrimSuffix []string `yaml:"trimsuffix" json:"trimsuffix"`
|
||||
SystemPrompt string `yaml:"system_prompt,omitempty" json:"system_prompt,omitempty"`
|
||||
TensorSplit string `yaml:"tensor_split,omitempty" json:"tensor_split,omitempty"`
|
||||
MainGPU string `yaml:"main_gpu,omitempty" json:"main_gpu,omitempty"`
|
||||
RMSNormEps float32 `yaml:"rms_norm_eps,omitempty" json:"rms_norm_eps,omitempty"`
|
||||
NGQA int32 `yaml:"ngqa,omitempty" json:"ngqa,omitempty"`
|
||||
PromptCachePath string `yaml:"prompt_cache_path,omitempty" json:"prompt_cache_path,omitempty"`
|
||||
PromptCacheAll bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"`
|
||||
PromptCacheRO bool `yaml:"prompt_cache_ro,omitempty" json:"prompt_cache_ro,omitempty"`
|
||||
MirostatETA *float64 `yaml:"mirostat_eta,omitempty" json:"mirostat_eta,omitempty"`
|
||||
MirostatTAU *float64 `yaml:"mirostat_tau,omitempty" json:"mirostat_tau,omitempty"`
|
||||
Mirostat *int `yaml:"mirostat,omitempty" json:"mirostat,omitempty"`
|
||||
NGPULayers *int `yaml:"gpu_layers,omitempty" json:"gpu_layers,omitempty"`
|
||||
MMap *bool `yaml:"mmap,omitempty" json:"mmap,omitempty"`
|
||||
MMlock *bool `yaml:"mmlock,omitempty" json:"mmlock,omitempty"`
|
||||
LowVRAM *bool `yaml:"low_vram,omitempty" json:"low_vram,omitempty"`
|
||||
Reranking *bool `yaml:"reranking,omitempty" json:"reranking,omitempty"`
|
||||
Grammar string `yaml:"grammar,omitempty" json:"grammar,omitempty"`
|
||||
StopWords []string `yaml:"stopwords,omitempty" json:"stopwords,omitempty"`
|
||||
Cutstrings []string `yaml:"cutstrings,omitempty" json:"cutstrings,omitempty"`
|
||||
ExtractRegex []string `yaml:"extract_regex,omitempty" json:"extract_regex,omitempty"`
|
||||
TrimSpace []string `yaml:"trimspace,omitempty" json:"trimspace,omitempty"`
|
||||
TrimSuffix []string `yaml:"trimsuffix,omitempty" json:"trimsuffix,omitempty"`
|
||||
|
||||
ContextSize *int `yaml:"context_size" json:"context_size"`
|
||||
NUMA bool `yaml:"numa" json:"numa"`
|
||||
LoraAdapter string `yaml:"lora_adapter" json:"lora_adapter"`
|
||||
LoraBase string `yaml:"lora_base" json:"lora_base"`
|
||||
LoraAdapters []string `yaml:"lora_adapters" json:"lora_adapters"`
|
||||
LoraScales []float32 `yaml:"lora_scales" json:"lora_scales"`
|
||||
LoraScale float32 `yaml:"lora_scale" json:"lora_scale"`
|
||||
NoMulMatQ bool `yaml:"no_mulmatq" json:"no_mulmatq"`
|
||||
DraftModel string `yaml:"draft_model" json:"draft_model"`
|
||||
NDraft int32 `yaml:"n_draft" json:"n_draft"`
|
||||
Quantization string `yaml:"quantization" json:"quantization"`
|
||||
LoadFormat string `yaml:"load_format" json:"load_format"`
|
||||
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization" json:"gpu_memory_utilization"` // vLLM
|
||||
TrustRemoteCode bool `yaml:"trust_remote_code" json:"trust_remote_code"` // vLLM
|
||||
EnforceEager bool `yaml:"enforce_eager" json:"enforce_eager"` // vLLM
|
||||
SwapSpace int `yaml:"swap_space" json:"swap_space"` // vLLM
|
||||
MaxModelLen int `yaml:"max_model_len" json:"max_model_len"` // vLLM
|
||||
TensorParallelSize int `yaml:"tensor_parallel_size" json:"tensor_parallel_size"` // vLLM
|
||||
DisableLogStatus bool `yaml:"disable_log_stats" json:"disable_log_stats"` // vLLM
|
||||
DType string `yaml:"dtype" json:"dtype"` // vLLM
|
||||
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt" json:"limit_mm_per_prompt"` // vLLM
|
||||
MMProj string `yaml:"mmproj" json:"mmproj"`
|
||||
ContextSize *int `yaml:"context_size,omitempty" json:"context_size,omitempty"`
|
||||
NUMA bool `yaml:"numa,omitempty" json:"numa,omitempty"`
|
||||
LoraAdapter string `yaml:"lora_adapter,omitempty" json:"lora_adapter,omitempty"`
|
||||
LoraBase string `yaml:"lora_base,omitempty" json:"lora_base,omitempty"`
|
||||
LoraAdapters []string `yaml:"lora_adapters,omitempty" json:"lora_adapters,omitempty"`
|
||||
LoraScales []float32 `yaml:"lora_scales,omitempty" json:"lora_scales,omitempty"`
|
||||
LoraScale float32 `yaml:"lora_scale,omitempty" json:"lora_scale,omitempty"`
|
||||
NoMulMatQ bool `yaml:"no_mulmatq,omitempty" json:"no_mulmatq,omitempty"`
|
||||
DraftModel string `yaml:"draft_model,omitempty" json:"draft_model,omitempty"`
|
||||
NDraft int32 `yaml:"n_draft,omitempty" json:"n_draft,omitempty"`
|
||||
Quantization string `yaml:"quantization,omitempty" json:"quantization,omitempty"`
|
||||
LoadFormat string `yaml:"load_format,omitempty" json:"load_format,omitempty"`
|
||||
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization,omitempty" json:"gpu_memory_utilization,omitempty"` // vLLM
|
||||
TrustRemoteCode bool `yaml:"trust_remote_code,omitempty" json:"trust_remote_code,omitempty"` // vLLM
|
||||
EnforceEager bool `yaml:"enforce_eager,omitempty" json:"enforce_eager,omitempty"` // vLLM
|
||||
SwapSpace int `yaml:"swap_space,omitempty" json:"swap_space,omitempty"` // vLLM
|
||||
MaxModelLen int `yaml:"max_model_len,omitempty" json:"max_model_len,omitempty"` // vLLM
|
||||
TensorParallelSize int `yaml:"tensor_parallel_size,omitempty" json:"tensor_parallel_size,omitempty"` // vLLM
|
||||
DisableLogStatus bool `yaml:"disable_log_stats,omitempty" json:"disable_log_stats,omitempty"` // vLLM
|
||||
DType string `yaml:"dtype,omitempty" json:"dtype,omitempty"` // vLLM
|
||||
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt,omitempty" json:"limit_mm_per_prompt,omitempty"` // vLLM
|
||||
MMProj string `yaml:"mmproj,omitempty" json:"mmproj,omitempty"`
|
||||
|
||||
FlashAttention *string `yaml:"flash_attention" json:"flash_attention"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"`
|
||||
CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"`
|
||||
CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"`
|
||||
FlashAttention *string `yaml:"flash_attention,omitempty" json:"flash_attention,omitempty"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading,omitempty" json:"no_kv_offloading,omitempty"`
|
||||
CacheTypeK string `yaml:"cache_type_k,omitempty" json:"cache_type_k,omitempty"`
|
||||
CacheTypeV string `yaml:"cache_type_v,omitempty" json:"cache_type_v,omitempty"`
|
||||
|
||||
RopeScaling string `yaml:"rope_scaling" json:"rope_scaling"`
|
||||
ModelType string `yaml:"type" json:"type"`
|
||||
RopeScaling string `yaml:"rope_scaling,omitempty" json:"rope_scaling,omitempty"`
|
||||
ModelType string `yaml:"type,omitempty" json:"type,omitempty"`
|
||||
|
||||
YarnExtFactor float32 `yaml:"yarn_ext_factor" json:"yarn_ext_factor"`
|
||||
YarnAttnFactor float32 `yaml:"yarn_attn_factor" json:"yarn_attn_factor"`
|
||||
YarnBetaFast float32 `yaml:"yarn_beta_fast" json:"yarn_beta_fast"`
|
||||
YarnBetaSlow float32 `yaml:"yarn_beta_slow" json:"yarn_beta_slow"`
|
||||
YarnExtFactor float32 `yaml:"yarn_ext_factor,omitempty" json:"yarn_ext_factor,omitempty"`
|
||||
YarnAttnFactor float32 `yaml:"yarn_attn_factor,omitempty" json:"yarn_attn_factor,omitempty"`
|
||||
YarnBetaFast float32 `yaml:"yarn_beta_fast,omitempty" json:"yarn_beta_fast,omitempty"`
|
||||
YarnBetaSlow float32 `yaml:"yarn_beta_slow,omitempty" json:"yarn_beta_slow,omitempty"`
|
||||
|
||||
CFGScale float32 `yaml:"cfg_scale" json:"cfg_scale"` // Classifier-Free Guidance Scale
|
||||
CFGScale float32 `yaml:"cfg_scale,omitempty" json:"cfg_scale,omitempty"` // Classifier-Free Guidance Scale
|
||||
}
|
||||
|
||||
// LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
|
||||
// @Description LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
|
||||
type LimitMMPerPrompt struct {
|
||||
LimitImagePerPrompt int `yaml:"image" json:"image"`
|
||||
LimitVideoPerPrompt int `yaml:"video" json:"video"`
|
||||
LimitAudioPerPrompt int `yaml:"audio" json:"audio"`
|
||||
LimitImagePerPrompt int `yaml:"image,omitempty" json:"image,omitempty"`
|
||||
LimitVideoPerPrompt int `yaml:"video,omitempty" json:"video,omitempty"`
|
||||
LimitAudioPerPrompt int `yaml:"audio,omitempty" json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
// TemplateConfig is a struct that holds the configuration of the templating system
|
||||
// @Description TemplateConfig is a struct that holds the configuration of the templating system
|
||||
type TemplateConfig struct {
|
||||
// Chat is the template used in the chat completion endpoint
|
||||
Chat string `yaml:"chat" json:"chat"`
|
||||
Chat string `yaml:"chat,omitempty" json:"chat,omitempty"`
|
||||
|
||||
// ChatMessage is the template used for chat messages
|
||||
ChatMessage string `yaml:"chat_message" json:"chat_message"`
|
||||
ChatMessage string `yaml:"chat_message,omitempty" json:"chat_message,omitempty"`
|
||||
|
||||
// Completion is the template used for completion requests
|
||||
Completion string `yaml:"completion" json:"completion"`
|
||||
Completion string `yaml:"completion,omitempty" json:"completion,omitempty"`
|
||||
|
||||
// Edit is the template used for edit completion requests
|
||||
Edit string `yaml:"edit" json:"edit"`
|
||||
Edit string `yaml:"edit,omitempty" json:"edit,omitempty"`
|
||||
|
||||
// Functions is the template used when tools are present in the client requests
|
||||
Functions string `yaml:"function" json:"function"`
|
||||
Functions string `yaml:"function,omitempty" json:"function,omitempty"`
|
||||
|
||||
// UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used.
|
||||
// Note: this is mostly consumed for backends such as vllm and transformers
|
||||
// that can use the tokenizers specified in the JSON config files of the models
|
||||
UseTokenizerTemplate bool `yaml:"use_tokenizer_template" json:"use_tokenizer_template"`
|
||||
UseTokenizerTemplate bool `yaml:"use_tokenizer_template,omitempty" json:"use_tokenizer_template,omitempty"`
|
||||
|
||||
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
|
||||
// It defaults to \n
|
||||
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character" json:"join_chat_messages_by_character"`
|
||||
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character,omitempty" json:"join_chat_messages_by_character,omitempty"`
|
||||
|
||||
Multimodal string `yaml:"multimodal" json:"multimodal"`
|
||||
Multimodal string `yaml:"multimodal,omitempty" json:"multimodal,omitempty"`
|
||||
|
||||
JinjaTemplate bool `yaml:"jinja_template" json:"jinja_template"`
|
||||
ReplyPrefix string `yaml:"reply_prefix,omitempty" json:"reply_prefix,omitempty"`
|
||||
}
|
||||
|
||||
ReplyPrefix string `yaml:"reply_prefix" json:"reply_prefix"`
|
||||
func (c *ModelConfig) syncKnownUsecasesFromString() {
|
||||
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
|
||||
// Make sure the usecases are valid, we rewrite with what we identified
|
||||
c.KnownUsecaseStrings = []string{}
|
||||
for k, usecase := range GetAllModelConfigUsecases() {
|
||||
if c.HasUsecases(usecase) {
|
||||
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
@@ -278,14 +295,7 @@ func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
|
||||
}
|
||||
*c = ModelConfig(aux)
|
||||
|
||||
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
|
||||
// Make sure the usecases are valid, we rewrite with what we identified
|
||||
c.KnownUsecaseStrings = []string{}
|
||||
for k, usecase := range GetAllModelConfigUsecases() {
|
||||
if c.HasUsecases(usecase) {
|
||||
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
|
||||
}
|
||||
}
|
||||
c.syncKnownUsecasesFromString()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -462,6 +472,7 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
|
||||
}
|
||||
|
||||
guessDefaultsFromFile(cfg, lo.modelPath, ctx)
|
||||
cfg.syncKnownUsecasesFromString()
|
||||
}
|
||||
|
||||
func (c *ModelConfig) Validate() bool {
|
||||
@@ -492,7 +503,7 @@ func (c *ModelConfig) Validate() bool {
|
||||
}
|
||||
|
||||
func (c *ModelConfig) HasTemplate() bool {
|
||||
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != ""
|
||||
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" || c.TemplateConfig.UseTokenizerTemplate
|
||||
}
|
||||
|
||||
func (c *ModelConfig) GetModelConfigFile() string {
|
||||
@@ -573,7 +584,7 @@ func (c *ModelConfig) HasUsecases(u ModelConfigUsecases) bool {
|
||||
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
|
||||
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
|
||||
if (u & FLAG_CHAT) == FLAG_CHAT {
|
||||
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" {
|
||||
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -658,3 +669,40 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// BuildCogitoOptions generates cogito options from the model configuration
|
||||
// It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results
|
||||
func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
|
||||
cogitoOpts := []cogito.Option{
|
||||
cogito.WithIterations(3), // default to 3 iterations
|
||||
cogito.WithMaxAttempts(3), // default to 3 attempts
|
||||
cogito.WithForceReasoning(),
|
||||
}
|
||||
|
||||
// Apply agent configuration options
|
||||
if c.Agent.EnableReasoning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
|
||||
}
|
||||
|
||||
if c.Agent.EnablePlanning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
|
||||
}
|
||||
|
||||
if c.Agent.EnableMCPPrompts {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
|
||||
}
|
||||
|
||||
if c.Agent.EnablePlanReEvaluator {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
|
||||
}
|
||||
|
||||
if c.Agent.MaxIterations != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations))
|
||||
}
|
||||
|
||||
if c.Agent.MaxAttempts != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts))
|
||||
}
|
||||
|
||||
return cogitoOpts
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -68,7 +70,7 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
|
||||
}
|
||||
|
||||
// InstallBackendFromGallery installs a backend from the gallery.
|
||||
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
|
||||
func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
|
||||
if !force {
|
||||
// check if we already have the backend installed
|
||||
backends, err := ListSystemBackends(systemState)
|
||||
@@ -108,7 +110,7 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
|
||||
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
|
||||
|
||||
// Then, let's install the best backend
|
||||
if err := InstallBackend(systemState, modelLoader, bestBackend, downloadStatus); err != nil {
|
||||
if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -133,10 +135,10 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
|
||||
return nil
|
||||
}
|
||||
|
||||
return InstallBackend(systemState, modelLoader, backend, downloadStatus)
|
||||
return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus)
|
||||
}
|
||||
|
||||
func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
|
||||
func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
|
||||
if err != nil {
|
||||
@@ -163,11 +165,17 @@ func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoa
|
||||
}
|
||||
} else {
|
||||
uri := downloader.URI(config.URI)
|
||||
if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
|
||||
success := false
|
||||
// Try to download from mirrors
|
||||
for _, mirror := range config.Mirrors {
|
||||
if err := downloader.URI(mirror).DownloadFile(backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
// Check for cancellation before trying next mirror
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
|
||||
success = true
|
||||
break
|
||||
}
|
||||
@@ -310,8 +318,10 @@ func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
log.Warn().Err(err).Msg("Failed to read system backends, proceeding with user-managed backends")
|
||||
} else if errors.Is(err, os.ErrNotExist) {
|
||||
log.Debug().Msg("No system backends found")
|
||||
}
|
||||
|
||||
// User-managed backends and alias collection
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -55,7 +56,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
|
||||
)
|
||||
must(err)
|
||||
sysDefault.GPUVendor = "" // force default selection
|
||||
backs, err := ListSystemBackends(sysDefault)
|
||||
backs, err := ListSystemBackends(sysDefault)
|
||||
must(err)
|
||||
aliasBack, ok := backs.Get("llama-cpp")
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -77,7 +78,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
|
||||
must(err)
|
||||
sysNvidia.GPUVendor = "nvidia"
|
||||
sysNvidia.VRAM = 8 * 1024 * 1024 * 1024
|
||||
backs, err = ListSystemBackends(sysNvidia)
|
||||
backs, err = ListSystemBackends(sysNvidia)
|
||||
must(err)
|
||||
aliasBack, ok = backs.Get("llama-cpp")
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -116,13 +117,13 @@ var _ = Describe("Gallery Backends", func() {
|
||||
|
||||
Describe("InstallBackendFromGallery", func() {
|
||||
It("should return error when backend is not found", func() {
|
||||
err := InstallBackendFromGallery(galleries, systemState, ml, "non-existent", nil, true)
|
||||
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
|
||||
})
|
||||
|
||||
It("should install backend from gallery", func() {
|
||||
err := InstallBackendFromGallery(galleries, systemState, ml, "test-backend", nil, true)
|
||||
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
|
||||
})
|
||||
@@ -298,7 +299,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -378,7 +379,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -462,7 +463,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
VRAM: 1000000000000,
|
||||
Backend: system.Backend{BackendsPath: tempDir},
|
||||
}
|
||||
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
metaBackendPath := filepath.Join(tempDir, "meta-backend")
|
||||
@@ -561,7 +562,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
system.WithBackendPath(newPath),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
|
||||
Expect(newPath).To(BeADirectory())
|
||||
})
|
||||
@@ -593,7 +594,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
|
||||
@@ -626,7 +627,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
|
||||
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
})
|
||||
@@ -647,7 +648,7 @@ var _ = Describe("Gallery Backends", func() {
|
||||
system.WithBackendPath(tempDir),
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
err = InstallBackend(systemState, ml, &backend, nil)
|
||||
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -28,6 +29,19 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
|
||||
var config T
|
||||
uri := downloader.URI(url)
|
||||
err := uri.DownloadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
|
||||
return yaml.Unmarshal(d, &config)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url")
|
||||
return config, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func ReadConfigFile[T any](filePath string) (*T, error) {
|
||||
// Read the YAML file
|
||||
yamlFile, err := os.ReadFile(filePath)
|
||||
@@ -61,12 +75,15 @@ func (gm GalleryElements[T]) Search(term string) GalleryElements[T] {
|
||||
term = strings.ToLower(term)
|
||||
for _, m := range gm {
|
||||
if fuzzy.Match(term, strings.ToLower(m.GetName())) ||
|
||||
fuzzy.Match(term, strings.ToLower(m.GetDescription())) ||
|
||||
fuzzy.Match(term, strings.ToLower(m.GetGallery().Name)) ||
|
||||
strings.Contains(strings.ToLower(m.GetName()), term) ||
|
||||
strings.Contains(strings.ToLower(m.GetDescription()), term) ||
|
||||
strings.Contains(strings.ToLower(m.GetGallery().Name), term) ||
|
||||
strings.Contains(strings.ToLower(strings.Join(m.GetTags(), ",")), term) {
|
||||
filteredModels = append(filteredModels, m)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredModels
|
||||
}
|
||||
|
||||
|
||||
67
core/gallery/importers/importers.go
Normal file
67
core/gallery/importers/importers.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
)
|
||||
|
||||
var defaultImporters = []Importer{
|
||||
&LlamaCPPImporter{},
|
||||
&MLXImporter{},
|
||||
&VLLMImporter{},
|
||||
&TransformersImporter{},
|
||||
}
|
||||
|
||||
type Details struct {
|
||||
HuggingFace *hfapi.ModelDetails
|
||||
URI string
|
||||
Preferences json.RawMessage
|
||||
}
|
||||
|
||||
type Importer interface {
|
||||
Match(details Details) bool
|
||||
Import(details Details) (gallery.ModelConfig, error)
|
||||
}
|
||||
|
||||
func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.ModelConfig, error) {
|
||||
var err error
|
||||
var modelConfig gallery.ModelConfig
|
||||
|
||||
hf := hfapi.NewClient()
|
||||
|
||||
hfrepoID := strings.ReplaceAll(uri, "huggingface://", "")
|
||||
hfrepoID = strings.ReplaceAll(hfrepoID, "hf://", "")
|
||||
hfrepoID = strings.ReplaceAll(hfrepoID, "https://huggingface.co/", "")
|
||||
|
||||
hfDetails, err := hf.GetModelDetails(hfrepoID)
|
||||
if err != nil {
|
||||
// maybe not a HF repository
|
||||
// TODO: maybe we can check if the URI is a valid HF repository
|
||||
log.Debug().Str("uri", uri).Msg("Failed to get model details, maybe not a HF repository")
|
||||
} else {
|
||||
log.Debug().Str("uri", uri).Msg("Got model details")
|
||||
log.Debug().Any("details", hfDetails).Msg("Model details")
|
||||
}
|
||||
|
||||
details := Details{
|
||||
HuggingFace: hfDetails,
|
||||
URI: uri,
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
for _, importer := range defaultImporters {
|
||||
if importer.Match(details) {
|
||||
modelConfig, err = importer.Import(details)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return modelConfig, err
|
||||
}
|
||||
13
core/gallery/importers/importers_suite_test.go
Normal file
13
core/gallery/importers/importers_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestImporters(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Importers test suite")
|
||||
}
|
||||
215
core/gallery/importers/importers_test.go
Normal file
215
core/gallery/importers/importers_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("DiscoverModelConfig", func() {
|
||||
|
||||
Context("With only a repository URI", func() {
|
||||
It("should discover and import using LlamaCPPImporter", func() {
|
||||
uri := "https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.Name).To(Equal("LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF/resolve/main/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].SHA256).To(Equal("4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("should discover and import using LlamaCPPImporter", func() {
|
||||
uri := "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: mmproj/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("mmproj/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("should discover and import using LlamaCPPImporter", func() {
|
||||
uri := "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"
|
||||
preferences := json.RawMessage(`{ "quantizations": "Q8_0", "mmproj_quantizations": "f16" }`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
|
||||
Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: mmproj/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].Filename).To(Equal("mmproj/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with .gguf URI", func() {
|
||||
It("should discover and import using LlamaCPPImporter", func() {
|
||||
uri := "https://example.com/my-model.gguf"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-model.gguf"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://example.com/my-model.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
|
||||
})
|
||||
|
||||
It("should use custom preferences when provided", func() {
|
||||
uri := "https://example.com/my-model.gguf"
|
||||
preferences := json.RawMessage(`{"name": "custom-name", "description": "Custom description"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-name"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom description"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with mlx-community URI", func() {
|
||||
It("should discover and import using MLXImporter", func() {
|
||||
uri := "https://huggingface.co/mlx-community/test-model"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("test-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mlx-community/test-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
|
||||
})
|
||||
|
||||
It("should use custom preferences when provided", func() {
|
||||
uri := "https://huggingface.co/mlx-community/test-model"
|
||||
preferences := json.RawMessage(`{"name": "custom-mlx", "description": "Custom MLX description"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-mlx"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom MLX description"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with backend preference", func() {
|
||||
It("should use llama-cpp backend when specified", func() {
|
||||
uri := "https://example.com/model"
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
|
||||
})
|
||||
|
||||
It("should use mlx backend when specified", func() {
|
||||
uri := "https://example.com/model"
|
||||
preferences := json.RawMessage(`{"backend": "mlx"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
|
||||
})
|
||||
|
||||
It("should use mlx-vlm backend when specified", func() {
|
||||
uri := "https://example.com/model"
|
||||
preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx-vlm"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with HuggingFace URI formats", func() {
|
||||
It("should handle huggingface:// prefix", func() {
|
||||
uri := "huggingface://mlx-community/test-model"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("test-model"))
|
||||
})
|
||||
|
||||
It("should handle hf:// prefix", func() {
|
||||
uri := "hf://mlx-community/test-model"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("test-model"))
|
||||
})
|
||||
|
||||
It("should handle https://huggingface.co/ prefix", func() {
|
||||
uri := "https://huggingface.co/mlx-community/test-model"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("test-model"))
|
||||
})
|
||||
})
|
||||
|
||||
Context("with invalid or non-matching URI", func() {
|
||||
It("should return error when no importer matches", func() {
|
||||
uri := "https://example.com/unknown-model.bin"
|
||||
preferences := json.RawMessage(`{}`)
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
// When no importer matches, the function returns empty config and error
|
||||
// The exact behavior depends on implementation, but typically an error is returned
|
||||
Expect(modelConfig.Name).To(BeEmpty())
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with invalid JSON preferences", func() {
|
||||
It("should return error when JSON is invalid even if URI matches", func() {
|
||||
uri := "https://example.com/model.gguf"
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
|
||||
// Even though Match() returns true for .gguf extension,
|
||||
// Import() will fail when trying to unmarshal invalid JSON preferences
|
||||
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
|
||||
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
209
core/gallery/importers/llama-cpp.go
Normal file
209
core/gallery/importers/llama-cpp.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &LlamaCPPImporter{}
|
||||
|
||||
type LlamaCPPImporter struct{}
|
||||
|
||||
func (i *LlamaCPPImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if preferencesMap["backend"] == "llama-cpp" {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasSuffix(details.URI, ".gguf") {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
if strings.HasSuffix(file.Path, ".gguf") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
preferedQuantizations, _ := preferencesMap["quantizations"].(string)
|
||||
quants := []string{"q4_k_m"}
|
||||
if preferedQuantizations != "" {
|
||||
quants = strings.Split(preferedQuantizations, ",")
|
||||
}
|
||||
|
||||
mmprojQuants, _ := preferencesMap["mmproj_quantizations"].(string)
|
||||
mmprojQuantsList := []string{"fp16"}
|
||||
if mmprojQuants != "" {
|
||||
mmprojQuantsList = strings.Split(mmprojQuants, ",")
|
||||
}
|
||||
|
||||
embeddings, _ := preferencesMap["embeddings"].(string)
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Options: []string{"use_jinja:true"},
|
||||
Backend: "llama-cpp",
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
},
|
||||
FunctionsConfig: functions.FunctionsConfig{
|
||||
GrammarConfig: functions.GrammarConfig{
|
||||
NoGrammar: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if embeddings != "" && strings.ToLower(embeddings) == "true" || strings.ToLower(embeddings) == "yes" {
|
||||
trueV := true
|
||||
modelConfig.Embeddings = &trueV
|
||||
}
|
||||
|
||||
cfg := gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
}
|
||||
|
||||
if strings.HasSuffix(details.URI, ".gguf") {
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: details.URI,
|
||||
Filename: filepath.Base(details.URI),
|
||||
})
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: filepath.Base(details.URI),
|
||||
},
|
||||
}
|
||||
} else if details.HuggingFace != nil {
|
||||
// We want to:
|
||||
// Get first the chosen quants that match filenames
|
||||
// OR the first mmproj/gguf file found
|
||||
var lastMMProjFile *gallery.File
|
||||
var lastGGUFFile *gallery.File
|
||||
foundPreferedQuant := false
|
||||
foundPreferedMMprojQuant := false
|
||||
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
// Get the mmproj prefered quants
|
||||
if strings.Contains(strings.ToLower(file.Path), "mmproj") {
|
||||
lastMMProjFile = &gallery.File{
|
||||
URI: file.URL,
|
||||
Filename: filepath.Join("mmproj", filepath.Base(file.Path)),
|
||||
SHA256: file.SHA256,
|
||||
}
|
||||
if slices.ContainsFunc(mmprojQuantsList, func(quant string) bool {
|
||||
return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
|
||||
}) {
|
||||
cfg.Files = append(cfg.Files, *lastMMProjFile)
|
||||
foundPreferedMMprojQuant = true
|
||||
}
|
||||
} else if strings.HasSuffix(strings.ToLower(file.Path), "gguf") {
|
||||
lastGGUFFile = &gallery.File{
|
||||
URI: file.URL,
|
||||
Filename: filepath.Base(file.Path),
|
||||
SHA256: file.SHA256,
|
||||
}
|
||||
// get the files of the prefered quants
|
||||
if slices.ContainsFunc(quants, func(quant string) bool {
|
||||
return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
|
||||
}) {
|
||||
foundPreferedQuant = true
|
||||
cfg.Files = append(cfg.Files, *lastGGUFFile)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure to add at least one file if not already present (which is the latest one)
|
||||
if lastMMProjFile != nil && !foundPreferedMMprojQuant {
|
||||
if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
|
||||
return f.Filename == lastMMProjFile.Filename
|
||||
}) {
|
||||
cfg.Files = append(cfg.Files, *lastMMProjFile)
|
||||
}
|
||||
}
|
||||
|
||||
if lastGGUFFile != nil && !foundPreferedQuant {
|
||||
if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
|
||||
return f.Filename == lastGGUFFile.Filename
|
||||
}) {
|
||||
cfg.Files = append(cfg.Files, *lastGGUFFile)
|
||||
}
|
||||
}
|
||||
|
||||
// Find first mmproj file and configure it in the config file
|
||||
for _, file := range cfg.Files {
|
||||
if !strings.Contains(strings.ToLower(file.Filename), "mmproj") {
|
||||
continue
|
||||
}
|
||||
modelConfig.MMProj = file.Filename
|
||||
break
|
||||
}
|
||||
|
||||
// Find first non-mmproj file and configure it in the config file
|
||||
for _, file := range cfg.Files {
|
||||
if strings.Contains(strings.ToLower(file.Filename), "mmproj") {
|
||||
continue
|
||||
}
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: file.Filename,
|
||||
},
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
cfg.ConfigFile = string(data)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
132
core/gallery/importers/llama-cpp_test.go
Normal file
132
core/gallery/importers/llama-cpp_test.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("LlamaCPPImporter", func() {
|
||||
var importer *LlamaCPPImporter
|
||||
|
||||
BeforeEach(func() {
|
||||
importer = &LlamaCPPImporter{}
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("should match when URI ends with .gguf", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/model.gguf",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when backend preference is llama-cpp", func() {
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should not match when URI does not end with .gguf and no backend preference", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/model.bin",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should not match when backend preference is different", func() {
|
||||
preferences := json.RawMessage(`{"backend": "mlx"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when JSON preferences are invalid", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
// Invalid JSON causes Match to return false early
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("should import model config with default name and description", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-model.gguf"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://example.com/my-model.gguf"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("should import model config with custom name and description from preferences", func() {
|
||||
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom description"))
|
||||
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
|
||||
It("should handle invalid JSON preferences", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://example.com/my-model.gguf",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
_, err := importer.Import(details)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should extract filename correctly from URI with path", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/path/to/model.gguf",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/path/to/model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
Expect(modelConfig.Files[0].Filename).To(Equal("model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
|
||||
})
|
||||
})
|
||||
})
|
||||
94
core/gallery/importers/mlx.go
Normal file
94
core/gallery/importers/mlx.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &MLXImporter{}
|
||||
|
||||
type MLXImporter struct{}
|
||||
|
||||
func (i *MLXImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok && b == "mlx" || b == "mlx-vlm" {
|
||||
return true
|
||||
}
|
||||
|
||||
// All https://huggingface.co/mlx-community/*
|
||||
if strings.Contains(details.URI, "mlx-community/") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
backend := "mlx"
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok {
|
||||
backend = b
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
},
|
||||
},
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
147
core/gallery/importers/mlx_test.go
Normal file
147
core/gallery/importers/mlx_test.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("MLXImporter", func() {
|
||||
var importer *importers.MLXImporter
|
||||
|
||||
BeforeEach(func() {
|
||||
importer = &importers.MLXImporter{}
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("should match when URI contains mlx-community/", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/test-model",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when backend preference is mlx", func() {
|
||||
preferences := json.RawMessage(`{"backend": "mlx"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when backend preference is mlx-vlm", func() {
|
||||
preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should not match when URI does not contain mlx-community/ and no backend preference", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/other-org/test-model",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should not match when backend preference is different", func() {
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when JSON preferences are invalid", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/test-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
// Invalid JSON causes Match to return false early
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("should import model config with default name and description", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/test-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("test-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mlx-community/test-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/mlx-community/test-model"))
|
||||
})
|
||||
|
||||
It("should import model config with custom name and description from preferences", func() {
|
||||
preferences := json.RawMessage(`{"name": "custom-mlx-model", "description": "Custom MLX description"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/test-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-mlx-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom MLX description"))
|
||||
})
|
||||
|
||||
It("should use custom backend from preferences", func() {
|
||||
preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/test-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx-vlm"))
|
||||
})
|
||||
|
||||
It("should handle invalid JSON preferences", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/test-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
_, err := importer.Import(details)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should extract filename correctly from URI with path", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/mlx-community/path/to/model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("model"))
|
||||
})
|
||||
})
|
||||
})
|
||||
110
core/gallery/importers/transformers.go
Normal file
110
core/gallery/importers/transformers.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &TransformersImporter{}
|
||||
|
||||
type TransformersImporter struct{}
|
||||
|
||||
func (i *TransformersImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok && b == "transformers" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
if strings.Contains(file.Path, "tokenizer.json") ||
|
||||
strings.Contains(file.Path, "tokenizer_config.json") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
backend := "transformers"
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok {
|
||||
backend = b
|
||||
}
|
||||
|
||||
modelType, ok := preferencesMap["type"].(string)
|
||||
if !ok {
|
||||
modelType = "AutoModelForCausalLM"
|
||||
}
|
||||
|
||||
quantization, ok := preferencesMap["quantization"].(string)
|
||||
if !ok {
|
||||
quantization = ""
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
},
|
||||
},
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
},
|
||||
}
|
||||
modelConfig.ModelType = modelType
|
||||
modelConfig.Quantization = quantization
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
219
core/gallery/importers/transformers_test.go
Normal file
219
core/gallery/importers/transformers_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("TransformersImporter", func() {
|
||||
var importer *TransformersImporter
|
||||
|
||||
BeforeEach(func() {
|
||||
importer = &TransformersImporter{}
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("should match when backend preference is transformers", func() {
|
||||
preferences := json.RawMessage(`{"backend": "transformers"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain tokenizer.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain tokenizer_config.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer_config.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should not match when URI has no tokenizer files and no backend preference", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/model.bin",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should not match when backend preference is different", func() {
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when JSON preferences are invalid", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("should import model config with default name and description", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
|
||||
})
|
||||
|
||||
It("should import model config with custom name and description from preferences", func() {
|
||||
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom description"))
|
||||
})
|
||||
|
||||
It("should use custom model type from preferences", func() {
|
||||
preferences := json.RawMessage(`{"type": "SentenceTransformer"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: SentenceTransformer"))
|
||||
})
|
||||
|
||||
It("should use default model type when not specified", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
|
||||
})
|
||||
|
||||
It("should use custom backend from preferences", func() {
|
||||
preferences := json.RawMessage(`{"backend": "transformers"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
|
||||
})
|
||||
|
||||
It("should use quantization from preferences", func() {
|
||||
preferences := json.RawMessage(`{"quantization": "int8"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("quantization: int8"))
|
||||
})
|
||||
|
||||
It("should handle invalid JSON preferences", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
_, err := importer.Import(details)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should extract filename correctly from URI with path", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/test/path/to/model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("model"))
|
||||
})
|
||||
|
||||
It("should include use_tokenizer_template in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
|
||||
})
|
||||
|
||||
It("should include known_usecases in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
|
||||
})
|
||||
})
|
||||
})
|
||||
98
core/gallery/importers/vllm.go
Normal file
98
core/gallery/importers/vllm.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &VLLMImporter{}
|
||||
|
||||
type VLLMImporter struct{}
|
||||
|
||||
func (i *VLLMImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok && b == "vllm" {
|
||||
return true
|
||||
}
|
||||
|
||||
if details.HuggingFace != nil {
|
||||
for _, file := range details.HuggingFace.Files {
|
||||
if strings.Contains(file.Path, "tokenizer.json") ||
|
||||
strings.Contains(file.Path, "tokenizer_config.json") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
err = json.Unmarshal(preferences, &preferencesMap)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
backend := "vllm"
|
||||
b, ok := preferencesMap["backend"].(string)
|
||||
if ok {
|
||||
backend = b
|
||||
}
|
||||
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
KnownUsecaseStrings: []string{"chat"},
|
||||
Backend: backend,
|
||||
PredictionOptions: schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{
|
||||
Model: details.URI,
|
||||
},
|
||||
},
|
||||
TemplateConfig: config.TemplateConfig{
|
||||
UseTokenizerTemplate: true,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
|
||||
return gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
ConfigFile: string(data),
|
||||
}, nil
|
||||
}
|
||||
181
core/gallery/importers/vllm_test.go
Normal file
181
core/gallery/importers/vllm_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
. "github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("VLLMImporter", func() {
|
||||
var importer *VLLMImporter
|
||||
|
||||
BeforeEach(func() {
|
||||
importer = &VLLMImporter{}
|
||||
})
|
||||
|
||||
Context("Match", func() {
|
||||
It("should match when backend preference is vllm", func() {
|
||||
preferences := json.RawMessage(`{"backend": "vllm"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain tokenizer.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should match when HuggingFace details contain tokenizer_config.json", func() {
|
||||
hfDetails := &hfapi.ModelDetails{
|
||||
Files: []hfapi.ModelFile{
|
||||
{Path: "tokenizer_config.json"},
|
||||
},
|
||||
}
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/model",
|
||||
HuggingFace: hfDetails,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeTrue())
|
||||
})
|
||||
|
||||
It("should not match when URI has no tokenizer files and no backend preference", func() {
|
||||
details := Details{
|
||||
URI: "https://example.com/model.bin",
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should not match when backend preference is different", func() {
|
||||
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
|
||||
It("should return false when JSON preferences are invalid", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://example.com/model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
result := importer.Match(details)
|
||||
Expect(result).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("Import", func() {
|
||||
It("should import model config with default name and description", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("my-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
|
||||
})
|
||||
|
||||
It("should import model config with custom name and description from preferences", func() {
|
||||
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("custom-model"))
|
||||
Expect(modelConfig.Description).To(Equal("Custom description"))
|
||||
})
|
||||
|
||||
It("should use custom backend from preferences", func() {
|
||||
preferences := json.RawMessage(`{"backend": "vllm"}`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
|
||||
})
|
||||
|
||||
It("should handle invalid JSON preferences", func() {
|
||||
preferences := json.RawMessage(`invalid json`)
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
Preferences: preferences,
|
||||
}
|
||||
|
||||
_, err := importer.Import(details)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("should extract filename correctly from URI with path", func() {
|
||||
details := importers.Details{
|
||||
URI: "https://huggingface.co/test/path/to/model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.Name).To(Equal("model"))
|
||||
})
|
||||
|
||||
It("should include use_tokenizer_template in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
|
||||
})
|
||||
|
||||
It("should include known_usecases in config", func() {
|
||||
details := Details{
|
||||
URI: "https://huggingface.co/test/my-model",
|
||||
}
|
||||
|
||||
modelConfig, err := importer.Import(details)
|
||||
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
|
||||
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -72,6 +73,7 @@ type PromptTemplate struct {
|
||||
|
||||
// Installs a model from the gallery
|
||||
func InstallModelFromGallery(
|
||||
ctx context.Context,
|
||||
modelGalleries, backendGalleries []config.Gallery,
|
||||
systemState *system.SystemState,
|
||||
modelLoader *model.ModelLoader,
|
||||
@@ -84,7 +86,7 @@ func InstallModelFromGallery(
|
||||
|
||||
if len(model.URL) > 0 {
|
||||
var err error
|
||||
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath)
|
||||
config, err = GetGalleryConfigFromURLWithContext[ModelConfig](ctx, model.URL, systemState.Model.ModelsPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -125,7 +127,7 @@ func InstallModelFromGallery(
|
||||
return err
|
||||
}
|
||||
|
||||
installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
|
||||
installedModel, err := InstallModel(ctx, systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -133,7 +135,7 @@ func InstallModelFromGallery(
|
||||
if automaticallyInstallBackend && installedModel.Backend != "" {
|
||||
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
|
||||
|
||||
if err := InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
|
||||
if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -154,7 +156,7 @@ func InstallModelFromGallery(
|
||||
return applyModel(model)
|
||||
}
|
||||
|
||||
func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
||||
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
|
||||
basePath := systemState.Model.ModelsPath
|
||||
// Create base path if it doesn't exist
|
||||
err := os.MkdirAll(basePath, 0750)
|
||||
@@ -168,6 +170,13 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
|
||||
|
||||
// Download files and verify their SHA
|
||||
for i, file := range config.Files {
|
||||
// Check for cancellation before each file
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
|
||||
|
||||
if err := utils.VerifyPath(file.Filename, basePath); err != nil {
|
||||
@@ -185,7 +194,7 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
|
||||
}
|
||||
}
|
||||
uri := downloader.URI(file.URI)
|
||||
if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
|
||||
if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package gallery_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -34,7 +35,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
|
||||
@@ -88,7 +89,7 @@ var _ = Describe("Model test", func() {
|
||||
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
|
||||
Expect(models[0].Installed).To(BeFalse())
|
||||
|
||||
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
|
||||
err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
|
||||
@@ -129,7 +130,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||
@@ -149,7 +150,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
|
||||
@@ -179,7 +180,7 @@ var _ = Describe("Model test", func() {
|
||||
system.WithModelPath(tempdir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
_, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
238
core/http/app.go
238
core/http/app.go
@@ -4,30 +4,23 @@ import (
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/dave-gray101/v2keyauth"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/http/routes"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
|
||||
"github.com/gofiber/contrib/fiberzerolog"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/cors"
|
||||
"github.com/gofiber/fiber/v2/middleware/csrf"
|
||||
"github.com/gofiber/fiber/v2/middleware/favicon"
|
||||
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
|
||||
// swagger handler
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -49,86 +42,85 @@ var embedDirStatic embed.FS
|
||||
// @in header
|
||||
// @name Authorization
|
||||
|
||||
func API(application *application.Application) (*fiber.App, error) {
|
||||
func API(application *application.Application) (*echo.Echo, error) {
|
||||
e := echo.New()
|
||||
|
||||
fiberCfg := fiber.Config{
|
||||
Views: renderEngine(),
|
||||
BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
|
||||
// We disable the Fiber startup message as it does not conform to structured logging.
|
||||
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
||||
DisableStartupMessage: true,
|
||||
// Override default error handler
|
||||
// Set body limit
|
||||
if application.ApplicationConfig().UploadLimitMB > 0 {
|
||||
e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", application.ApplicationConfig().UploadLimitMB)))
|
||||
}
|
||||
|
||||
// Set error handler
|
||||
if !application.ApplicationConfig().OpaqueErrors {
|
||||
// Normally, return errors as JSON responses
|
||||
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
|
||||
// Status code defaults to 500
|
||||
code := fiber.StatusInternalServerError
|
||||
e.HTTPErrorHandler = func(err error, c echo.Context) {
|
||||
code := http.StatusInternalServerError
|
||||
var he *echo.HTTPError
|
||||
if errors.As(err, &he) {
|
||||
code = he.Code
|
||||
}
|
||||
|
||||
// Retrieve the custom status code if it's a *fiber.Error
|
||||
var e *fiber.Error
|
||||
if errors.As(err, &e) {
|
||||
code = e.Code
|
||||
// Handle 404 errors with HTML rendering when appropriate
|
||||
if code == http.StatusNotFound {
|
||||
notFoundHandler(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Send custom error page
|
||||
return ctx.Status(code).JSON(
|
||||
schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: err.Error(), Code: code},
|
||||
},
|
||||
)
|
||||
c.JSON(code, schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: err.Error(), Code: code},
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// If OpaqueErrors are required, replace everything with a blank 500.
|
||||
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, _ error) error {
|
||||
return ctx.Status(500).SendString("")
|
||||
e.HTTPErrorHandler = func(err error, c echo.Context) {
|
||||
code := http.StatusInternalServerError
|
||||
var he *echo.HTTPError
|
||||
if errors.As(err, &he) {
|
||||
code = he.Code
|
||||
}
|
||||
c.NoContent(code)
|
||||
}
|
||||
}
|
||||
|
||||
router := fiber.New(fiberCfg)
|
||||
// Set renderer
|
||||
e.Renderer = renderEngine()
|
||||
|
||||
router.Use(middleware.StripPathPrefix())
|
||||
// Hide banner
|
||||
e.HideBanner = true
|
||||
|
||||
// Middleware - StripPathPrefix must be registered early as it uses Rewrite which runs before routing
|
||||
e.Pre(httpMiddleware.StripPathPrefix())
|
||||
|
||||
if application.ApplicationConfig().MachineTag != "" {
|
||||
router.Use(func(c *fiber.Ctx) error {
|
||||
c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag)
|
||||
|
||||
return c.Next()
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
c.Response().Header().Set("Machine-Tag", application.ApplicationConfig().MachineTag)
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
router.Use("/v1/realtime", func(c *fiber.Ctx) error {
|
||||
if websocket.IsWebSocketUpgrade(c) {
|
||||
// Returns true if the client requested upgrade to the WebSocket protocol
|
||||
return c.Next()
|
||||
// Custom logger middleware using zerolog
|
||||
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
req := c.Request()
|
||||
res := c.Response()
|
||||
start := log.Logger.Info()
|
||||
err := next(c)
|
||||
start.
|
||||
Str("method", req.Method).
|
||||
Str("path", req.URL.Path).
|
||||
Int("status", res.Status).
|
||||
Msg("HTTP request")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
|
||||
scheme := "http"
|
||||
if listenData.TLS {
|
||||
scheme = "https"
|
||||
}
|
||||
log.Info().Str("endpoint", scheme+"://"+listenData.Host+":"+listenData.Port).Msg("LocalAI API is listening! Please connect to the endpoint for API documentation.")
|
||||
return nil
|
||||
})
|
||||
|
||||
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
|
||||
logger := log.Logger
|
||||
router.Use(fiberzerolog.New(fiberzerolog.Config{
|
||||
Logger: &logger,
|
||||
}))
|
||||
|
||||
// Default middleware config
|
||||
|
||||
// Recover middleware
|
||||
if !application.ApplicationConfig().Debug {
|
||||
router.Use(recover.New())
|
||||
e.Use(middleware.Recover())
|
||||
}
|
||||
|
||||
// OpenTelemetry metrics for Prometheus export
|
||||
// Metrics middleware
|
||||
if !application.ApplicationConfig().DisableMetrics {
|
||||
metricsService, err := services.NewLocalAIMetricsService()
|
||||
if err != nil {
|
||||
@@ -136,35 +128,40 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||
}
|
||||
|
||||
if metricsService != nil {
|
||||
router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
router.Hooks().OnShutdown(func() error {
|
||||
return metricsService.Shutdown()
|
||||
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
metricsService.Shutdown()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Health Checks should always be exempt from auth, so register these first
|
||||
routes.HealthRoutes(router)
|
||||
routes.HealthRoutes(e)
|
||||
|
||||
kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
|
||||
if err != nil || kaConfig == nil {
|
||||
// Get key auth middleware
|
||||
keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create key auth config: %w", err)
|
||||
}
|
||||
|
||||
httpFS := http.FS(embedDirStatic)
|
||||
// Favicon handler
|
||||
e.GET("/favicon.svg", func(c echo.Context) error {
|
||||
data, err := embedDirStatic.ReadFile("static/favicon.svg")
|
||||
if err != nil {
|
||||
return c.NoContent(http.StatusNotFound)
|
||||
}
|
||||
c.Response().Header().Set("Content-Type", "image/svg+xml")
|
||||
return c.Blob(http.StatusOK, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
router.Use(favicon.New(favicon.Config{
|
||||
URL: "/favicon.svg",
|
||||
FileSystem: httpFS,
|
||||
File: "static/favicon.svg",
|
||||
}))
|
||||
|
||||
router.Use("/static", filesystem.New(filesystem.Config{
|
||||
Root: httpFS,
|
||||
PathPrefix: "static",
|
||||
Browse: true,
|
||||
}))
|
||||
// Static files - use fs.Sub to create a filesystem rooted at "static"
|
||||
staticFS, err := fs.Sub(embedDirStatic, "static")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create static filesystem: %w", err)
|
||||
}
|
||||
e.StaticFS("/static", staticFS)
|
||||
|
||||
// Generated content directories
|
||||
if application.ApplicationConfig().GeneratedContentDir != "" {
|
||||
os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)
|
||||
audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
|
||||
@@ -175,62 +172,53 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||
os.MkdirAll(imagePath, 0750)
|
||||
os.MkdirAll(videoPath, 0750)
|
||||
|
||||
router.Static("/generated-audio", audioPath)
|
||||
router.Static("/generated-images", imagePath)
|
||||
router.Static("/generated-videos", videoPath)
|
||||
e.Static("/generated-audio", audioPath)
|
||||
e.Static("/generated-images", imagePath)
|
||||
e.Static("/generated-videos", videoPath)
|
||||
}
|
||||
|
||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
|
||||
router.Use(v2keyauth.New(*kaConfig))
|
||||
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration
|
||||
e.Use(keyAuthMiddleware)
|
||||
|
||||
// CORS middleware
|
||||
if application.ApplicationConfig().CORS {
|
||||
var c func(ctx *fiber.Ctx) error
|
||||
if application.ApplicationConfig().CORSAllowOrigins == "" {
|
||||
c = cors.New()
|
||||
} else {
|
||||
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
|
||||
corsConfig := middleware.CORSConfig{}
|
||||
if application.ApplicationConfig().CORSAllowOrigins != "" {
|
||||
corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",")
|
||||
}
|
||||
|
||||
router.Use(c)
|
||||
e.Use(middleware.CORSWithConfig(corsConfig))
|
||||
}
|
||||
|
||||
// CSRF middleware
|
||||
if application.ApplicationConfig().CSRF {
|
||||
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
|
||||
router.Use(csrf.New())
|
||||
e.Use(middleware.CSRF())
|
||||
}
|
||||
|
||||
requestExtractor := middleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
routes.RegisterElevenLabsRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
routes.RegisterLocalAIRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
|
||||
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
|
||||
var opcache *services.OpCache
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
|
||||
// Create metrics store for tracking usage (before API routes registration)
|
||||
metricsStore := services.NewInMemoryMetricsStore()
|
||||
|
||||
// Add metrics middleware BEFORE API routes so it can intercept them
|
||||
router.Use(middleware.MetricsMiddleware(metricsStore))
|
||||
|
||||
// Register cleanup on shutdown
|
||||
router.Hooks().OnShutdown(func() error {
|
||||
metricsStore.Stop()
|
||||
log.Info().Msg("Metrics store stopped")
|
||||
return nil
|
||||
})
|
||||
|
||||
// Create opcache for tracking UI operations
|
||||
opcache := services.NewOpCache(application.GalleryService())
|
||||
routes.RegisterUIAPIRoutes(router, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, metricsStore)
|
||||
routes.RegisterUIRoutes(router, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
opcache = services.NewOpCache(application.GalleryService())
|
||||
}
|
||||
|
||||
routes.RegisterJINARoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator())
|
||||
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
|
||||
if !application.ApplicationConfig().DisableWebUI {
|
||||
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
|
||||
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
|
||||
}
|
||||
routes.RegisterJINARoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
|
||||
|
||||
// Define a custom 404 handler
|
||||
// Note: keep this at the bottom!
|
||||
router.Use(notFoundHandler)
|
||||
// Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route
|
||||
|
||||
return router, nil
|
||||
// Log startup message
|
||||
e.Server.RegisterOnShutdown(func() {
|
||||
log.Info().Msg("LocalAI API server shutting down")
|
||||
})
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
@@ -10,13 +10,14 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -25,6 +26,7 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
openaigo "github.com/otiai10/openaigo"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/jsonschema"
|
||||
)
|
||||
@@ -85,7 +87,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
|
||||
response := []gallery.GalleryModel{}
|
||||
uri := downloader.URI(url)
|
||||
// TODO: No tests currently seem to exercise file:// urls. Fix?
|
||||
err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
|
||||
err := uri.DownloadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error {
|
||||
// Unmarshal YAML data into a struct
|
||||
return json.Unmarshal(i, &response)
|
||||
})
|
||||
@@ -266,7 +268,7 @@ const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b8
|
||||
|
||||
var _ = Describe("API test", func() {
|
||||
|
||||
var app *fiber.App
|
||||
var app *echo.Echo
|
||||
var client *openai.Client
|
||||
var client2 *openaigo.Client
|
||||
var c context.Context
|
||||
@@ -339,7 +341,11 @@ var _ = Describe("API test", func() {
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig(apiKey)
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -358,7 +364,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func(sc SpecContext) {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := os.RemoveAll(tmpdir)
|
||||
@@ -547,7 +555,11 @@ var _ = Describe("API test", func() {
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -566,7 +578,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
err := os.RemoveAll(tmpdir)
|
||||
@@ -755,7 +769,11 @@ var _ = Describe("API test", func() {
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -773,7 +791,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
@@ -796,6 +816,83 @@ var _ = Describe("API test", func() {
|
||||
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("returns logprobs in chat completions when requested", func() {
|
||||
topLogprobsVal := 3
|
||||
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
|
||||
Model: "testmodel.ggml",
|
||||
LogProbs: true,
|
||||
TopLogProbs: topLogprobsVal,
|
||||
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(len(response.Choices)).To(Equal(1))
|
||||
Expect(response.Choices[0].Message).ToNot(BeNil())
|
||||
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
|
||||
// Verify logprobs are present and have correct structure
|
||||
Expect(response.Choices[0].LogProbs).ToNot(BeNil())
|
||||
Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty())
|
||||
|
||||
Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1))
|
||||
|
||||
foundatLeastToken := ""
|
||||
foundAtLeastBytes := []byte{}
|
||||
foundAtLeastTopLogprobBytes := []byte{}
|
||||
foundatLeastTopLogprob := ""
|
||||
// Verify logprobs content structure matches OpenAI format
|
||||
for _, logprobContent := range response.Choices[0].LogProbs.Content {
|
||||
// Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it
|
||||
if len(logprobContent.Bytes) > 0 {
|
||||
foundAtLeastBytes = logprobContent.Bytes
|
||||
}
|
||||
if len(logprobContent.Token) > 0 {
|
||||
foundatLeastToken = logprobContent.Token
|
||||
}
|
||||
Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0
|
||||
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1))
|
||||
|
||||
// If top_logprobs is requested, verify top_logprobs array respects the limit
|
||||
if len(logprobContent.TopLogProbs) > 0 {
|
||||
// Should respect top_logprobs limit (3 in this test)
|
||||
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal))
|
||||
for _, topLogprob := range logprobContent.TopLogProbs {
|
||||
if len(topLogprob.Bytes) > 0 {
|
||||
foundAtLeastTopLogprobBytes = topLogprob.Bytes
|
||||
}
|
||||
if len(topLogprob.Token) > 0 {
|
||||
foundatLeastTopLogprob = topLogprob.Token
|
||||
}
|
||||
Expect(topLogprob.LogProb).To(BeNumerically("<=", 0))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Expect(foundAtLeastBytes).ToNot(BeEmpty())
|
||||
Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty())
|
||||
Expect(foundatLeastToken).ToNot(BeEmpty())
|
||||
Expect(foundatLeastTopLogprob).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("applies logit_bias to chat completions when requested", func() {
|
||||
// logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
|
||||
// According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion
|
||||
logitBias := map[string]int{
|
||||
"15043": 1, // Bias token ID 15043 (example token ID) with bias value 1
|
||||
}
|
||||
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
|
||||
Model: "testmodel.ggml",
|
||||
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
|
||||
LogitBias: logitBias,
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(response.Choices)).To(Equal(1))
|
||||
Expect(response.Choices[0].Message).ToNot(BeNil())
|
||||
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
|
||||
// If logit_bias is applied, the response should be generated successfully
|
||||
// We can't easily verify the bias effect without knowing the actual token IDs for the model,
|
||||
// but the fact that the request succeeds confirms the API accepts and processes logit_bias
|
||||
})
|
||||
|
||||
It("returns errors", func() {
|
||||
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
|
||||
Expect(err).To(HaveOccurred())
|
||||
@@ -1006,7 +1103,11 @@ var _ = Describe("API test", func() {
|
||||
app, err = API(application)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
go app.Listen("127.0.0.1:9090")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
|
||||
log.Error().Err(err).Msg("server error")
|
||||
}
|
||||
}()
|
||||
|
||||
defaultConfig := openai.DefaultConfig("")
|
||||
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
|
||||
@@ -1022,7 +1123,9 @@ var _ = Describe("API test", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
err := app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
err := app.Shutdown(ctx)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -15,17 +17,17 @@ import (
|
||||
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/sound-generation [post]
|
||||
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
|
||||
if !ok || input.ModelID == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
|
||||
@@ -35,7 +37,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package elevenlabs
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -17,19 +18,19 @@ import (
|
||||
// @Param request body schema.TTSRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/text-to-speech/{voice-id} [post]
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
voiceID := c.Params("voice-id")
|
||||
voiceID := c.Param("voice-id")
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
|
||||
if !ok || input.ModelID == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request received")
|
||||
@@ -38,6 +39,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Download(filePath)
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,28 +2,32 @@ package explorer
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/explorer"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
)
|
||||
|
||||
func Dashboard() func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
summary := fiber.Map{
|
||||
func Dashboard() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI API - " + internal.PrintableVersion(),
|
||||
"Version": internal.PrintableVersion(),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
}
|
||||
|
||||
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
accept := c.Request().Header.Get("Accept")
|
||||
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "html")) {
|
||||
// The client expects a JSON response
|
||||
return c.Status(fiber.StatusOK).JSON(summary)
|
||||
return c.JSON(http.StatusOK, summary)
|
||||
} else {
|
||||
// Render index
|
||||
return c.Render("views/explorer", summary)
|
||||
return c.Render(http.StatusOK, "views/explorer", summary)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -39,8 +43,8 @@ type Network struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
results := []Network{}
|
||||
for _, token := range db.TokenList() {
|
||||
networkData, exists := db.Get(token) // get the token data
|
||||
@@ -61,44 +65,44 @@ func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
|
||||
return len(results[i].Clusters) > len(results[j].Clusters)
|
||||
})
|
||||
|
||||
return c.JSON(results)
|
||||
return c.JSON(http.StatusOK, results)
|
||||
}
|
||||
}
|
||||
|
||||
func AddNetwork(db *explorer.Database) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func AddNetwork(db *explorer.Database) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
request := new(AddNetworkRequest)
|
||||
if err := c.BodyParser(request); err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
|
||||
if err := c.Bind(request); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"})
|
||||
}
|
||||
|
||||
if request.Token == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"})
|
||||
}
|
||||
|
||||
if request.Name == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Name is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"})
|
||||
}
|
||||
|
||||
if request.Description == "" {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Description is required"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"})
|
||||
}
|
||||
|
||||
// TODO: check if token is valid, otherwise reject
|
||||
// try to decode the token from base64
|
||||
_, err := base64.StdEncoding.DecodeString(request.Token)
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid token"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"})
|
||||
}
|
||||
|
||||
if _, exists := db.Get(request.Token); exists {
|
||||
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token already exists"})
|
||||
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"})
|
||||
}
|
||||
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
|
||||
if err != nil {
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Cannot add token"})
|
||||
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"})
|
||||
}
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Token added"})
|
||||
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
package jina
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -17,24 +18,36 @@ import (
|
||||
// @Param request body schema.JINARerankRequest true "query params"
|
||||
// @Success 200 {object} schema.JINARerankResponse "Response"
|
||||
// @Router /v1/rerank [post]
|
||||
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received")
|
||||
|
||||
var requestTopN int32
|
||||
docs := int32(len(input.Documents))
|
||||
if input.TopN == nil { // omit top_n to get all
|
||||
requestTopN = docs
|
||||
} else {
|
||||
requestTopN = int32(*input.TopN)
|
||||
if requestTopN < 1 {
|
||||
return c.JSON(http.StatusUnprocessableEntity, "top_n - should be greater than or equal to 1")
|
||||
}
|
||||
if requestTopN > docs { // make it more obvious for backends
|
||||
requestTopN = docs
|
||||
}
|
||||
}
|
||||
request := &proto.RerankRequest{
|
||||
Query: input.Query,
|
||||
TopN: int32(input.TopN),
|
||||
TopN: requestTopN,
|
||||
Documents: input.Documents,
|
||||
}
|
||||
|
||||
@@ -58,6 +71,6 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
response.Usage.TotalTokens = int(results.Usage.TotalTokens)
|
||||
response.Usage.PromptTokens = int(results.Usage.PromptTokens)
|
||||
|
||||
return c.Status(fiber.StatusOK).JSON(response)
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -39,13 +39,13 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste
|
||||
// @Summary Returns the job status
|
||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
||||
// @Router /backends/jobs/{uuid} [get]
|
||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
status := mgs.backendApplier.GetStatus(c.Params("uuid"))
|
||||
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
status := mgs.backendApplier.GetStatus(c.Param("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(status)
|
||||
return c.JSON(200, status)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,9 +53,9 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) erro
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
||||
// @Router /backends/jobs [get]
|
||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(mgs.backendApplier.GetAllStatus())
|
||||
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, mgs.backendApplier.GetAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -64,11 +64,11 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) err
|
||||
// @Param request body GalleryBackend true "query params"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/apply [post]
|
||||
func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(GalleryBackend)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -76,13 +76,13 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend]{
|
||||
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: input.ID,
|
||||
Galleries: mgs.galleries,
|
||||
}
|
||||
|
||||
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,11 +91,11 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) err
|
||||
// @Param name path string true "Backend name"
|
||||
// @Success 200 {object} schema.BackendResponse "Response"
|
||||
// @Router /backends/delete/{name} [post]
|
||||
func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backendName := c.Params("name")
|
||||
func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backendName := c.Param("name")
|
||||
|
||||
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend]{
|
||||
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
|
||||
Delete: true,
|
||||
GalleryElementName: backendName,
|
||||
Galleries: mgs.galleries,
|
||||
@@ -106,7 +106,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,13 +114,13 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
|
||||
// @Summary List all Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends [get]
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(backends.GetAll())
|
||||
return c.JSON(200, backends.GetAll())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,14 +129,14 @@ func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.Syst
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /backends/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
log.Debug().Msgf("Listing backend galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
return c.Blob(200, "application/json", dat)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,12 +144,12 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.
|
||||
// @Summary List all available Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends/available [get]
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(backends)
|
||||
return c.JSON(200, backends)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,45 +1,45 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// BackendMonitorEndpoint returns the status of the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Success 200 {object} proto.StatusResponse "Response"
|
||||
// @Router /backend/monitor [get]
|
||||
func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := bm.CheckAndSample(input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(resp)
|
||||
}
|
||||
}
|
||||
|
||||
// BackendShutdownEndpoint shuts down the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Router /backend/shutdown [post]
|
||||
func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bm.ShutdownModel(input.Model)
|
||||
}
|
||||
}
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
)
|
||||
|
||||
// BackendMonitorEndpoint returns the status of the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Success 200 {object} proto.StatusResponse "Response"
|
||||
// @Router /backend/monitor [get]
|
||||
func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := bm.CheckAndSample(input.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// BackendShutdownEndpoint shuts down the specified backend
|
||||
// @Summary Backend monitor endpoint
|
||||
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
|
||||
// @Router /backend/shutdown [post]
|
||||
func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.BackendMonitorRequest)
|
||||
// Get input data from the request body
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bm.ShutdownModel(input.Model)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -16,17 +16,17 @@ import (
|
||||
// @Param request body schema.DetectionRequest true "query params"
|
||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
||||
// @Router /v1/detection [post]
|
||||
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection")
|
||||
@@ -54,6 +54,6 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
|
||||
}
|
||||
}
|
||||
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,11 +2,13 @@ package localai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/utils"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
@@ -14,15 +16,15 @@ import (
|
||||
)
|
||||
|
||||
// GetEditModelPage renders the edit model page with current configuration
|
||||
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName := c.Params("name")
|
||||
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if modelName == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
@@ -31,7 +33,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Success: false,
|
||||
Error: "Model configuration not found",
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
return c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
|
||||
modelConfigFile := modelConfig.GetModelConfigFile()
|
||||
@@ -40,7 +42,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Success: false,
|
||||
Error: "Model configuration file not found",
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
return c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
configData, err := os.ReadFile(modelConfigFile)
|
||||
if err != nil {
|
||||
@@ -48,7 +50,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Success: false,
|
||||
Error: "Failed to read configuration file: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Render the edit page with the current configuration
|
||||
@@ -69,20 +71,20 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
|
||||
Version: internal.PrintableVersion(),
|
||||
}
|
||||
|
||||
return c.Render("views/model-editor", templateData)
|
||||
return c.Render(http.StatusOK, "views/model-editor", templateData)
|
||||
}
|
||||
}
|
||||
|
||||
// EditModelEndpoint handles updating existing model configurations
|
||||
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName := c.Params("name")
|
||||
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
if modelName == "" {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Model name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
modelConfig, exists := cl.GetModelConfig(modelName)
|
||||
@@ -91,17 +93,24 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Existing model configuration not found",
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
return c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
|
||||
// Get the raw body
|
||||
body := c.Body()
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to read request body: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
if len(body) == 0 {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Request body is empty",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Check content to see if it's a valid model config
|
||||
@@ -113,7 +122,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
@@ -122,7 +131,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Validate the configuration
|
||||
@@ -132,7 +141,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Error: "Validation failed",
|
||||
Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."},
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Load the existing configuration
|
||||
@@ -142,7 +151,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Model configuration not trusted: " + err.Error(),
|
||||
}
|
||||
return c.Status(404).JSON(response)
|
||||
return c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
|
||||
// Write new content to file
|
||||
@@ -151,16 +160,16 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Reload configurations
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Preload the model
|
||||
@@ -169,7 +178,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Success: false,
|
||||
Error: "Failed to preload model: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Return success response
|
||||
@@ -179,20 +188,20 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
|
||||
Filename: configPath,
|
||||
Config: req,
|
||||
}
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
// ReloadModelsEndpoint handles reloading model configurations from disk
|
||||
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Reload configurations
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Preload the models
|
||||
@@ -201,7 +210,7 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
|
||||
Success: false,
|
||||
Error: "Failed to preload models: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Return success response
|
||||
@@ -209,6 +218,6 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
|
||||
Success: true,
|
||||
Message: "Model configurations reloaded successfully",
|
||||
}
|
||||
return c.Status(fiber.StatusOK).JSON(response)
|
||||
return c.JSON(http.StatusOK, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,12 +2,14 @@ package localai_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
@@ -15,6 +17,14 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// testRenderer is a simple renderer for tests that returns JSON
|
||||
type testRenderer struct{}
|
||||
|
||||
func (t *testRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
|
||||
// For tests, just return the data as JSON
|
||||
return json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
var _ = Describe("Edit Model test", func() {
|
||||
|
||||
var tempDir string
|
||||
@@ -40,33 +50,35 @@ var _ = Describe("Edit Model test", func() {
|
||||
//modelLoader := model.NewModelLoader(systemState, true)
|
||||
modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath)
|
||||
|
||||
// Define Fiber app.
|
||||
app := fiber.New()
|
||||
app.Put("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
|
||||
// Define Echo app and register all routes upfront
|
||||
app := echo.New()
|
||||
// Set up a simple renderer for the test
|
||||
app.Renderer = &testRenderer{}
|
||||
app.POST("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
|
||||
app.GET("/edit-model/:name", GetEditModelPage(modelConfigLoader, applicationConfig))
|
||||
|
||||
requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`)
|
||||
|
||||
req := httptest.NewRequest("PUT", "/import-model", requestBody)
|
||||
resp, err := app.Test(req, 5000)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
req := httptest.NewRequest("POST", "/import-model", requestBody)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(rec.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(ContainSubstring("Model configuration created successfully"))
|
||||
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
app.Get("/edit-model/:name", EditModelEndpoint(modelConfigLoader, applicationConfig))
|
||||
requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`)
|
||||
req = httptest.NewRequest("GET", "/edit-model/foo", nil)
|
||||
rec = httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
req = httptest.NewRequest("GET", "/edit-model/foo", requestBody)
|
||||
resp, _ = app.Test(req, 1)
|
||||
|
||||
body, err = io.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
body, err = io.ReadAll(rec.Body)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(string(body)).To(ContainSubstring(`"model":"foo"`))
|
||||
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
|
||||
// The response contains the model configuration with backend field
|
||||
Expect(string(body)).To(ContainSubstring(`"backend":"foo"`))
|
||||
Expect(string(body)).To(ContainSubstring(`"name":"foo"`))
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,160 +1,160 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ModelGalleryEndpointService struct {
|
||||
galleries []config.Gallery
|
||||
backendGalleries []config.Gallery
|
||||
modelPath string
|
||||
galleryApplier *services.GalleryService
|
||||
}
|
||||
|
||||
type GalleryModel struct {
|
||||
ID string `json:"id"`
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
return ModelGalleryEndpointService{
|
||||
galleries: galleries,
|
||||
backendGalleries: backendGalleries,
|
||||
modelPath: systemState.Model.ModelsPath,
|
||||
galleryApplier: galleryApplier,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOpStatusEndpoint returns the job status
|
||||
// @Summary Returns the job status
|
||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
||||
// @Router /models/jobs/{uuid} [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(status)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllStatusEndpoint returns all the jobs status progress
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
||||
// @Router /models/jobs [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(mgs.galleryApplier.GetAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
|
||||
// @Summary Install models to LocalAI.
|
||||
// @Param request body GalleryModel true "query params"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/apply [post]
|
||||
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input := new(GalleryModel)
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel]{
|
||||
Req: input.GalleryModel,
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: input.ID,
|
||||
Galleries: mgs.galleries,
|
||||
BackendGalleries: mgs.backendGalleries,
|
||||
}
|
||||
|
||||
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
|
||||
// @Summary delete models to LocalAI.
|
||||
// @Param name path string true "Model name"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/delete/{name} [post]
|
||||
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelName := c.Params("name")
|
||||
|
||||
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel]{
|
||||
Delete: true,
|
||||
GalleryElementName: modelName,
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
|
||||
// @Summary List installable models.
|
||||
// @Success 200 {object} []gallery.GalleryModel "Response"
|
||||
// @Router /models/available [get]
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries))
|
||||
|
||||
m := []gallery.Metadata{}
|
||||
|
||||
for _, mm := range models {
|
||||
m = append(m, mm.Metadata)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Models %#v", m)
|
||||
|
||||
dat, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal models: %w", err)
|
||||
}
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
|
||||
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
||||
// @Summary List all Galleries
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /models/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Send(dat)
|
||||
}
|
||||
}
|
||||
package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ModelGalleryEndpointService struct {
|
||||
galleries []config.Gallery
|
||||
backendGalleries []config.Gallery
|
||||
modelPath string
|
||||
galleryApplier *services.GalleryService
|
||||
}
|
||||
|
||||
type GalleryModel struct {
|
||||
ID string `json:"id"`
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
return ModelGalleryEndpointService{
|
||||
galleries: galleries,
|
||||
backendGalleries: backendGalleries,
|
||||
modelPath: systemState.Model.ModelsPath,
|
||||
galleryApplier: galleryApplier,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOpStatusEndpoint returns the job status
|
||||
// @Summary Returns the job status
|
||||
// @Success 200 {object} services.GalleryOpStatus "Response"
|
||||
// @Router /models/jobs/{uuid} [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
status := mgs.galleryApplier.GetStatus(c.Param("uuid"))
|
||||
if status == nil {
|
||||
return fmt.Errorf("could not find any status for ID")
|
||||
}
|
||||
return c.JSON(200, status)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllStatusEndpoint returns all the jobs status progress
|
||||
// @Summary Returns all the jobs status progress
|
||||
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
|
||||
// @Router /models/jobs [get]
|
||||
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, mgs.galleryApplier.GetAllStatus())
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
|
||||
// @Summary Install models to LocalAI.
|
||||
// @Param request body GalleryModel true "query params"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/apply [post]
|
||||
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(GalleryModel)
|
||||
// Get input data from the request body
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
Req: input.GalleryModel,
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: input.ID,
|
||||
Galleries: mgs.galleries,
|
||||
BackendGalleries: mgs.backendGalleries,
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
|
||||
// @Summary delete models to LocalAI.
|
||||
// @Param name path string true "Model name"
|
||||
// @Success 200 {object} schema.GalleryResponse "Response"
|
||||
// @Router /models/delete/{name} [post]
|
||||
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelName := c.Param("name")
|
||||
|
||||
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
Delete: true,
|
||||
GalleryElementName: modelName,
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
|
||||
}
|
||||
}
|
||||
|
||||
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
|
||||
// @Summary List installable models.
|
||||
// @Success 200 {object} []gallery.GalleryModel "Response"
|
||||
// @Router /models/available [get]
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries))
|
||||
|
||||
m := []gallery.Metadata{}
|
||||
|
||||
for _, mm := range models {
|
||||
m = append(m, mm.Metadata)
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Models %#v", m)
|
||||
|
||||
dat, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal models: %w", err)
|
||||
}
|
||||
return c.Blob(200, "application/json", dat)
|
||||
}
|
||||
}
|
||||
|
||||
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
|
||||
// @Summary List all Galleries
|
||||
// @Success 200 {object} []config.Gallery "Response"
|
||||
// @Router /models/galleries [get]
|
||||
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
|
||||
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
|
||||
dat, err := json.Marshal(mgs.galleries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Blob(200, "application/json", dat)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -21,17 +21,17 @@ import (
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
// @Router /v1/tokenMetrics [get]
|
||||
// @Router /tokenMetrics [get]
|
||||
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input := new(schema.TokenMetricsRequest)
|
||||
|
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
modelFile, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if !ok || modelFile != "" {
|
||||
modelFile = input.Model
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
@@ -52,6 +52,6 @@ func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,33 +2,97 @@ package localai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ImportModelURIEndpoint handles creating new model configurations from a URI
|
||||
func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input := new(schema.ImportModelRequest)
|
||||
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelConfig, err := importers.DiscoverModelConfig(input.URI, input.Preferences)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to discover model config: %w", err)
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewUUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Determine gallery ID for tracking - use model name if available, otherwise use URI
|
||||
galleryID := input.URI
|
||||
if modelConfig.Name != "" {
|
||||
galleryID = modelConfig.Name
|
||||
}
|
||||
|
||||
// Register operation in opcache if available (for UI progress tracking)
|
||||
if opcache != nil {
|
||||
opcache.Set(galleryID, uuid.String())
|
||||
}
|
||||
|
||||
galleryService.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
|
||||
Req: gallery.GalleryModel{
|
||||
Overrides: map[string]interface{}{},
|
||||
},
|
||||
ID: uuid.String(),
|
||||
GalleryElementName: galleryID,
|
||||
GalleryElement: &modelConfig,
|
||||
BackendGalleries: appConfig.BackendGalleries,
|
||||
}
|
||||
|
||||
return c.JSON(200, schema.GalleryResponse{
|
||||
ID: uuid.String(),
|
||||
StatusURL: fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), uuid.String()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ImportModelEndpoint handles creating new model configurations
|
||||
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// Get the raw body
|
||||
body := c.Body()
|
||||
body, err := io.ReadAll(c.Request().Body)
|
||||
if err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to read request body: " + err.Error(),
|
||||
}
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
if len(body) == 0 {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Request body is empty",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Check content type to determine how to parse
|
||||
contentType := string(c.Context().Request.Header.ContentType())
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
var modelConfig config.ModelConfig
|
||||
var err error
|
||||
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
// Parse JSON
|
||||
@@ -37,7 +101,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
|
||||
// Parse YAML
|
||||
@@ -46,18 +110,18 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else {
|
||||
// Try to auto-detect format
|
||||
if strings.TrimSpace(string(body))[0] == '{' {
|
||||
if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' {
|
||||
// Looks like JSON
|
||||
if err := json.Unmarshal(body, &modelConfig); err != nil {
|
||||
response := ModelResponse{
|
||||
Success: false,
|
||||
Error: "Failed to parse JSON: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
} else {
|
||||
// Assume YAML
|
||||
@@ -66,7 +130,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to parse YAML: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -77,7 +141,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Name is required",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Set defaults
|
||||
@@ -89,7 +153,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Invalid configuration",
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Create the configuration file
|
||||
@@ -99,7 +163,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Model path not trusted: " + err.Error(),
|
||||
}
|
||||
return c.Status(400).JSON(response)
|
||||
return c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Marshal to YAML for storage
|
||||
@@ -109,7 +173,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to marshal configuration: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Write the file
|
||||
@@ -118,7 +182,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to write configuration file: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
// Reload configurations
|
||||
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
|
||||
@@ -126,7 +190,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to reload configurations: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Preload the model
|
||||
@@ -135,7 +199,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Success: false,
|
||||
Error: "Failed to preload model: " + err.Error(),
|
||||
}
|
||||
return c.Status(500).JSON(response)
|
||||
return c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
// Return success response
|
||||
response := ModelResponse{
|
||||
@@ -143,6 +207,6 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
|
||||
Message: "Model configuration created successfully",
|
||||
Filename: filepath.Base(configPath),
|
||||
}
|
||||
return c.JSON(response)
|
||||
return c.JSON(200, response)
|
||||
}
|
||||
}
|
||||
|
||||
323
core/http/endpoints/localai/mcp.go
Normal file
323
core/http/endpoints/localai/mcp.go
Normal file
@@ -0,0 +1,323 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/cogito"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MCP SSE Event Types
|
||||
type MCPReasoningEvent struct {
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type MCPToolCallEvent struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
Reasoning string `json:"reasoning"`
|
||||
}
|
||||
|
||||
type MCPToolResultEvent struct {
|
||||
Type string `json:"type"`
|
||||
Name string `json:"name"`
|
||||
Result string `json:"result"`
|
||||
}
|
||||
|
||||
type MCPStatusEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type MCPAssistantEvent struct {
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type MCPErrorEvent struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// MCPStreamEndpoint is the SSE streaming endpoint for MCP chat completions
|
||||
// @Summary Stream MCP chat completions with reasoning, tool calls, and results
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/mcp/chat/completions [post]
|
||||
func MCPStreamEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ctx := c.Request().Context()
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
// Handle Correlation
|
||||
id := c.Request().Header.Get("X-Correlation-ID")
|
||||
if id == "" {
|
||||
id = fmt.Sprintf("mcp-%d", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
|
||||
return fmt.Errorf("no MCP servers configured")
|
||||
}
|
||||
|
||||
// Get MCP config from model config
|
||||
remote, stdio, err := config.MCP.MCPConfigFromYAML()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get MCP config: %w", err)
|
||||
}
|
||||
|
||||
// Check if we have tools in cache, or we have to have an initial connection
|
||||
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get MCP sessions: %w", err)
|
||||
}
|
||||
|
||||
if len(sessions) == 0 {
|
||||
return fmt.Errorf("no working MCP servers found")
|
||||
}
|
||||
|
||||
// Build fragment from messages
|
||||
fragment := cogito.NewEmptyFragment()
|
||||
for _, message := range input.Messages {
|
||||
fragment = fragment.AddMessage(message.Role, message.StringContent)
|
||||
}
|
||||
|
||||
port := appConfig.APIAddress[strings.LastIndex(appConfig.APIAddress, ":")+1:]
|
||||
apiKey := ""
|
||||
if len(appConfig.ApiKeys) > 0 {
|
||||
apiKey = appConfig.ApiKeys[0]
|
||||
}
|
||||
|
||||
ctxWithCancellation, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// TODO: instead of connecting to the API, we should just wire this internally
|
||||
// and act like completion.go.
|
||||
// We can do this as cogito expects an interface and we can create one that
|
||||
// we satisfy to just call internally ComputeChoices
|
||||
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
|
||||
|
||||
// Build cogito options using the consolidated method
|
||||
cogitoOpts := config.BuildCogitoOptions()
|
||||
cogitoOpts = append(
|
||||
cogitoOpts,
|
||||
cogito.WithContext(ctxWithCancellation),
|
||||
cogito.WithMCPs(sessions...),
|
||||
)
|
||||
// Check if streaming is requested
|
||||
toStream := input.Stream
|
||||
|
||||
if !toStream {
|
||||
// Non-streaming mode: execute synchronously and return JSON response
|
||||
cogitoOpts = append(
|
||||
cogitoOpts,
|
||||
cogito.WithStatusCallback(func(s string) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
|
||||
}),
|
||||
cogito.WithReasoningCallback(func(s string) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
|
||||
}),
|
||||
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
|
||||
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("reasoning", t.Reasoning).Interface("arguments", t.Arguments).Msg("[model agent] Tool call")
|
||||
return true
|
||||
}),
|
||||
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
|
||||
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("result", t.Result).Interface("tool_arguments", t.ToolArguments).Msg("[model agent] Tool call result")
|
||||
}),
|
||||
)
|
||||
|
||||
f, err := cogito.ExecuteTools(
|
||||
defaultLLM, fragment,
|
||||
cogitoOpts...,
|
||||
)
|
||||
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err = defaultLLM.Ask(ctxWithCancellation, f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Message: &schema.Message{Role: "assistant", Content: &f.LastMessage().Content}}},
|
||||
Object: "chat.completion",
|
||||
}
|
||||
|
||||
jsonResult, _ := json.Marshal(resp)
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
|
||||
// Streaming mode: use SSE
|
||||
// Set up SSE headers
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().Header().Set("X-Correlation-ID", id)
|
||||
|
||||
// Create channel for streaming events
|
||||
events := make(chan interface{})
|
||||
ended := make(chan error, 1)
|
||||
|
||||
// Set up callbacks for streaming
|
||||
statusCallback := func(s string) {
|
||||
events <- MCPStatusEvent{
|
||||
Type: "status",
|
||||
Message: s,
|
||||
}
|
||||
}
|
||||
|
||||
reasoningCallback := func(s string) {
|
||||
events <- MCPReasoningEvent{
|
||||
Type: "reasoning",
|
||||
Content: s,
|
||||
}
|
||||
}
|
||||
|
||||
toolCallCallback := func(t *cogito.ToolChoice) bool {
|
||||
events <- MCPToolCallEvent{
|
||||
Type: "tool_call",
|
||||
Name: t.Name,
|
||||
Arguments: t.Arguments,
|
||||
Reasoning: t.Reasoning,
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
toolCallResultCallback := func(t cogito.ToolStatus) {
|
||||
events <- MCPToolResultEvent{
|
||||
Type: "tool_result",
|
||||
Name: t.Name,
|
||||
Result: t.Result,
|
||||
}
|
||||
}
|
||||
|
||||
cogitoOpts = append(cogitoOpts,
|
||||
cogito.WithStatusCallback(statusCallback),
|
||||
cogito.WithReasoningCallback(reasoningCallback),
|
||||
cogito.WithToolCallBack(toolCallCallback),
|
||||
cogito.WithToolCallResultCallback(toolCallResultCallback),
|
||||
)
|
||||
|
||||
// Execute tools in a goroutine
|
||||
go func() {
|
||||
defer close(events)
|
||||
|
||||
f, err := cogito.ExecuteTools(
|
||||
defaultLLM, fragment,
|
||||
cogitoOpts...,
|
||||
)
|
||||
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
|
||||
events <- MCPErrorEvent{
|
||||
Type: "error",
|
||||
Message: fmt.Sprintf("Failed to execute tools: %v", err),
|
||||
}
|
||||
ended <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Get final response
|
||||
f, err = defaultLLM.Ask(ctxWithCancellation, f)
|
||||
if err != nil {
|
||||
events <- MCPErrorEvent{
|
||||
Type: "error",
|
||||
Message: fmt.Sprintf("Failed to get response: %v", err),
|
||||
}
|
||||
ended <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Stream final assistant response
|
||||
content := f.LastMessage().Content
|
||||
events <- MCPAssistantEvent{
|
||||
Type: "assistant",
|
||||
Content: content,
|
||||
}
|
||||
|
||||
ended <- nil
|
||||
}()
|
||||
|
||||
// Stream events to client
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Context was cancelled (client disconnected or request cancelled)
|
||||
log.Debug().Msgf("Request context cancelled, stopping stream")
|
||||
cancel()
|
||||
break LOOP
|
||||
case event := <-events:
|
||||
if event == nil {
|
||||
// Channel closed
|
||||
break LOOP
|
||||
}
|
||||
eventData, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Failed to marshal event: %v", err)
|
||||
continue
|
||||
}
|
||||
log.Debug().Msgf("Sending event: %s", string(eventData))
|
||||
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(eventData))
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Sending event failed: %v", err)
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
// Send done signal
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
break LOOP
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
errorEvent := MCPErrorEvent{
|
||||
Type: "error",
|
||||
Message: err.Error(),
|
||||
}
|
||||
errorData, marshalErr := json.Marshal(errorEvent)
|
||||
if marshalErr != nil {
|
||||
fmt.Fprintf(c.Response().Writer, "data: {\"type\":\"error\",\"message\":\"Internal error\"}\n\n")
|
||||
} else {
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
|
||||
}
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Stream ended")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -1,46 +1,47 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/adaptor"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
|
||||
// @Summary Prometheus metrics endpoint
|
||||
// @Param request body config.Gallery true "Gallery details"
|
||||
// @Router /metrics [get]
|
||||
func LocalAIMetricsEndpoint() fiber.Handler {
|
||||
return adaptor.HTTPHandler(promhttp.Handler())
|
||||
}
|
||||
|
||||
type apiMiddlewareConfig struct {
|
||||
Filter func(c *fiber.Ctx) bool
|
||||
metricsService *services.LocalAIMetricsService
|
||||
}
|
||||
|
||||
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler {
|
||||
cfg := apiMiddlewareConfig{
|
||||
metricsService: metrics,
|
||||
Filter: func(c *fiber.Ctx) bool {
|
||||
return c.Path() == "/metrics"
|
||||
},
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
if cfg.Filter != nil && cfg.Filter(c) {
|
||||
return c.Next()
|
||||
}
|
||||
path := c.Path()
|
||||
method := c.Method()
|
||||
|
||||
start := time.Now()
|
||||
err := c.Next()
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
cfg.metricsService.ObserveAPICall(method, path, elapsed)
|
||||
return err
|
||||
}
|
||||
}
|
||||
package localai
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
|
||||
// @Summary Prometheus metrics endpoint
|
||||
// @Param request body config.Gallery true "Gallery details"
|
||||
// @Router /metrics [get]
|
||||
func LocalAIMetricsEndpoint() echo.HandlerFunc {
|
||||
return echo.WrapHandler(promhttp.Handler())
|
||||
}
|
||||
|
||||
type apiMiddlewareConfig struct {
|
||||
Filter func(c echo.Context) bool
|
||||
metricsService *services.LocalAIMetricsService
|
||||
}
|
||||
|
||||
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) echo.MiddlewareFunc {
|
||||
cfg := apiMiddlewareConfig{
|
||||
metricsService: metrics,
|
||||
Filter: func(c echo.Context) bool {
|
||||
return c.Path() == "/metrics"
|
||||
},
|
||||
}
|
||||
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if cfg.Filter != nil && cfg.Filter(c) {
|
||||
return next(c)
|
||||
}
|
||||
path := c.Path()
|
||||
method := c.Request().Method
|
||||
|
||||
start := time.Now()
|
||||
err := next(c)
|
||||
elapsed := float64(time.Since(start)) / float64(time.Second)
|
||||
cfg.metricsService.ObserveAPICall(method, path, elapsed)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/p2p"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
@@ -11,10 +11,10 @@ import (
|
||||
// @Summary Returns available P2P nodes
|
||||
// @Success 200 {object} []schema.P2PNodesResponse "Response"
|
||||
// @Router /api/p2p [get]
|
||||
func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||
func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
// Render index
|
||||
return func(c *fiber.Ctx) error {
|
||||
return c.JSON(schema.P2PNodesResponse{
|
||||
return func(c echo.Context) error {
|
||||
return c.JSON(200, schema.P2PNodesResponse{
|
||||
Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)),
|
||||
FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)),
|
||||
})
|
||||
@@ -25,6 +25,6 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||
// @Summary Show the P2P token
|
||||
// @Success 200 {string} string "Response"
|
||||
// @Router /api/p2p/token [get]
|
||||
func ShowP2PToken(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error { return c.Send([]byte(appConfig.P2PToken)) }
|
||||
func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error { return c.String(200, appConfig.P2PToken) }
|
||||
}
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// SettingsEndpoint handles the settings page which shows detailed model/backend management
|
||||
func SettingsEndpoint(appConfig *config.ApplicationConfig,
|
||||
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
|
||||
installedBackends, err := gallery.ListSystemBackends(appConfig.SystemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, m := range modelConfigs {
|
||||
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
galleryConfigs[m.Name] = cfg
|
||||
}
|
||||
|
||||
loadedModels := ml.ListLoadedModels()
|
||||
loadedModelsMap := map[string]bool{}
|
||||
for _, m := range loadedModels {
|
||||
loadedModelsMap[m.ID] = true
|
||||
}
|
||||
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
// Get model statuses to display in the UI the operation in progress
|
||||
processingModels, taskTypes := opcache.GetStatus()
|
||||
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Settings & Management",
|
||||
"Version": internal.PrintableVersion(),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"Models": modelsWithoutConfig,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
"ApplicationConfig": appConfig,
|
||||
"ProcessingModels": processingModels,
|
||||
"TaskTypes": taskTypes,
|
||||
"LoadedModels": loadedModelsMap,
|
||||
"InstalledBackends": installedBackends,
|
||||
}
|
||||
|
||||
// Render settings page
|
||||
return c.Render("views/settings", summary)
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/store"
|
||||
)
|
||||
|
||||
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.StoresSet)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -28,20 +28,20 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||
vals[i] = []byte(v)
|
||||
}
|
||||
|
||||
err = store.SetCols(c.Context(), sb, input.Keys, vals)
|
||||
err = store.SetCols(c.Request().Context(), sb, input.Keys, vals)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Send(nil)
|
||||
return c.NoContent(200)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.StoresDelete)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -51,19 +51,19 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
|
||||
}
|
||||
defer sl.Close()
|
||||
|
||||
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
|
||||
if err := store.DeleteCols(c.Request().Context(), sb, input.Keys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Send(nil)
|
||||
return c.NoContent(200)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.StoresGet)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||
}
|
||||
defer sl.Close()
|
||||
|
||||
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
|
||||
keys, vals, err := store.GetCols(c.Request().Context(), sb, input.Keys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -87,15 +87,15 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
|
||||
res.Values[i] = string(v)
|
||||
}
|
||||
|
||||
return c.JSON(res)
|
||||
return c.JSON(200, res)
|
||||
}
|
||||
}
|
||||
|
||||
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := new(schema.StoresFind)
|
||||
|
||||
if err := c.BodyParser(input); err != nil {
|
||||
if err := c.Bind(input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
|
||||
}
|
||||
defer sl.Close()
|
||||
|
||||
keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
|
||||
keys, vals, similarities, err := store.Find(c.Request().Context(), sb, input.Key, input.Topk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -120,6 +120,6 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
|
||||
res.Values[i] = string(v)
|
||||
}
|
||||
|
||||
return c.JSON(res)
|
||||
return c.JSON(200, res)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
// @Summary Show the LocalAI instance information
|
||||
// @Success 200 {object} schema.SystemInformationResponse "Response"
|
||||
// @Router /system [get]
|
||||
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
availableBackends := []string{}
|
||||
loadedModels := ml.ListLoadedModels()
|
||||
for b := range appConfig.ExternalGRPCBackends {
|
||||
@@ -26,7 +26,7 @@ func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConf
|
||||
for _, m := range loadedModels {
|
||||
sysmodels = append(sysmodels, schema.SysInfoModel{ID: m.ID})
|
||||
}
|
||||
return c.JSON(
|
||||
return c.JSON(200,
|
||||
schema.SystemInformationResponse{
|
||||
Backends: availableBackends,
|
||||
Models: sysmodels,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -14,22 +14,22 @@ import (
|
||||
// @Param request body schema.TokenizeRequest true "Request"
|
||||
// @Success 200 {object} schema.TokenizeResponse "Response"
|
||||
// @Router /v1/tokenize [post]
|
||||
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
|
||||
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ctx.JSON(tokenResponse)
|
||||
return c.JSON(200, tokenResponse)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
@@ -22,16 +24,16 @@ import (
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
// @Router /v1/audio/speech [post]
|
||||
// @Router /tts [post]
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request received")
|
||||
@@ -59,6 +61,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Download(filePath)
|
||||
return c.Attachment(filePath, filepath.Base(filePath))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -16,26 +16,26 @@ import (
|
||||
// @Param request body schema.VADRequest true "query params"
|
||||
// @Success 200 {object} proto.VADResponse "Response"
|
||||
// @Router /vad [post]
|
||||
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
|
||||
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request received")
|
||||
|
||||
resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
|
||||
resp, err := backend.VAD(input, c.Request().Context(), ml, appConfig, *cfg)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,19 +7,20 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -64,18 +65,18 @@ func downloadFile(url string) (string, error) {
|
||||
// @Param request body schema.VideoRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /video [post]
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
|
||||
if !ok || input.Model == "" {
|
||||
log.Error().Msg("Video Endpoint - Invalid Input")
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Video Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
src := ""
|
||||
@@ -164,7 +165,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
return err
|
||||
}
|
||||
|
||||
baseURL := c.BaseURL()
|
||||
baseURL := middleware.BaseURL(c)
|
||||
|
||||
fn, err := backend.VideoGeneration(
|
||||
height,
|
||||
@@ -201,7 +202,10 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = baseURL + "/generated-videos/" + base
|
||||
item.URL, err = url.JoinPath(baseURL, "generated-videos", base)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
id := uuid.New().String()
|
||||
@@ -216,6 +220,6 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/internal"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
|
||||
@@ -40,10 +42,10 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
// Get model statuses to display in the UI the operation in progress
|
||||
processingModels, taskTypes := opcache.GetStatus()
|
||||
|
||||
summary := fiber.Map{
|
||||
summary := map[string]interface{}{
|
||||
"Title": "LocalAI API - " + internal.PrintableVersion(),
|
||||
"Version": internal.PrintableVersion(),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
"Models": modelsWithoutConfig,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
@@ -54,12 +56,21 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
"InstalledBackends": installedBackends,
|
||||
}
|
||||
|
||||
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
accept := c.Request().Header.Get("Accept")
|
||||
// Default to HTML if Accept header is empty (browser behavior)
|
||||
// Only return JSON if explicitly requested or Content-Type is application/json
|
||||
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "text/html")) {
|
||||
// The client expects a JSON response
|
||||
return c.Status(fiber.StatusOK).JSON(summary)
|
||||
return c.JSON(200, summary)
|
||||
} else {
|
||||
// Render index
|
||||
return c.Render("views/index", summary)
|
||||
// Check if this is the manage route
|
||||
templateName := "views/index"
|
||||
if strings.HasSuffix(c.Request().URL.Path, "/manage") || c.Request().URL.Path == "/manage" {
|
||||
templateName = "views/manage"
|
||||
}
|
||||
// Render appropriate template
|
||||
return c.Render(200, templateName, summary)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -19,7 +17,6 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create
|
||||
@@ -27,7 +24,7 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/chat/completions [post]
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) echo.HandlerFunc {
|
||||
var id, textContentToReturn string
|
||||
var created int
|
||||
|
||||
@@ -36,7 +33,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &textContentToReturn}}},
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
@@ -56,7 +53,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: usage,
|
||||
}
|
||||
@@ -90,7 +87,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &textContentToReturn}}},
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
@@ -114,7 +111,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
|
||||
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0, FinishReason: nil}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: usage,
|
||||
}
|
||||
@@ -142,7 +139,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
},
|
||||
},
|
||||
},
|
||||
}}},
|
||||
},
|
||||
Index: 0,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
responses <- initialMessage
|
||||
@@ -165,7 +165,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
},
|
||||
},
|
||||
},
|
||||
}}},
|
||||
},
|
||||
Index: 0,
|
||||
FinishReason: nil,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
}
|
||||
@@ -175,21 +178,21 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return err
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
return func(c echo.Context) error {
|
||||
textContentToReturn = ""
|
||||
id = uuid.New().String()
|
||||
created = int(time.Now().Unix())
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
|
||||
@@ -217,6 +220,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
|
||||
}
|
||||
|
||||
// If we are using a response format, we need to generate a grammar for it
|
||||
if config.ResponseFormatMap != nil {
|
||||
d := schema.ChatCompletionResponseFormat{}
|
||||
dat, err := json.Marshal(config.ResponseFormatMap)
|
||||
@@ -260,6 +264,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
|
||||
switch {
|
||||
// Generates grammar with internal's LocalAI engine
|
||||
case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn:
|
||||
noActionGrammar := functions.Function{
|
||||
Name: noActionName,
|
||||
@@ -283,7 +288,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
funcs = funcs.Select(config.FunctionToCall())
|
||||
}
|
||||
|
||||
// Update input grammar
|
||||
// Update input grammar or json_schema based on use_llama_grammar option
|
||||
jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey)
|
||||
g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...)
|
||||
if err == nil {
|
||||
@@ -298,6 +303,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
} else {
|
||||
log.Error().Err(err).Msg("Failed generating grammar")
|
||||
}
|
||||
|
||||
default:
|
||||
// Force picking one of the functions by the request
|
||||
if config.FunctionToCall() != "" {
|
||||
@@ -316,7 +322,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
|
||||
// If we are using the tokenizer template, we don't need to process the messages
|
||||
// unless we are processing functions
|
||||
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
|
||||
if !config.TemplateConfig.UseTokenizerTemplate {
|
||||
predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn)
|
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput)
|
||||
@@ -329,13 +335,10 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
case toStream:
|
||||
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
c.Set("X-Correlation-ID", id)
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().Header().Set("X-Correlation-ID", id)
|
||||
|
||||
responses := make(chan schema.OpenAIResponse)
|
||||
ended := make(chan error, 1)
|
||||
@@ -348,89 +351,101 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
}()
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
usage := &schema.OpenAIUsage{}
|
||||
toolsCalled := false
|
||||
usage := &schema.OpenAIUsage{}
|
||||
toolsCalled := false
|
||||
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case ev := <-responses:
|
||||
if len(ev.Choices) == 0 {
|
||||
log.Debug().Msgf("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
||||
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
||||
toolsCalled = true
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.Encode(ev)
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||
input.Cancel()
|
||||
}
|
||||
w.Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
break LOOP
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: "stop",
|
||||
Index: 0,
|
||||
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
|
||||
return
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case <-input.Context.Done():
|
||||
// Context was cancelled (client disconnected or request cancelled)
|
||||
log.Debug().Msgf("Request context cancelled, stopping stream")
|
||||
input.Cancel()
|
||||
break LOOP
|
||||
case ev := <-responses:
|
||||
if len(ev.Choices) == 0 {
|
||||
log.Debug().Msgf("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
|
||||
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
|
||||
toolsCalled = true
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Failed to marshal response: %v", err)
|
||||
input.Cancel()
|
||||
continue
|
||||
}
|
||||
log.Debug().Msgf("Sending chunk: %s", string(respData))
|
||||
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Sending chunk failed: %v", err)
|
||||
input.Cancel()
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
break LOOP
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: &stopReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, marshalErr := json.Marshal(resp)
|
||||
if marshalErr != nil {
|
||||
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
|
||||
// Send a simple error message as fallback
|
||||
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
|
||||
} else {
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
}
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
finishReason := "stop"
|
||||
if toolsCalled && len(input.Tools) > 0 {
|
||||
finishReason = "tool_calls"
|
||||
} else if toolsCalled {
|
||||
finishReason = "function_call"
|
||||
}
|
||||
finishReason := FinishReasonStop
|
||||
if toolsCalled && len(input.Tools) > 0 {
|
||||
finishReason = FinishReasonToolCalls
|
||||
} else if toolsCalled {
|
||||
finishReason = FinishReasonFunctionCall
|
||||
}
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: finishReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{Content: &textContentToReturn},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
log.Debug().Msgf("Stream ended")
|
||||
}))
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
FinishReason: &finishReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{},
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
Usage: *usage,
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
log.Debug().Msgf("Stream ended")
|
||||
return nil
|
||||
|
||||
// no streaming mode
|
||||
@@ -439,7 +454,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
tokenCallback := func(s string, c *[]schema.Choice) {
|
||||
if !shouldUseFn {
|
||||
// no function is called, just reply and use stop as finish reason
|
||||
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||
stopReason := FinishReasonStop
|
||||
*c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -457,12 +473,14 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
return
|
||||
}
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
*c = append(*c, schema.Choice{
|
||||
FinishReason: "stop",
|
||||
FinishReason: &stopReason,
|
||||
Message: &schema.Message{Role: "assistant", Content: &result}})
|
||||
default:
|
||||
toolCallsReason := FinishReasonToolCalls
|
||||
toolChoice := schema.Choice{
|
||||
FinishReason: "tool_calls",
|
||||
FinishReason: &toolCallsReason,
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
},
|
||||
@@ -486,8 +504,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
)
|
||||
} else {
|
||||
// otherwise we return more choices directly (deprecated)
|
||||
functionCallReason := FinishReasonFunctionCall
|
||||
*c = append(*c, schema.Choice{
|
||||
FinishReason: "function_call",
|
||||
FinishReason: &functionCallReason,
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: &textContentToReturn,
|
||||
@@ -508,6 +527,9 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
|
||||
}
|
||||
|
||||
// Echo properly supports context cancellation via c.Request().Context()
|
||||
// No workaround needed!
|
||||
|
||||
result, tokenUsage, err := ComputeChoices(
|
||||
input,
|
||||
predInput,
|
||||
@@ -543,7 +565,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
log.Debug().Msgf("Response: %s", respData)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -597,7 +619,48 @@ func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, in
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil)
|
||||
// Serialize tools and tool_choice to JSON strings
|
||||
toolsJSON := ""
|
||||
if len(input.Tools) > 0 {
|
||||
toolsBytes, err := json.Marshal(input.Tools)
|
||||
if err == nil {
|
||||
toolsJSON = string(toolsBytes)
|
||||
}
|
||||
}
|
||||
toolChoiceJSON := ""
|
||||
if input.ToolsChoice != nil {
|
||||
toolChoiceBytes, err := json.Marshal(input.ToolsChoice)
|
||||
if err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logprobs from request
|
||||
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
|
||||
var logprobs *int
|
||||
var topLogprobs *int
|
||||
if input.Logprobs.IsEnabled() {
|
||||
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
|
||||
if input.TopLogprobs != nil {
|
||||
topLogprobs = input.TopLogprobs
|
||||
// For backend compatibility, set logprobs to the top_logprobs value
|
||||
logprobs = input.TopLogprobs
|
||||
} else {
|
||||
// Default to 1 if logprobs is true but top_logprobs not specified
|
||||
val := 1
|
||||
logprobs = &val
|
||||
topLogprobs = &val
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logit_bias from request
|
||||
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
|
||||
var logitBias map[string]float64
|
||||
if len(input.LogitBias) > 0 {
|
||||
logitBias = input.LogitBias
|
||||
}
|
||||
|
||||
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("model inference failed")
|
||||
return "", err
|
||||
|
||||
@@ -1,25 +1,22 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions
|
||||
@@ -27,7 +24,7 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
created := int(time.Now().Unix())
|
||||
@@ -47,8 +44,9 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
Text: s,
|
||||
Index: 0,
|
||||
Text: s,
|
||||
FinishReason: nil,
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
@@ -64,22 +62,25 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
return err
|
||||
}
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
// Handle Correlation
|
||||
id := c.Get("X-Correlation-ID", uuid.New().String())
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
id := c.Request().Header.Get("X-Correlation-ID")
|
||||
if id == "" {
|
||||
id = uuid.New().String()
|
||||
}
|
||||
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
if config.ResponseFormatMap != nil {
|
||||
@@ -97,15 +98,10 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
|
||||
if input.Stream {
|
||||
log.Debug().Msgf("Stream request received")
|
||||
c.Context().SetContentType("text/event-stream")
|
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache")
|
||||
c.Set("Connection", "keep-alive")
|
||||
c.Set("Transfer-Encoding", "chunked")
|
||||
}
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
if input.Stream {
|
||||
if len(config.PromptStrings) > 1 {
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
}
|
||||
@@ -130,53 +126,78 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
|
||||
}()
|
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case ev := <-responses:
|
||||
if len(ev.Choices) == 0 {
|
||||
log.Debug().Msgf("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
log.Debug().Msgf("Failed to marshal response: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
LOOP:
|
||||
for {
|
||||
select {
|
||||
case ev := <-responses:
|
||||
if len(ev.Choices) == 0 {
|
||||
log.Debug().Msgf("No choices in the response, skipping")
|
||||
continue
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
enc := json.NewEncoder(&buf)
|
||||
enc.Encode(ev)
|
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String())
|
||||
fmt.Fprintf(w, "data: %v\n", buf.String())
|
||||
w.Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
break LOOP
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
fmt.Fprintf(w, "data: %v\n", "Internal error: "+err.Error())
|
||||
w.Flush()
|
||||
log.Debug().Msgf("Sending chunk: %s", string(respData))
|
||||
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
case err := <-ended:
|
||||
if err == nil {
|
||||
break LOOP
|
||||
}
|
||||
}
|
||||
log.Error().Msgf("Stream ended with error: %v", err)
|
||||
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: "stop",
|
||||
stopReason := FinishReasonStop
|
||||
errorResp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model,
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: &stopReason,
|
||||
Text: "Internal error: " + err.Error(),
|
||||
},
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
Object: "text_completion",
|
||||
}
|
||||
errorData, marshalErr := json.Marshal(errorResp)
|
||||
if marshalErr != nil {
|
||||
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
|
||||
// Send a simple error message as fallback
|
||||
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
|
||||
} else {
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
|
||||
}
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
}
|
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
|
||||
w.WriteString("data: [DONE]\n\n")
|
||||
w.Flush()
|
||||
}))
|
||||
return <-ended
|
||||
stopReason := FinishReasonStop
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []schema.Choice{
|
||||
{
|
||||
Index: 0,
|
||||
FinishReason: &stopReason,
|
||||
},
|
||||
},
|
||||
Object: "text_completion",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
|
||||
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
|
||||
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []schema.Choice
|
||||
@@ -197,7 +218,8 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
|
||||
r, tokenUsage, err := ComputeChoices(
|
||||
input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
|
||||
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
|
||||
stopReason := FinishReasonStop
|
||||
*c = append(*c, schema.Choice{Text: s, FinishReason: &stopReason, Index: k})
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -231,6 +253,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
8
core/http/endpoints/openai/constants.go
Normal file
8
core/http/endpoints/openai/constants.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package openai
|
||||
|
||||
// Finish reason constants for OpenAI API responses
|
||||
const (
|
||||
FinishReasonStop = "stop"
|
||||
FinishReasonToolCalls = "tool_calls"
|
||||
FinishReasonFunctionCall = "function_call"
|
||||
)
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
@@ -23,20 +23,20 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/edits [post]
|
||||
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
return func(c echo.Context) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
// Opt-in extra usage flag
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Edit Endpoint Input : %+v", input)
|
||||
@@ -98,6 +98,6 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -12,7 +13,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -21,16 +21,16 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/embeddings [post]
|
||||
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config)
|
||||
@@ -78,6 +78,6 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@@ -14,13 +15,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
@@ -65,18 +66,18 @@ func downloadFile(url string) (string, error) {
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/images/generations [post]
|
||||
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
log.Error().Msg("Image Endpoint - Invalid Input")
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Image Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
// Process input images (for img2img/inpainting)
|
||||
@@ -188,7 +189,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
return err
|
||||
}
|
||||
|
||||
baseURL := c.BaseURL()
|
||||
baseURL := middleware.BaseURL(c)
|
||||
|
||||
// Use the first input image as src if available, otherwise use the original src
|
||||
inputSrc := src
|
||||
@@ -215,7 +216,10 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
||||
} else {
|
||||
base := filepath.Base(output)
|
||||
item.URL = baseURL + "/generated-images/" + base
|
||||
item.URL, err = url.JoinPath(baseURL, "generated-images", base)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, *item)
|
||||
@@ -234,7 +238,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
|
||||
@@ -37,8 +39,50 @@ func ComputeChoices(
|
||||
audios = append(audios, m.StringAudios...)
|
||||
}
|
||||
|
||||
// Serialize tools and tool_choice to JSON strings
|
||||
toolsJSON := ""
|
||||
if len(req.Tools) > 0 {
|
||||
toolsBytes, err := json.Marshal(req.Tools)
|
||||
if err == nil {
|
||||
toolsJSON = string(toolsBytes)
|
||||
}
|
||||
}
|
||||
toolChoiceJSON := ""
|
||||
if req.ToolsChoice != nil {
|
||||
toolChoiceBytes, err := json.Marshal(req.ToolsChoice)
|
||||
if err == nil {
|
||||
toolChoiceJSON = string(toolChoiceBytes)
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logprobs from request
|
||||
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
|
||||
var logprobs *int
|
||||
var topLogprobs *int
|
||||
if req.Logprobs.IsEnabled() {
|
||||
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
|
||||
if req.TopLogprobs != nil {
|
||||
topLogprobs = req.TopLogprobs
|
||||
// For backend compatibility, set logprobs to the top_logprobs value
|
||||
logprobs = req.TopLogprobs
|
||||
} else {
|
||||
// Default to 1 if logprobs is true but top_logprobs not specified
|
||||
val := 1
|
||||
logprobs = &val
|
||||
topLogprobs = &val
|
||||
}
|
||||
}
|
||||
|
||||
// Extract logit_bias from request
|
||||
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
|
||||
var logitBias map[string]float64
|
||||
if len(req.LogitBias) > 0 {
|
||||
logitBias = req.LogitBias
|
||||
}
|
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback)
|
||||
predFunc, err := backend.ModelInference(
|
||||
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
|
||||
if err != nil {
|
||||
return result, backend.TokenUsage{}, err
|
||||
}
|
||||
@@ -59,6 +103,11 @@ func ComputeChoices(
|
||||
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
|
||||
cb(finetunedResponse, &result)
|
||||
|
||||
// Add logprobs to the last choice if present
|
||||
if prediction.Logprobs != nil && len(result) > 0 {
|
||||
result[len(result)-1].Logprobs = prediction.Logprobs
|
||||
}
|
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
@@ -12,14 +12,15 @@ import (
|
||||
// @Summary List and describe the various models available in the API.
|
||||
// @Success 200 {object} schema.ModelsDataResponse "Response"
|
||||
// @Router /v1/models [get]
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
// If blank, no filter is applied.
|
||||
filter := c.Query("filter")
|
||||
filter := c.QueryParam("filter")
|
||||
|
||||
// By default, exclude any loose files that are already referenced by a configuration file.
|
||||
var policy services.LooseFilePolicy
|
||||
if c.QueryBool("excludeConfigured", true) {
|
||||
excludeConfigured := c.QueryParam("excludeConfigured")
|
||||
if excludeConfigured == "" || excludeConfigured == "true" {
|
||||
policy = services.SKIP_IF_CONFIGURED
|
||||
} else {
|
||||
policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user?
|
||||
@@ -41,7 +42,7 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap
|
||||
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
|
||||
}
|
||||
|
||||
return c.JSON(schema.ModelsDataResponse{
|
||||
return c.JSON(200, schema.ModelsDataResponse{
|
||||
Object: "list",
|
||||
Data: dataModels,
|
||||
})
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
@@ -25,24 +26,27 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /mcp/v1/completions [post]
|
||||
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
// We do not support streaming mode (Yet?)
|
||||
return func(c *fiber.Ctx) error {
|
||||
return func(c echo.Context) error {
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
ctx := c.Context()
|
||||
ctx := c.Request().Context()
|
||||
|
||||
// Handle Correlation
|
||||
id := c.Get("X-Correlation-ID", uuid.New().String())
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
id := c.Request().Header.Get("X-Correlation-ID")
|
||||
if id == "" {
|
||||
id = uuid.New().String()
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
|
||||
@@ -50,12 +54,15 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
}
|
||||
|
||||
// Get MCP config from model config
|
||||
remote, stdio := config.MCP.MCPConfigFromYAML()
|
||||
remote, stdio, err := config.MCP.MCPConfigFromYAML()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get MCP config: %w", err)
|
||||
}
|
||||
|
||||
// Check if we have tools in cache, or we have to have an initial connection
|
||||
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to get MCP sessions: %w", err)
|
||||
}
|
||||
|
||||
if len(sessions) == 0 {
|
||||
@@ -73,46 +80,37 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
if appConfig.ApiKeys != nil {
|
||||
apiKey = appConfig.ApiKeys[0]
|
||||
}
|
||||
|
||||
ctxWithCancellation, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// TODO: instead of connecting to the API, we should just wire this internally
|
||||
// and act like completion.go.
|
||||
// We can do this as cogito expects an interface and we can create one that
|
||||
// we satisfy to just call internally ComputeChoices
|
||||
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
|
||||
|
||||
cogitoOpts := []cogito.Option{
|
||||
// Build cogito options using the consolidated method
|
||||
cogitoOpts := config.BuildCogitoOptions()
|
||||
|
||||
cogitoOpts = append(
|
||||
cogitoOpts,
|
||||
cogito.WithContext(ctxWithCancellation),
|
||||
cogito.WithMCPs(sessions...),
|
||||
cogito.WithStatusCallback(func(s string) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
|
||||
}),
|
||||
cogito.WithContext(ctx),
|
||||
cogito.WithMCPs(sessions...),
|
||||
cogito.WithIterations(3), // default to 3 iterations
|
||||
cogito.WithMaxAttempts(3), // default to 3 attempts
|
||||
cogito.WithForceReasoning(),
|
||||
}
|
||||
|
||||
if config.Agent.EnableReasoning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
|
||||
}
|
||||
|
||||
if config.Agent.EnablePlanning {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
|
||||
}
|
||||
|
||||
if config.Agent.EnableMCPPrompts {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
|
||||
}
|
||||
|
||||
if config.Agent.EnablePlanReEvaluator {
|
||||
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
|
||||
}
|
||||
|
||||
if config.Agent.MaxIterations != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithIterations(config.Agent.MaxIterations))
|
||||
}
|
||||
|
||||
if config.Agent.MaxAttempts != 0 {
|
||||
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(config.Agent.MaxAttempts))
|
||||
}
|
||||
cogito.WithReasoningCallback(func(s string) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
|
||||
}),
|
||||
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", t.Name, t.Reasoning, t.Arguments)
|
||||
return true
|
||||
}),
|
||||
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
|
||||
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, tool arguments: %+v", t.Name, t.Result, t.ToolArguments)
|
||||
}),
|
||||
)
|
||||
|
||||
f, err := cogito.ExecuteTools(
|
||||
defaultLLM, fragment,
|
||||
@@ -139,6 +137,6 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
|
||||
log.Debug().Msgf("Response: %s", jsonResult)
|
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp)
|
||||
return c.JSON(200, resp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,9 +10,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"net/http"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
|
||||
@@ -167,32 +169,50 @@ type Model interface {
|
||||
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true // Allow all origins
|
||||
},
|
||||
}
|
||||
|
||||
// TODO: Implement ephemeral keys to allow these endpoints to be used
|
||||
func RealtimeSessions(application *application.Application) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
return ctx.SendStatus(501)
|
||||
func RealtimeSessions(application *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.NoContent(501)
|
||||
}
|
||||
}
|
||||
|
||||
func RealtimeTranscriptionSession(application *application.Application) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
return ctx.SendStatus(501)
|
||||
func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
return c.NoContent(501)
|
||||
}
|
||||
}
|
||||
|
||||
func Realtime(application *application.Application) fiber.Handler {
|
||||
return websocket.New(registerRealtime(application))
|
||||
func Realtime(application *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
// Extract query parameters from Echo context before passing to websocket handler
|
||||
model := c.QueryParam("model")
|
||||
if model == "" {
|
||||
model = "gpt-4o"
|
||||
}
|
||||
intent := c.QueryParam("intent")
|
||||
|
||||
registerRealtime(application, model, intent)(ws)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func registerRealtime(application *application.Application) func(c *websocket.Conn) {
|
||||
func registerRealtime(application *application.Application, model, intent string) func(c *websocket.Conn) {
|
||||
return func(c *websocket.Conn) {
|
||||
|
||||
evaluator := application.TemplatesEvaluator()
|
||||
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
|
||||
|
||||
model := c.Query("model", "gpt-4o")
|
||||
|
||||
intent := c.Query("intent")
|
||||
if intent != "transcription" {
|
||||
sendNotImplemented(c, "Only transcription mode is supported which requires the intent=transcription parameter")
|
||||
}
|
||||
@@ -1067,12 +1087,13 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st
|
||||
// For example, the model might return a special token or JSON indicating a function call
|
||||
|
||||
/*
|
||||
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
|
||||
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil, "", "", nil, nil, nil)
|
||||
|
||||
result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
|
||||
if !shouldUseFn {
|
||||
// no function is called, just reply and use stop as finish reason
|
||||
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||
stopReason := FinishReasonStop
|
||||
*c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1099,7 +1120,8 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
toolChoice.FinishReason = "tool_calls"
|
||||
toolCallsReason := FinishReasonToolCalls
|
||||
toolChoice.FinishReason = &toolCallsReason
|
||||
}
|
||||
|
||||
for _, ss := range results {
|
||||
@@ -1120,8 +1142,9 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st
|
||||
)
|
||||
} else {
|
||||
// otherwise we return more choices directly
|
||||
functionCallReason := FinishReasonFunctionCall
|
||||
*c = append(*c, schema.Choice{
|
||||
FinishReason: "function_call",
|
||||
FinishReason: &functionCallReason,
|
||||
Message: &schema.Message{
|
||||
Role: "assistant",
|
||||
Content: &textContentToReturn,
|
||||
|
||||
@@ -7,13 +7,13 @@ import (
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -24,19 +24,19 @@ import (
|
||||
// @Param file formData file true "file"
|
||||
// @Success 200 {object} map[string]string "Response"
|
||||
// @Router /v1/audio/transcriptions [post]
|
||||
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
diarize := c.FormValue("diarize", "false") != "false"
|
||||
diarize := c.FormValue("diarize") != "false"
|
||||
|
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file")
|
||||
@@ -76,6 +76,6 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr)
|
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(tr)
|
||||
return c.JSON(http.StatusOK, tr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -14,20 +14,24 @@ import (
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input == nil {
|
||||
return fiber.ErrBadRequest
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
var raw map[string]interface{}
|
||||
if body := c.Body(); len(body) > 0 {
|
||||
body := make([]byte, 0)
|
||||
if c.Request().Body != nil {
|
||||
c.Request().Body.Read(body)
|
||||
}
|
||||
if len(body) > 0 {
|
||||
_ = json.Unmarshal(body, &raw)
|
||||
}
|
||||
// Build VideoRequest using shared mapper
|
||||
vr := MapOpenAIToVideo(input, raw)
|
||||
// Place VideoRequest into locals so localai.VideoEndpoint can consume it
|
||||
c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
|
||||
// Place VideoRequest into context so localai.VideoEndpoint can consume it
|
||||
c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
|
||||
// Delegate to existing localai handler
|
||||
return localai.VideoEndpoint(cl, ml, appConfig)(c)
|
||||
}
|
||||
|
||||
@@ -1,48 +1,50 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"net/http"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/favicon"
|
||||
"github.com/gofiber/fiber/v2/middleware/filesystem"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/explorer"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/http/routes"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func Explorer(db *explorer.Database) *fiber.App {
|
||||
func Explorer(db *explorer.Database) *echo.Echo {
|
||||
e := echo.New()
|
||||
|
||||
fiberCfg := fiber.Config{
|
||||
Views: renderEngine(),
|
||||
// We disable the Fiber startup message as it does not conform to structured logging.
|
||||
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
|
||||
DisableStartupMessage: false,
|
||||
// Override default error handler
|
||||
// Set renderer
|
||||
e.Renderer = renderEngine()
|
||||
|
||||
// Hide banner
|
||||
e.HideBanner = true
|
||||
|
||||
e.Pre(middleware.StripPathPrefix())
|
||||
routes.RegisterExplorerRoutes(e, db)
|
||||
|
||||
// Favicon handler
|
||||
e.GET("/favicon.svg", func(c echo.Context) error {
|
||||
data, err := embedDirStatic.ReadFile("static/favicon.svg")
|
||||
if err != nil {
|
||||
return c.NoContent(http.StatusNotFound)
|
||||
}
|
||||
c.Response().Header().Set("Content-Type", "image/svg+xml")
|
||||
return c.Blob(http.StatusOK, "image/svg+xml", data)
|
||||
})
|
||||
|
||||
// Static files - use fs.Sub to create a filesystem rooted at "static"
|
||||
staticFS, err := fs.Sub(embedDirStatic, "static")
|
||||
if err != nil {
|
||||
// Log error but continue - static files might not work
|
||||
log.Error().Err(err).Msg("failed to create static filesystem")
|
||||
} else {
|
||||
e.StaticFS("/static", staticFS)
|
||||
}
|
||||
|
||||
app := fiber.New(fiberCfg)
|
||||
|
||||
app.Use(middleware.StripPathPrefix())
|
||||
routes.RegisterExplorerRoutes(app, db)
|
||||
|
||||
httpFS := http.FS(embedDirStatic)
|
||||
|
||||
app.Use(favicon.New(favicon.Config{
|
||||
URL: "/favicon.svg",
|
||||
FileSystem: httpFS,
|
||||
File: "static/favicon.svg",
|
||||
}))
|
||||
|
||||
app.Use("/static", filesystem.New(filesystem.Config{
|
||||
Root: httpFS,
|
||||
PathPrefix: "static",
|
||||
Browse: true,
|
||||
}))
|
||||
|
||||
// Define a custom 404 handler
|
||||
// Note: keep this at the bottom!
|
||||
app.Use(notFoundHandler)
|
||||
e.GET("/*", notFoundHandler)
|
||||
|
||||
return app
|
||||
return e
|
||||
}
|
||||
|
||||
@@ -3,50 +3,108 @@ package middleware
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/dave-gray101/v2keyauth"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/keyauth"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/labstack/echo/v4/middleware"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
)
|
||||
|
||||
// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
|
||||
// Currently this requires an upstream patch - and feature patches are no longer accepted to v2
|
||||
// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate.
|
||||
var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key")
|
||||
|
||||
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) {
|
||||
customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key", "cookie:token"}, keyauth.ConfigDefault.AuthScheme)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// GetKeyAuthConfig returns Echo's KeyAuth middleware configuration
|
||||
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) {
|
||||
// Create validator function
|
||||
validator := getApiKeyValidationFunction(applicationConfig)
|
||||
|
||||
return &v2keyauth.Config{
|
||||
CustomKeyLookup: customLookup,
|
||||
Next: getApiKeyRequiredFilterFunction(applicationConfig),
|
||||
Validator: getApiKeyValidationFunction(applicationConfig),
|
||||
ErrorHandler: getApiKeyErrorHandler(applicationConfig),
|
||||
AuthScheme: "Bearer",
|
||||
// Create error handler
|
||||
errorHandler := getApiKeyErrorHandler(applicationConfig)
|
||||
|
||||
// Create Next function (skip middleware for certain requests)
|
||||
skipper := getApiKeyRequiredFilterFunction(applicationConfig)
|
||||
|
||||
// Wrap it with our custom key lookup that checks multiple sources
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if len(applicationConfig.ApiKeys) == 0 {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Skip if skipper says so
|
||||
if skipper != nil && skipper(c) {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Try to extract key from multiple sources
|
||||
key, err := extractKeyFromMultipleSources(c)
|
||||
if err != nil {
|
||||
return errorHandler(err, c)
|
||||
}
|
||||
|
||||
// Validate the key
|
||||
valid, err := validator(key, c)
|
||||
if err != nil || !valid {
|
||||
return errorHandler(ErrMissingOrMalformedAPIKey, c)
|
||||
}
|
||||
|
||||
// Store key in context for later use
|
||||
c.Set("api_key", key)
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler {
|
||||
return func(ctx *fiber.Ctx, err error) error {
|
||||
if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) {
|
||||
// extractKeyFromMultipleSources checks multiple sources for the API key
|
||||
// in order: Authorization header, x-api-key header, xi-api-key header, token cookie
|
||||
func extractKeyFromMultipleSources(c echo.Context) (string, error) {
|
||||
// Check Authorization header first
|
||||
auth := c.Request().Header.Get("Authorization")
|
||||
if auth != "" {
|
||||
// Check for Bearer scheme
|
||||
if strings.HasPrefix(auth, "Bearer ") {
|
||||
return strings.TrimPrefix(auth, "Bearer "), nil
|
||||
}
|
||||
// If no Bearer prefix, return as-is (for backward compatibility)
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// Check x-api-key header
|
||||
if key := c.Request().Header.Get("x-api-key"); key != "" {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Check xi-api-key header
|
||||
if key := c.Request().Header.Get("xi-api-key"); key != "" {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Check token cookie
|
||||
cookie, err := c.Cookie("token")
|
||||
if err == nil && cookie != nil && cookie.Value != "" {
|
||||
return cookie.Value, nil
|
||||
}
|
||||
|
||||
return "", ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
|
||||
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error {
|
||||
return func(err error, c echo.Context) error {
|
||||
if errors.Is(err, ErrMissingOrMalformedAPIKey) {
|
||||
if len(applicationConfig.ApiKeys) == 0 {
|
||||
return ctx.Next() // if no keys are set up, any error we get here is not an error.
|
||||
return nil // if no keys are set up, any error we get here is not an error.
|
||||
}
|
||||
ctx.Set("WWW-Authenticate", "Bearer")
|
||||
c.Response().Header().Set("WWW-Authenticate", "Bearer")
|
||||
if applicationConfig.OpaqueErrors {
|
||||
return ctx.SendStatus(401)
|
||||
return c.NoContent(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
// Check if the request content type is JSON
|
||||
contentType := string(ctx.Context().Request.Header.ContentType())
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "application/json") {
|
||||
return ctx.Status(401).JSON(schema.ErrorResponse{
|
||||
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
|
||||
Error: &schema.APIError{
|
||||
Message: "An authentication key is required",
|
||||
Code: 401,
|
||||
@@ -55,50 +113,69 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.Er
|
||||
})
|
||||
}
|
||||
|
||||
return ctx.Status(401).Render("views/login", fiber.Map{
|
||||
"BaseURL": utils.BaseURL(ctx),
|
||||
return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{
|
||||
"BaseURL": BaseURL(c),
|
||||
})
|
||||
}
|
||||
if applicationConfig.OpaqueErrors {
|
||||
return ctx.SendStatus(500)
|
||||
return c.NoContent(http.StatusInternalServerError)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) {
|
||||
|
||||
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) {
|
||||
if applicationConfig.UseSubtleKeyComparison {
|
||||
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
|
||||
return func(key string, c echo.Context) (bool, error) {
|
||||
if len(applicationConfig.ApiKeys) == 0 {
|
||||
return true, nil // If no keys are setup, accept everything
|
||||
}
|
||||
for _, validKey := range applicationConfig.ApiKeys {
|
||||
if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 {
|
||||
if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, v2keyauth.ErrMissingOrMalformedAPIKey
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
}
|
||||
|
||||
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
|
||||
return func(key string, c echo.Context) (bool, error) {
|
||||
if len(applicationConfig.ApiKeys) == 0 {
|
||||
return true, nil // If no keys are setup, accept everything
|
||||
}
|
||||
for _, validKey := range applicationConfig.ApiKeys {
|
||||
if apiKey == validKey {
|
||||
if key == validKey {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, v2keyauth.ErrMissingOrMalformedAPIKey
|
||||
return false, ErrMissingOrMalformedAPIKey
|
||||
}
|
||||
}
|
||||
|
||||
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool {
|
||||
if applicationConfig.DisableApiKeyRequirementForHttpGet {
|
||||
return func(c *fiber.Ctx) bool {
|
||||
if c.Method() != "GET" {
|
||||
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper {
|
||||
return func(c echo.Context) bool {
|
||||
path := c.Request().URL.Path
|
||||
|
||||
// Always skip authentication for static files
|
||||
if strings.HasPrefix(path, "/static/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Always skip authentication for generated content
|
||||
if strings.HasPrefix(path, "/generated-audio/") ||
|
||||
strings.HasPrefix(path, "/generated-images/") ||
|
||||
strings.HasPrefix(path, "/generated-videos/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip authentication for favicon
|
||||
if path == "/favicon.svg" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle GET request exemptions if enabled
|
||||
if applicationConfig.DisableApiKeyRequirementForHttpGet {
|
||||
if c.Request().Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
|
||||
@@ -106,8 +183,8 @@ func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
return func(c *fiber.Ctx) bool { return false }
|
||||
}
|
||||
|
||||
48
core/http/middleware/baseurl.go
Normal file
48
core/http/middleware/baseurl.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// BaseURL returns the base URL for the given HTTP request context.
|
||||
// It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path.
|
||||
// The returned URL is guaranteed to end with `/`.
|
||||
// The method should be used in conjunction with the StripPathPrefix middleware.
|
||||
func BaseURL(c echo.Context) string {
|
||||
path := c.Path()
|
||||
origPath := c.Request().URL.Path
|
||||
|
||||
// Check if StripPathPrefix middleware stored the original path
|
||||
if storedPath, ok := c.Get("_original_path").(string); ok && storedPath != "" {
|
||||
origPath = storedPath
|
||||
}
|
||||
|
||||
// Check X-Forwarded-Proto for scheme
|
||||
scheme := "http"
|
||||
if c.Request().Header.Get("X-Forwarded-Proto") == "https" {
|
||||
scheme = "https"
|
||||
} else if c.Request().TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
// Check X-Forwarded-Host for host
|
||||
host := c.Request().Host
|
||||
if forwardedHost := c.Request().Header.Get("X-Forwarded-Host"); forwardedHost != "" {
|
||||
host = forwardedHost
|
||||
}
|
||||
|
||||
if path != origPath && strings.HasSuffix(origPath, path) && len(path) > 0 {
|
||||
prefixLen := len(origPath) - len(path)
|
||||
if prefixLen > 0 && prefixLen <= len(origPath) {
|
||||
pathPrefix := origPath[:prefixLen]
|
||||
if !strings.HasSuffix(pathPrefix, "/") {
|
||||
pathPrefix += "/"
|
||||
}
|
||||
return scheme + "://" + host + pathPrefix
|
||||
}
|
||||
}
|
||||
|
||||
return scheme + "://" + host + "/"
|
||||
}
|
||||
58
core/http/middleware/baseurl_test.go
Normal file
58
core/http/middleware/baseurl_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("BaseURL", func() {
|
||||
Context("without prefix", func() {
|
||||
It("should return base URL without prefix", func() {
|
||||
app := echo.New()
|
||||
actualURL := ""
|
||||
|
||||
// Register route - use the actual request path so routing works
|
||||
routePath := "/hello/world"
|
||||
app.GET(routePath, func(c echo.Context) error {
|
||||
actualURL = BaseURL(c)
|
||||
return nil
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualURL).To(Equal("http://example.com/"), "base URL")
|
||||
})
|
||||
})
|
||||
|
||||
Context("with prefix", func() {
|
||||
It("should return base URL with prefix", func() {
|
||||
app := echo.New()
|
||||
actualURL := ""
|
||||
|
||||
// Register route with the stripped path (after middleware removes prefix)
|
||||
routePath := "/hello/world"
|
||||
app.GET(routePath, func(c echo.Context) error {
|
||||
// Simulate what StripPathPrefix middleware does - store original path
|
||||
c.Set("_original_path", "/myprefix/hello/world")
|
||||
// Modify the request path to simulate prefix stripping
|
||||
c.Request().URL.Path = "/hello/world"
|
||||
actualURL = BaseURL(c)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Make request with stripped path (middleware would have already processed it)
|
||||
req := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualURL).To(Equal("http://example.com/myprefix/"), "base URL")
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,174 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// MetricsMiddleware creates a middleware that tracks API usage metrics
|
||||
// Note: Uses CONTEXT_LOCALS_KEY_MODEL_NAME constant defined in request.go
|
||||
func MetricsMiddleware(metricsStore services.MetricsStore) fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
path := c.Path()
|
||||
|
||||
// Skip tracking for UI routes, static files, and non-API endpoints
|
||||
if shouldSkipMetrics(path) {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Record start time
|
||||
start := time.Now()
|
||||
|
||||
// Get endpoint category
|
||||
endpoint := categorizeEndpoint(path)
|
||||
|
||||
// Continue with the request
|
||||
err := c.Next()
|
||||
|
||||
// Record metrics after request completes
|
||||
duration := time.Since(start)
|
||||
success := err == nil && c.Response().StatusCode() < 400
|
||||
|
||||
// Extract model name from context (set by RequestExtractor middleware)
|
||||
// Use the same constant as RequestExtractor
|
||||
model := "unknown"
|
||||
if modelVal, ok := c.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string); ok && modelVal != "" {
|
||||
model = modelVal
|
||||
log.Debug().Str("model", model).Str("endpoint", endpoint).Msg("Recording metrics for request")
|
||||
} else {
|
||||
// Fallback: try to extract from path params or query
|
||||
model = extractModelFromRequest(c)
|
||||
log.Debug().Str("model", model).Str("endpoint", endpoint).Msg("Recording metrics for request (fallback)")
|
||||
}
|
||||
|
||||
// Extract backend from response headers if available
|
||||
backend := string(c.Response().Header.Peek("X-LocalAI-Backend"))
|
||||
|
||||
// Record the request
|
||||
metricsStore.RecordRequest(endpoint, model, backend, success, duration)
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// shouldSkipMetrics determines if a request should be excluded from metrics
|
||||
func shouldSkipMetrics(path string) bool {
|
||||
// Skip UI routes
|
||||
skipPrefixes := []string{
|
||||
"/views/",
|
||||
"/static/",
|
||||
"/browse/",
|
||||
"/chat/",
|
||||
"/text2image/",
|
||||
"/tts/",
|
||||
"/talk/",
|
||||
"/models/edit/",
|
||||
"/import-model",
|
||||
"/settings",
|
||||
"/api/models", // UI API endpoints
|
||||
"/api/backends", // UI API endpoints
|
||||
"/api/operations", // UI API endpoints
|
||||
"/api/p2p", // UI API endpoints
|
||||
"/api/metrics", // Metrics API itself
|
||||
}
|
||||
|
||||
for _, prefix := range skipPrefixes {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Also skip root path and other UI pages
|
||||
if path == "/" || path == "/index" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// categorizeEndpoint maps request paths to friendly endpoint categories
|
||||
func categorizeEndpoint(path string) string {
|
||||
// OpenAI-compatible endpoints
|
||||
if strings.HasPrefix(path, "/v1/chat/completions") || strings.HasPrefix(path, "/chat/completions") {
|
||||
return "chat"
|
||||
}
|
||||
if strings.HasPrefix(path, "/v1/completions") || strings.HasPrefix(path, "/completions") {
|
||||
return "completions"
|
||||
}
|
||||
if strings.HasPrefix(path, "/v1/embeddings") || strings.HasPrefix(path, "/embeddings") {
|
||||
return "embeddings"
|
||||
}
|
||||
if strings.HasPrefix(path, "/v1/images/generations") || strings.HasPrefix(path, "/images/generations") {
|
||||
return "image-generation"
|
||||
}
|
||||
if strings.HasPrefix(path, "/v1/audio/transcriptions") || strings.HasPrefix(path, "/audio/transcriptions") {
|
||||
return "transcriptions"
|
||||
}
|
||||
if strings.HasPrefix(path, "/v1/audio/speech") || strings.HasPrefix(path, "/audio/speech") {
|
||||
return "text-to-speech"
|
||||
}
|
||||
if strings.HasPrefix(path, "/v1/models") || strings.HasPrefix(path, "/models") {
|
||||
return "models"
|
||||
}
|
||||
|
||||
// LocalAI-specific endpoints
|
||||
if strings.HasPrefix(path, "/v1/internal") {
|
||||
return "internal"
|
||||
}
|
||||
if strings.Contains(path, "/tts") {
|
||||
return "text-to-speech"
|
||||
}
|
||||
if strings.Contains(path, "/stt") || strings.Contains(path, "/whisper") {
|
||||
return "speech-to-text"
|
||||
}
|
||||
if strings.Contains(path, "/sound-generation") {
|
||||
return "sound-generation"
|
||||
}
|
||||
|
||||
// Default to the first path segment
|
||||
parts := strings.Split(strings.Trim(path, "/"), "/")
|
||||
if len(parts) > 0 {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// extractModelFromRequest attempts to extract the model name from the request
|
||||
func extractModelFromRequest(c *fiber.Ctx) string {
|
||||
// Try query parameter first
|
||||
model := c.Query("model")
|
||||
if model != "" {
|
||||
return model
|
||||
}
|
||||
|
||||
// Try to extract from JSON body for POST requests
|
||||
if c.Method() == fiber.MethodPost {
|
||||
// Read body
|
||||
bodyBytes := c.Body()
|
||||
if len(bodyBytes) > 0 {
|
||||
// Parse JSON
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &reqBody); err == nil {
|
||||
if modelVal, ok := reqBody["model"]; ok {
|
||||
if modelStr, ok := modelVal.(string); ok {
|
||||
return modelStr
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try path parameter for endpoints like /models/:model
|
||||
model = c.Params("model")
|
||||
if model != "" {
|
||||
return model
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
13
core/http/middleware/middleware_suite_test.go
Normal file
13
core/http/middleware/middleware_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Middleware test suite")
|
||||
}
|
||||
@@ -1,462 +1,482 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type correlationIDKeyType string
|
||||
|
||||
// CorrelationIDKey to track request across process boundary
|
||||
const CorrelationIDKey correlationIDKeyType = "correlationID"
|
||||
|
||||
type RequestExtractor struct {
|
||||
modelConfigLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
||||
return &RequestExtractor{
|
||||
modelConfigLoader: modelConfigLoader,
|
||||
modelLoader: modelLoader,
|
||||
applicationConfig: applicationConfig,
|
||||
}
|
||||
}
|
||||
|
||||
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
|
||||
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
|
||||
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
|
||||
|
||||
// TODO: Refactor to not return error if unchanged
|
||||
func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
|
||||
model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && model != "" {
|
||||
return
|
||||
}
|
||||
model = ctx.Params("model")
|
||||
|
||||
if (model == "") && ctx.Query("model") != "" {
|
||||
model = ctx.Query("model")
|
||||
}
|
||||
|
||||
if model == "" {
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
|
||||
if bearer != "" {
|
||||
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
|
||||
if err == nil && exists {
|
||||
model = bearer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
re.setModelNameFromRequest(ctx)
|
||||
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if !ok || localModelName == "" {
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
|
||||
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
|
||||
}
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
re.setModelNameFromRequest(ctx)
|
||||
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if localModelName != "" { // Don't overwrite existing values
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
if len(modelNames) == 0 {
|
||||
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
|
||||
// This is non-fatal - making it so was breaking the case of direct installation of raw models
|
||||
// return errors.New("this endpoint requires at least one model to be installed")
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
|
||||
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: If context and cancel above belong on all methods, move that part of above into here!
|
||||
// Otherwise, it's in its own method below for now
|
||||
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
input := initializer()
|
||||
if input == nil {
|
||||
return fmt.Errorf("unable to initialize body")
|
||||
}
|
||||
if err := ctx.BodyParser(input); err != nil {
|
||||
return fmt.Errorf("failed parsing request body: %w", err)
|
||||
}
|
||||
|
||||
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
|
||||
if input.ModelName(nil) == "" {
|
||||
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && localModelName != "" {
|
||||
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
|
||||
input.ModelName(&localModelName)
|
||||
}
|
||||
} else {
|
||||
// Update context locals with the model name from the request body
|
||||
// This ensures downstream middleware (like metrics) can access it
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, input.ModelName(nil))
|
||||
}
|
||||
|
||||
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
|
||||
} else if cfg.Model == "" && input.ModelName(nil) != "" {
|
||||
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
|
||||
cfg.Model = input.ModelName(nil)
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
|
||||
input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
// Extract or generate the correlation ID
|
||||
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
|
||||
ctx.Set("X-Correlation-ID", correlationID)
|
||||
|
||||
c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
||||
// Add the correlation ID to the new context
|
||||
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
||||
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
err := mergeOpenAIRequestAndModelConfig(cfg, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Model == "" {
|
||||
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
|
||||
cfg.Model = input.Model
|
||||
}
|
||||
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != nil {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != nil {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != nil {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != nil {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
if input.ResponseFormat != nil {
|
||||
switch responseFormat := input.ResponseFormat.(type) {
|
||||
case string:
|
||||
config.ResponseFormat = responseFormat
|
||||
case map[string]interface{}:
|
||||
config.ResponseFormatMap = responseFormat
|
||||
}
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
for _, tool := range input.Tools {
|
||||
input.Functions = append(input.Functions, tool.Function)
|
||||
}
|
||||
}
|
||||
|
||||
if input.ToolsChoice != nil {
|
||||
var toolChoice functions.Tool
|
||||
|
||||
switch content := input.ToolsChoice.(type) {
|
||||
case string:
|
||||
_ = json.Unmarshal([]byte(content), &toolChoice)
|
||||
case map[string]interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
_ = json.Unmarshal(dat, &toolChoice)
|
||||
}
|
||||
input.FunctionCall = map[string]interface{}{
|
||||
"name": toolChoice.Function.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
imgIndex, vidIndex, audioIndex := 0, 0, 0
|
||||
for i, m := range input.Messages {
|
||||
nrOfImgsInMessage := 0
|
||||
nrOfVideosInMessage := 0
|
||||
nrOfAudiosInMessage := 0
|
||||
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
|
||||
textContent := ""
|
||||
// we will template this at the end
|
||||
|
||||
CONTENT:
|
||||
for _, pp := range c {
|
||||
switch pp.Type {
|
||||
case "text":
|
||||
textContent += pp.Text
|
||||
//input.Messages[i].StringContent = pp.Text
|
||||
case "video", "video_url":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding video: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
||||
vidIndex++
|
||||
nrOfVideosInMessage++
|
||||
case "audio_url", "audio":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding audio: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "input_audio":
|
||||
// TODO: make sure that we only return base64 stuff
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "image_url", "image":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding image: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
|
||||
imgIndex++
|
||||
nrOfImgsInMessage++
|
||||
}
|
||||
}
|
||||
|
||||
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
||||
TotalImages: imgIndex,
|
||||
TotalVideos: vidIndex,
|
||||
TotalAudios: audioIndex,
|
||||
ImagesInMessage: nrOfImgsInMessage,
|
||||
VideosInMessage: nrOfVideosInMessage,
|
||||
AudiosInMessage: nrOfAudiosInMessage,
|
||||
}, textContent)
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.FrequencyPenalty != 0 {
|
||||
config.FrequencyPenalty = input.FrequencyPenalty
|
||||
}
|
||||
|
||||
if input.PresencePenalty != 0 {
|
||||
config.PresencePenalty = input.PresencePenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != nil {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.TypicalP != nil {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []any:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []any:
|
||||
tokens := []int{}
|
||||
inputStrings := []string{}
|
||||
for _, ii := range i {
|
||||
switch ii := ii.(type) {
|
||||
case int:
|
||||
tokens = append(tokens, ii)
|
||||
case float64:
|
||||
tokens = append(tokens, int(ii))
|
||||
case string:
|
||||
inputStrings = append(inputStrings, ii)
|
||||
default:
|
||||
log.Error().Msgf("Unknown input type: %T", ii)
|
||||
}
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
config.InputStrings = append(config.InputStrings, inputStrings...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a quality was defined as number, convert it to step
|
||||
if input.Quality != "" {
|
||||
q, err := strconv.Atoi(input.Quality)
|
||||
if err == nil {
|
||||
config.Step = q
|
||||
}
|
||||
}
|
||||
|
||||
if config.Validate() {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unable to validate configuration after merging")
|
||||
}
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type correlationIDKeyType string
|
||||
|
||||
// CorrelationIDKey to track request across process boundary
|
||||
const CorrelationIDKey correlationIDKeyType = "correlationID"
|
||||
|
||||
type RequestExtractor struct {
|
||||
modelConfigLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
||||
return &RequestExtractor{
|
||||
modelConfigLoader: modelConfigLoader,
|
||||
modelLoader: modelLoader,
|
||||
applicationConfig: applicationConfig,
|
||||
}
|
||||
}
|
||||
|
||||
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
|
||||
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
|
||||
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
|
||||
|
||||
// TODO: Refactor to not return error if unchanged
|
||||
func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) {
|
||||
model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && model != "" {
|
||||
return
|
||||
}
|
||||
model = c.Param("model")
|
||||
|
||||
if model == "" {
|
||||
model = c.QueryParam("model")
|
||||
}
|
||||
|
||||
if model == "" {
|
||||
// Set model from bearer token, if available
|
||||
auth := c.Request().Header.Get("Authorization")
|
||||
bearer := strings.TrimPrefix(auth, "Bearer ")
|
||||
if bearer != "" && bearer != auth {
|
||||
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
|
||||
if err == nil && exists {
|
||||
model = bearer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
re.setModelNameFromRequest(c)
|
||||
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if !ok || localModelName == "" {
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
|
||||
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
|
||||
}
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
re.setModelNameFromRequest(c)
|
||||
localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if localModelName != "" { // Don't overwrite existing values
|
||||
return next(c)
|
||||
}
|
||||
|
||||
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
|
||||
return next(c)
|
||||
}
|
||||
|
||||
if len(modelNames) == 0 {
|
||||
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
|
||||
// This is non-fatal - making it so was breaking the case of direct installation of raw models
|
||||
// return errors.New("this endpoint requires at least one model to be installed")
|
||||
return next(c)
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
|
||||
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: If context and cancel above belong on all methods, move that part of above into here!
|
||||
// Otherwise, it's in its own method below for now
|
||||
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
input := initializer()
|
||||
if input == nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body")
|
||||
}
|
||||
if err := c.Bind(input); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err))
|
||||
}
|
||||
|
||||
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
|
||||
if input.ModelName(nil) == "" {
|
||||
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
if ok && localModelName != "" {
|
||||
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
|
||||
input.ModelName(&localModelName)
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
|
||||
} else if cfg.Model == "" && input.ModelName(nil) != "" {
|
||||
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
|
||||
cfg.Model = input.ModelName(nil)
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error {
|
||||
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return echo.ErrBadRequest
|
||||
}
|
||||
|
||||
// Extract or generate the correlation ID
|
||||
correlationID := c.Request().Header.Get("X-Correlation-ID")
|
||||
if correlationID == "" {
|
||||
correlationID = uuid.New().String()
|
||||
}
|
||||
c.Response().Header().Set("X-Correlation-ID", correlationID)
|
||||
|
||||
// Use the request context directly - Echo properly supports context cancellation!
|
||||
// No need for workarounds like handleConnectionCancellation
|
||||
reqCtx := c.Request().Context()
|
||||
c1, cancel := context.WithCancel(re.applicationConfig.Context)
|
||||
|
||||
// Cancel when request context is cancelled (client disconnects)
|
||||
go func() {
|
||||
select {
|
||||
case <-reqCtx.Done():
|
||||
cancel()
|
||||
case <-c1.Done():
|
||||
// Already cancelled
|
||||
}
|
||||
}()
|
||||
|
||||
// Add the correlation ID to the new context
|
||||
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
|
||||
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
err := mergeOpenAIRequestAndModelConfig(cfg, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if cfg.Model == "" {
|
||||
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
|
||||
cfg.Model = input.Model
|
||||
}
|
||||
|
||||
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
|
||||
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
if input.TopK != nil {
|
||||
config.TopK = input.TopK
|
||||
}
|
||||
if input.TopP != nil {
|
||||
config.TopP = input.TopP
|
||||
}
|
||||
|
||||
if input.Backend != "" {
|
||||
config.Backend = input.Backend
|
||||
}
|
||||
|
||||
if input.ClipSkip != 0 {
|
||||
config.Diffusers.ClipSkip = input.ClipSkip
|
||||
}
|
||||
|
||||
if input.NegativePromptScale != 0 {
|
||||
config.NegativePromptScale = input.NegativePromptScale
|
||||
}
|
||||
|
||||
if input.NegativePrompt != "" {
|
||||
config.NegativePrompt = input.NegativePrompt
|
||||
}
|
||||
|
||||
if input.RopeFreqBase != 0 {
|
||||
config.RopeFreqBase = input.RopeFreqBase
|
||||
}
|
||||
|
||||
if input.RopeFreqScale != 0 {
|
||||
config.RopeFreqScale = input.RopeFreqScale
|
||||
}
|
||||
|
||||
if input.Grammar != "" {
|
||||
config.Grammar = input.Grammar
|
||||
}
|
||||
|
||||
if input.Temperature != nil {
|
||||
config.Temperature = input.Temperature
|
||||
}
|
||||
|
||||
if input.Maxtokens != nil {
|
||||
config.Maxtokens = input.Maxtokens
|
||||
}
|
||||
|
||||
if input.ResponseFormat != nil {
|
||||
switch responseFormat := input.ResponseFormat.(type) {
|
||||
case string:
|
||||
config.ResponseFormat = responseFormat
|
||||
case map[string]interface{}:
|
||||
config.ResponseFormatMap = responseFormat
|
||||
}
|
||||
}
|
||||
|
||||
switch stop := input.Stop.(type) {
|
||||
case string:
|
||||
if stop != "" {
|
||||
config.StopWords = append(config.StopWords, stop)
|
||||
}
|
||||
case []interface{}:
|
||||
for _, pp := range stop {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.StopWords = append(config.StopWords, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(input.Tools) > 0 {
|
||||
for _, tool := range input.Tools {
|
||||
input.Functions = append(input.Functions, tool.Function)
|
||||
}
|
||||
}
|
||||
|
||||
if input.ToolsChoice != nil {
|
||||
var toolChoice functions.Tool
|
||||
|
||||
switch content := input.ToolsChoice.(type) {
|
||||
case string:
|
||||
_ = json.Unmarshal([]byte(content), &toolChoice)
|
||||
case map[string]interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
_ = json.Unmarshal(dat, &toolChoice)
|
||||
}
|
||||
input.FunctionCall = map[string]interface{}{
|
||||
"name": toolChoice.Function.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// Decode each request's message content
|
||||
imgIndex, vidIndex, audioIndex := 0, 0, 0
|
||||
for i, m := range input.Messages {
|
||||
nrOfImgsInMessage := 0
|
||||
nrOfVideosInMessage := 0
|
||||
nrOfAudiosInMessage := 0
|
||||
|
||||
switch content := m.Content.(type) {
|
||||
case string:
|
||||
input.Messages[i].StringContent = content
|
||||
case []interface{}:
|
||||
dat, _ := json.Marshal(content)
|
||||
c := []schema.Content{}
|
||||
json.Unmarshal(dat, &c)
|
||||
|
||||
textContent := ""
|
||||
// we will template this at the end
|
||||
|
||||
CONTENT:
|
||||
for _, pp := range c {
|
||||
switch pp.Type {
|
||||
case "text":
|
||||
textContent += pp.Text
|
||||
//input.Messages[i].StringContent = pp.Text
|
||||
case "video", "video_url":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding video: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
|
||||
vidIndex++
|
||||
nrOfVideosInMessage++
|
||||
case "audio_url", "audio":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding audio: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "input_audio":
|
||||
// TODO: make sure that we only return base64 stuff
|
||||
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
|
||||
audioIndex++
|
||||
nrOfAudiosInMessage++
|
||||
case "image_url", "image":
|
||||
// Decode content as base64 either if it's an URL or base64 text
|
||||
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
|
||||
if err != nil {
|
||||
log.Error().Msgf("Failed encoding image: %s", err)
|
||||
continue CONTENT
|
||||
}
|
||||
|
||||
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
||||
|
||||
imgIndex++
|
||||
nrOfImgsInMessage++
|
||||
}
|
||||
}
|
||||
|
||||
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
|
||||
TotalImages: imgIndex,
|
||||
TotalVideos: vidIndex,
|
||||
TotalAudios: audioIndex,
|
||||
ImagesInMessage: nrOfImgsInMessage,
|
||||
VideosInMessage: nrOfVideosInMessage,
|
||||
AudiosInMessage: nrOfAudiosInMessage,
|
||||
}, textContent)
|
||||
}
|
||||
}
|
||||
|
||||
if input.RepeatPenalty != 0 {
|
||||
config.RepeatPenalty = input.RepeatPenalty
|
||||
}
|
||||
|
||||
if input.FrequencyPenalty != 0 {
|
||||
config.FrequencyPenalty = input.FrequencyPenalty
|
||||
}
|
||||
|
||||
if input.PresencePenalty != 0 {
|
||||
config.PresencePenalty = input.PresencePenalty
|
||||
}
|
||||
|
||||
if input.Keep != 0 {
|
||||
config.Keep = input.Keep
|
||||
}
|
||||
|
||||
if input.Batch != 0 {
|
||||
config.Batch = input.Batch
|
||||
}
|
||||
|
||||
if input.IgnoreEOS {
|
||||
config.IgnoreEOS = input.IgnoreEOS
|
||||
}
|
||||
|
||||
if input.Seed != nil {
|
||||
config.Seed = input.Seed
|
||||
}
|
||||
|
||||
if input.TypicalP != nil {
|
||||
config.TypicalP = input.TypicalP
|
||||
}
|
||||
|
||||
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
|
||||
|
||||
switch inputs := input.Input.(type) {
|
||||
case string:
|
||||
if inputs != "" {
|
||||
config.InputStrings = append(config.InputStrings, inputs)
|
||||
}
|
||||
case []any:
|
||||
for _, pp := range inputs {
|
||||
switch i := pp.(type) {
|
||||
case string:
|
||||
config.InputStrings = append(config.InputStrings, i)
|
||||
case []any:
|
||||
tokens := []int{}
|
||||
inputStrings := []string{}
|
||||
for _, ii := range i {
|
||||
switch ii := ii.(type) {
|
||||
case int:
|
||||
tokens = append(tokens, ii)
|
||||
case float64:
|
||||
tokens = append(tokens, int(ii))
|
||||
case string:
|
||||
inputStrings = append(inputStrings, ii)
|
||||
default:
|
||||
log.Error().Msgf("Unknown input type: %T", ii)
|
||||
}
|
||||
}
|
||||
config.InputToken = append(config.InputToken, tokens)
|
||||
config.InputStrings = append(config.InputStrings, inputStrings...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) {
|
||||
case string:
|
||||
if fnc != "" {
|
||||
config.SetFunctionCallString(fnc)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
var name string
|
||||
n, exists := fnc["name"]
|
||||
if exists {
|
||||
nn, e := n.(string)
|
||||
if e {
|
||||
name = nn
|
||||
}
|
||||
}
|
||||
config.SetFunctionCallNameString(name)
|
||||
}
|
||||
|
||||
switch p := input.Prompt.(type) {
|
||||
case string:
|
||||
config.PromptStrings = append(config.PromptStrings, p)
|
||||
case []interface{}:
|
||||
for _, pp := range p {
|
||||
if s, ok := pp.(string); ok {
|
||||
config.PromptStrings = append(config.PromptStrings, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If a quality was defined as number, convert it to step
|
||||
if input.Quality != "" {
|
||||
q, err := strconv.Atoi(input.Quality)
|
||||
if err == nil {
|
||||
config.Step = q
|
||||
}
|
||||
}
|
||||
|
||||
if config.Validate() {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("unable to validate configuration after merging")
|
||||
}
|
||||
|
||||
@@ -3,34 +3,55 @@ package middleware
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
)
|
||||
|
||||
// StripPathPrefix returns a middleware that strips a path prefix from the request path.
|
||||
// StripPathPrefix returns middleware that strips a path prefix from the request path.
|
||||
// The path prefix is obtained from the X-Forwarded-Prefix HTTP request header.
|
||||
func StripPathPrefix() fiber.Handler {
|
||||
return func(c *fiber.Ctx) error {
|
||||
for _, prefix := range c.GetReqHeaders()["X-Forwarded-Prefix"] {
|
||||
if prefix != "" {
|
||||
path := c.Path()
|
||||
pos := len(prefix)
|
||||
// This must be registered as Pre middleware (using e.Pre()) to modify the path before routing.
|
||||
func StripPathPrefix() echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
prefixes := c.Request().Header.Values("X-Forwarded-Prefix")
|
||||
originalPath := c.Request().URL.Path
|
||||
|
||||
if prefix[pos-1] == '/' {
|
||||
pos--
|
||||
} else {
|
||||
prefix += "/"
|
||||
}
|
||||
for _, prefix := range prefixes {
|
||||
if prefix != "" {
|
||||
normalizedPrefix := prefix
|
||||
if !strings.HasSuffix(prefix, "/") {
|
||||
normalizedPrefix = prefix + "/"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
c.Path(path[pos:])
|
||||
break
|
||||
} else if prefix[:pos] == path {
|
||||
c.Redirect(prefix)
|
||||
return nil
|
||||
if strings.HasPrefix(originalPath, normalizedPrefix) {
|
||||
// Update the request path by stripping the normalized prefix
|
||||
newPath := originalPath[len(normalizedPrefix):]
|
||||
if newPath == "" {
|
||||
newPath = "/"
|
||||
}
|
||||
// Ensure path starts with / for proper routing
|
||||
if !strings.HasPrefix(newPath, "/") {
|
||||
newPath = "/" + newPath
|
||||
}
|
||||
// Update the URL path - Echo's router uses URL.Path for routing
|
||||
c.Request().URL.Path = newPath
|
||||
c.Request().URL.RawPath = ""
|
||||
// Update RequestURI to match the new path (needed for proper routing)
|
||||
if c.Request().URL.RawQuery != "" {
|
||||
c.Request().RequestURI = newPath + "?" + c.Request().URL.RawQuery
|
||||
} else {
|
||||
c.Request().RequestURI = newPath
|
||||
}
|
||||
// Store original path for BaseURL utility
|
||||
c.Set("_original_path", originalPath)
|
||||
break
|
||||
} else if originalPath == prefix || originalPath == prefix+"/" {
|
||||
// Redirect to prefix with trailing slash (use 302 to match test expectations)
|
||||
return c.Redirect(302, normalizedPrefix)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return c.Next()
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,120 +2,133 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/labstack/echo/v4"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestStripPathPrefix(t *testing.T) {
|
||||
var _ = Describe("StripPathPrefix", func() {
|
||||
var app *echo.Echo
|
||||
var actualPath string
|
||||
var appInitialized bool
|
||||
|
||||
app := fiber.New()
|
||||
BeforeEach(func() {
|
||||
actualPath = ""
|
||||
if !appInitialized {
|
||||
app = echo.New()
|
||||
app.Pre(StripPathPrefix())
|
||||
|
||||
app.Use(StripPathPrefix())
|
||||
app.GET("/hello/world", func(c echo.Context) error {
|
||||
actualPath = c.Request().URL.Path
|
||||
return nil
|
||||
})
|
||||
|
||||
app.Get("/hello/world", func(c *fiber.Ctx) error {
|
||||
actualPath = c.Path()
|
||||
return nil
|
||||
app.GET("/", func(c echo.Context) error {
|
||||
actualPath = c.Request().URL.Path
|
||||
return nil
|
||||
})
|
||||
appInitialized = true
|
||||
}
|
||||
})
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
actualPath = c.Path()
|
||||
return nil
|
||||
})
|
||||
Context("without prefix", func() {
|
||||
It("should not modify path when no header is present", func() {
|
||||
req := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
path string
|
||||
prefixHeader []string
|
||||
expectStatus int
|
||||
expectPath string
|
||||
}{
|
||||
{
|
||||
name: "without prefix and header",
|
||||
path: "/hello/world",
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "without prefix and headers on root path",
|
||||
path: "/",
|
||||
expectStatus: 200,
|
||||
expectPath: "/",
|
||||
},
|
||||
{
|
||||
name: "without prefix but header",
|
||||
path: "/hello/world",
|
||||
prefixHeader: []string{"/otherprefix/"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix but non-matching header",
|
||||
path: "/prefix/hello/world",
|
||||
prefixHeader: []string{"/otherprefix/"},
|
||||
expectStatus: 404,
|
||||
},
|
||||
{
|
||||
name: "with prefix and matching header",
|
||||
path: "/myprefix/hello/world",
|
||||
prefixHeader: []string{"/myprefix/"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix and 1st header matching",
|
||||
path: "/myprefix/hello/world",
|
||||
prefixHeader: []string{"/myprefix/", "/otherprefix/"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix and 2nd header matching",
|
||||
path: "/myprefix/hello/world",
|
||||
prefixHeader: []string{"/otherprefix/", "/myprefix/"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix and header not ending with slash",
|
||||
path: "/myprefix/hello/world",
|
||||
prefixHeader: []string{"/myprefix"},
|
||||
expectStatus: 200,
|
||||
expectPath: "/hello/world",
|
||||
},
|
||||
{
|
||||
name: "with prefix and non-matching header not ending with slash",
|
||||
path: "/myprefix-suffix/hello/world",
|
||||
prefixHeader: []string{"/myprefix"},
|
||||
expectStatus: 404,
|
||||
},
|
||||
{
|
||||
name: "redirect when prefix does not end with a slash",
|
||||
path: "/myprefix",
|
||||
prefixHeader: []string{"/myprefix"},
|
||||
expectStatus: 302,
|
||||
expectPath: "/myprefix/",
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
actualPath = ""
|
||||
req := httptest.NewRequest("GET", tc.path, nil)
|
||||
if tc.prefixHeader != nil {
|
||||
req.Header["X-Forwarded-Prefix"] = tc.prefixHeader
|
||||
}
|
||||
|
||||
resp, err := app.Test(req, -1)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tc.expectStatus, resp.StatusCode, "response status code")
|
||||
|
||||
if tc.expectStatus == 200 {
|
||||
require.Equal(t, tc.expectPath, actualPath, "rewritten path")
|
||||
} else if tc.expectStatus == 302 {
|
||||
require.Equal(t, tc.expectPath, resp.Header.Get("Location"), "redirect location")
|
||||
}
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
It("should not modify root path when no header is present", func() {
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should not modify path when header does not match", func() {
|
||||
req := httptest.NewRequest("GET", "/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
})
|
||||
|
||||
Context("with prefix", func() {
|
||||
It("should return 404 when prefix does not match header", func() {
|
||||
req := httptest.NewRequest("GET", "/prefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(404), "response status code")
|
||||
})
|
||||
|
||||
It("should strip matching prefix from path", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should strip prefix when it matches the first header value", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/", "/otherprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should strip prefix when it matches the second header value", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/", "/myprefix/"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should strip prefix when header does not end with slash", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(200), "response status code")
|
||||
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
|
||||
})
|
||||
|
||||
It("should return 404 when prefix does not match header without trailing slash", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix-suffix/hello/world", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(404), "response status code")
|
||||
})
|
||||
|
||||
It("should redirect when prefix does not end with a slash", func() {
|
||||
req := httptest.NewRequest("GET", "/myprefix", nil)
|
||||
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
|
||||
rec := httptest.NewRecorder()
|
||||
app.ServeHTTP(rec, req)
|
||||
|
||||
Expect(rec.Code).To(Equal(302), "response status code")
|
||||
Expect(rec.Header().Get("Location")).To(Equal("/myprefix/"), "redirect location")
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
"fmt"
|
||||
. "github.com/mudler/LocalAI/core/http"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -62,7 +62,7 @@ func (f *fakeAI) VAD(*pb.VADRequest) (pb.VADResponse, error) { return pb.VADResp
|
||||
var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
|
||||
var tmpdir string
|
||||
var appServer *application.Application
|
||||
var app *fiber.App
|
||||
var app *echo.Echo
|
||||
var ctx context.Context
|
||||
var cancel context.CancelFunc
|
||||
|
||||
@@ -97,7 +97,9 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
if app != nil {
|
||||
_ = app.Shutdown()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = app.Shutdown(ctx)
|
||||
}
|
||||
_ = os.RemoveAll(tmpdir)
|
||||
})
|
||||
@@ -106,7 +108,11 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
|
||||
var err error
|
||||
app, err = API(appServer)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
go app.Listen("127.0.0.1:9091")
|
||||
go func() {
|
||||
if err := app.Start("127.0.0.1:9091"); err != nil && err != http.ErrServerClosed {
|
||||
// Log error if needed
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for server
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
|
||||
@@ -4,13 +4,15 @@ import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Masterminds/sprig/v3"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
fiberhtml "github.com/gofiber/template/html/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/microcosm-cc/bluemonday"
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/russross/blackfriday"
|
||||
)
|
||||
@@ -18,26 +20,67 @@ import (
|
||||
//go:embed views/*
|
||||
var viewsfs embed.FS
|
||||
|
||||
func notFoundHandler(c *fiber.Ctx) error {
|
||||
// TemplateRenderer is a custom template renderer for Echo
|
||||
type TemplateRenderer struct {
|
||||
templates *template.Template
|
||||
}
|
||||
|
||||
// Render renders a template document
|
||||
func (t *TemplateRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
|
||||
return t.templates.ExecuteTemplate(w, name, data)
|
||||
}
|
||||
|
||||
func notFoundHandler(c echo.Context) error {
|
||||
// Check if the request accepts JSON
|
||||
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
|
||||
contentType := c.Request().Header.Get("Content-Type")
|
||||
accept := c.Request().Header.Get("Accept")
|
||||
if strings.Contains(contentType, "application/json") || !strings.Contains(accept, "text/html") {
|
||||
// The client expects a JSON response
|
||||
return c.Status(fiber.StatusNotFound).JSON(schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: "Resource not found", Code: fiber.StatusNotFound},
|
||||
return c.JSON(http.StatusNotFound, schema.ErrorResponse{
|
||||
Error: &schema.APIError{Message: "Resource not found", Code: http.StatusNotFound},
|
||||
})
|
||||
} else {
|
||||
// The client expects an HTML response
|
||||
return c.Status(fiber.StatusNotFound).Render("views/404", fiber.Map{
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
return c.Render(http.StatusNotFound, "views/404", map[string]interface{}{
|
||||
"BaseURL": middleware.BaseURL(c),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func renderEngine() *fiberhtml.Engine {
|
||||
engine := fiberhtml.NewFileSystem(http.FS(viewsfs), ".html")
|
||||
engine.AddFuncMap(sprig.FuncMap())
|
||||
engine.AddFunc("MDToHTML", markDowner)
|
||||
return engine
|
||||
func renderEngine() *TemplateRenderer {
|
||||
// Parse all templates from embedded filesystem
|
||||
tmpl := template.New("").Funcs(sprig.FuncMap())
|
||||
tmpl = tmpl.Funcs(template.FuncMap{
|
||||
"MDToHTML": markDowner,
|
||||
})
|
||||
|
||||
// Recursively walk through embedded filesystem and parse all HTML templates
|
||||
err := fs.WalkDir(viewsfs, "views", func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !d.IsDir() && strings.HasSuffix(path, ".html") {
|
||||
data, err := viewsfs.ReadFile(path)
|
||||
if err == nil {
|
||||
// Remove .html extension to get template name (e.g., "views/index.html" -> "views/index")
|
||||
templateName := strings.TrimSuffix(path, ".html")
|
||||
_, err := tmpl.New(templateName).Parse(string(data))
|
||||
if err != nil {
|
||||
// If parsing fails, try parsing without explicit name (for templates with {{define}})
|
||||
tmpl.Parse(string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but continue - templates might still work
|
||||
fmt.Printf("Error walking views directory: %v\n", err)
|
||||
}
|
||||
|
||||
return &TemplateRenderer{
|
||||
templates: tmpl,
|
||||
}
|
||||
}
|
||||
|
||||
func markDowner(args ...interface{}) template.HTML {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/elevenlabs"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
@@ -9,21 +9,23 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func RegisterElevenLabsRoutes(app *fiber.App,
|
||||
func RegisterElevenLabsRoutes(app *echo.Echo,
|
||||
re *middleware.RequestExtractor,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig) {
|
||||
|
||||
// Elevenlabs
|
||||
app.Post("/v1/text-to-speech/:voice-id",
|
||||
ttsHandler := elevenlabs.TTSEndpoint(cl, ml, appConfig)
|
||||
app.POST("/v1/text-to-speech/:voice-id",
|
||||
ttsHandler,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }),
|
||||
elevenlabs.TTSEndpoint(cl, ml, appConfig))
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }))
|
||||
|
||||
app.Post("/v1/sound-generation",
|
||||
soundGenHandler := elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)
|
||||
app.POST("/v1/sound-generation",
|
||||
soundGenHandler,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }),
|
||||
elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig))
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }))
|
||||
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user