Compare commits

..

1 Commits

Author SHA1 Message Date
Ettore Di Giacinto
eebda7204e feat(ui): add front-page stats
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-10-28 15:58:00 +01:00
207 changed files with 6733 additions and 16642 deletions

View File

@@ -1,8 +0,0 @@
# .air.toml
[build]
cmd = "make build"
bin = "./local-ai"
args_bin = [ "--debug" ]
include_ext = ["go", "html", "yaml", "toml", "json", "txt", "md"]
exclude_dir = ["pkg/grpc/proto"]
delay = 1000

View File

@@ -7,8 +7,8 @@ import (
"slices" "slices"
"strings" "strings"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" "github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
cogito "github.com/mudler/cogito" "github.com/mudler/cogito"
"github.com/mudler/cogito/structures" "github.com/mudler/cogito/structures"
"github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"

39
.github/gallery-agent/go.mod vendored Normal file
View File

@@ -0,0 +1,39 @@
module github.com/go-skynet/LocalAI/.github/gallery-agent
go 1.24.1
require (
github.com/mudler/cogito v0.3.0
github.com/onsi/ginkgo/v2 v2.25.3
github.com/onsi/gomega v1.38.2
github.com/sashabaranov/go-openai v1.41.2
github.com/tmc/langchaingo v0.1.13
gopkg.in/yaml.v3 v3.0.1
)
require (
dario.cat/mergo v1.0.1 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.4.0 // indirect
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/jsonschema-go v0.3.0 // indirect
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect
github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/modelcontextprotocol/go-sdk v1.0.0 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/spf13/cast v1.7.0 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.41.0 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.28.0 // indirect
golang.org/x/tools v0.36.0 // indirect
)

168
.github/gallery-agent/go.sum vendored Normal file
View File

@@ -0,0 +1,168 @@
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0=
github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs=
github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.2.2+incompatible h1:CjwRSksz8Yo4+RmQ339Dp/D2tGO5JxwYeqtMOEe0LDw=
github.com/docker/docker v28.2.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8=
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ=
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modelcontextprotocol/go-sdk v1.0.0 h1:Z4MSjLi38bTgLrd/LjSmofqRqyBiVKRyQSJgw8q8V74=
github.com/modelcontextprotocol/go-sdk v1.0.0/go.mod h1:nYtYQroQ2KQiM0/SbyEPUWQ6xs4B95gJjEalc9AQyOs=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/mudler/cogito v0.3.0 h1:NbVAO3bLkK5oGSY0xq87jlz8C9OIsLW55s+8Hfzeu9s=
github.com/mudler/cogito v0.3.0/go.mod h1:abMwl+CUjCp87IufA2quZdZt0bbLaHHN79o17HbUKxU=
github.com/onsi/ginkgo/v2 v2.25.3 h1:Ty8+Yi/ayDAGtk4XxmmfUy4GabvM+MegeB4cDLRi6nw=
github.com/onsi/ginkgo/v2 v2.25.3/go.mod h1:43uiyQC4Ed2tkOzLsEYm7hnrb7UJTWHYNsuy3bG/snE=
github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A=
github.com/onsi/gomega v1.38.2/go.mod h1:W2MJcYxRGV63b418Ai34Ud0hEdTVXq9NW9+Sx6uXf3k=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM=
github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/shirou/gopsutil/v4 v4.25.5 h1:rtd9piuSMGeU8g1RMXjZs9y9luK5BwtnG7dZaQUJAsc=
github.com/shirou/gopsutil/v4 v4.25.5/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/testcontainers/testcontainers-go v0.38.0 h1:d7uEapLcv2P8AvH8ahLqDMMxda2W9gQN1nRbHS28HBw=
github.com/testcontainers/testcontainers-go v0.38.0/go.mod h1:C52c9MoHpWO+C4aqmgSU+hxlR5jlEayWtgYrb8Pzz1w=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA=
github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -51,7 +51,6 @@ type ModelFile struct {
Size int64 Size int64
SHA256 string SHA256 string
IsReadme bool IsReadme bool
URL string
} }
// ModelDetails represents detailed information about a model // ModelDetails represents detailed information about a model
@@ -216,7 +215,6 @@ func (c *Client) GetModelDetails(repoID string) (*ModelDetails, error) {
} }
// Process each file // Process each file
baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
for _, file := range files { for _, file := range files {
fileName := filepath.Base(file.Path) fileName := filepath.Base(file.Path)
isReadme := strings.Contains(strings.ToLower(fileName), "readme") isReadme := strings.Contains(strings.ToLower(fileName), "readme")
@@ -229,16 +227,11 @@ func (c *Client) GetModelDetails(repoID string) (*ModelDetails, error) {
sha256 = file.Oid sha256 = file.Oid
} }
// Construct the full URL for the file
// Use /resolve/main/ for downloading files (handles LFS properly)
fileURL := fmt.Sprintf("%s/%s/resolve/main/%s", baseURL, repoID, file.Path)
modelFile := ModelFile{ modelFile := ModelFile{
Path: file.Path, Path: file.Path,
Size: file.Size, Size: file.Size,
SHA256: sha256, SHA256: sha256,
IsReadme: isReadme, IsReadme: isReadme,
URL: fileURL,
} }
details.Files = append(details.Files, modelFile) details.Files = append(details.Files, modelFile)

View File

@@ -1,7 +1,6 @@
package hfapi_test package hfapi_test
import ( import (
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@@ -9,7 +8,7 @@ import (
. "github.com/onsi/ginkgo/v2" . "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" "github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
) )
var _ = Describe("HuggingFace API Client", func() { var _ = Describe("HuggingFace API Client", func() {
@@ -271,15 +270,6 @@ var _ = Describe("HuggingFace API Client", func() {
}) })
}) })
Context("when getting file SHA on remote model", func() {
It("should get file SHA successfully", func() {
sha, err := client.GetFileSHA(
"mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF", "localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf")
Expect(err).ToNot(HaveOccurred())
Expect(sha).To(Equal("4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4"))
})
})
Context("when listing files", func() { Context("when listing files", func() {
BeforeEach(func() { BeforeEach(func() {
mockFilesResponse := `[ mockFilesResponse := `[
@@ -339,25 +329,23 @@ var _ = Describe("HuggingFace API Client", func() {
Context("when getting file SHA", func() { Context("when getting file SHA", func() {
BeforeEach(func() { BeforeEach(func() {
mockFilesResponse := `[ mockFileInfoResponse := `{
{
"type": "file",
"path": "model-Q4_K_M.gguf", "path": "model-Q4_K_M.gguf",
"size": 1000000, "size": 1000000,
"oid": "abc123", "oid": "abc123",
"lfs": { "lfs": {
"oid": "def456789", "oid": "sha256:def456",
"size": 1000000, "size": 1000000,
"pointerSize": 135 "pointer": "version https://git-lfs.github.com/spec/v1",
"sha256": "def456789"
} }
} }`
]`
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/tree/main") { if strings.Contains(r.URL.Path, "/paths-info") {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(mockFilesResponse)) w.Write([]byte(mockFileInfoResponse))
} else { } else {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
} }
@@ -375,29 +363,18 @@ var _ = Describe("HuggingFace API Client", func() {
It("should handle missing SHA gracefully", func() { It("should handle missing SHA gracefully", func() {
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/tree/main") {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(`[ w.Write([]byte(`{"path": "file.txt", "size": 100}`))
{
"type": "file",
"path": "file.txt",
"size": 100,
"oid": "file123"
}
]`))
} else {
w.WriteHeader(http.StatusNotFound)
}
})) }))
client.SetBaseURL(server.URL) client.SetBaseURL(server.URL)
sha, err := client.GetFileSHA("test/model", "file.txt") sha, err := client.GetFileSHA("test/model", "file.txt")
Expect(err).ToNot(HaveOccurred()) Expect(err).To(HaveOccurred())
// When there's no LFS, it should return the OID Expect(err.Error()).To(ContainSubstring("no SHA256 found"))
Expect(sha).To(Equal("file123")) Expect(sha).To(Equal(""))
}) })
}) })
@@ -462,13 +439,6 @@ var _ = Describe("HuggingFace API Client", func() {
Expect(details.ReadmeFile).ToNot(BeNil()) Expect(details.ReadmeFile).ToNot(BeNil())
Expect(details.ReadmeFile.Path).To(Equal("README.md")) Expect(details.ReadmeFile.Path).To(Equal("README.md"))
Expect(details.ReadmeFile.IsReadme).To(BeTrue()) Expect(details.ReadmeFile.IsReadme).To(BeTrue())
// Verify URLs are set for all files
baseURL := strings.TrimSuffix(server.URL, "/api/models")
for _, file := range details.Files {
expectedURL := fmt.Sprintf("%s/test/model/resolve/main/%s", baseURL, file.Path)
Expect(file.URL).To(Equal(expectedURL))
}
}) })
}) })

View File

@@ -11,5 +11,3 @@ func TestHfapi(t *testing.T) {
RegisterFailHandler(Fail) RegisterFailHandler(Fail)
RunSpecs(t, "HuggingFace API Suite") RunSpecs(t, "HuggingFace API Suite")
} }

View File

@@ -9,7 +9,7 @@ import (
"strings" "strings"
"time" "time"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" "github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
) )
// ProcessedModelFile represents a processed model file with additional metadata // ProcessedModelFile represents a processed model file with additional metadata

View File

@@ -3,9 +3,9 @@ package main
import ( import (
"fmt" "fmt"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" "github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
openai "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
jsonschema "github.com/sashabaranov/go-openai/jsonschema" "github.com/tmc/langchaingo/jsonschema"
) )
// Get repository README from HF // Get repository README from HF
@@ -13,7 +13,7 @@ type HFReadmeTool struct {
client *hfapi.Client client *hfapi.Client
} }
func (s *HFReadmeTool) Execute(args map[string]any) (string, error) { func (s *HFReadmeTool) Run(args map[string]any) (string, error) {
q, ok := args["repository"].(string) q, ok := args["repository"].(string)
if !ok { if !ok {
return "", fmt.Errorf("no query") return "", fmt.Errorf("no query")

View File

@@ -1,10 +1,10 @@
name: Bump Backend dependencies name: Bump dependencies
on: on:
schedule: schedule:
- cron: 0 20 * * * - cron: 0 20 * * *
workflow_dispatch: workflow_dispatch:
jobs: jobs:
bump-backends: bump:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:

View File

@@ -1,10 +1,10 @@
name: Bump Documentation name: Bump dependencies
on: on:
schedule: schedule:
- cron: 0 20 * * * - cron: 0 20 * * *
workflow_dispatch: workflow_dispatch:
jobs: jobs:
bump-docs: bump:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:

View File

@@ -33,7 +33,7 @@ jobs:
run: | run: |
CGO_ENABLED=0 make build CGO_ENABLED=0 make build
- name: rm - name: rm
uses: appleboy/ssh-action@v1.2.3 uses: appleboy/ssh-action@v1.2.2
with: with:
host: ${{ secrets.EXPLORER_SSH_HOST }} host: ${{ secrets.EXPLORER_SSH_HOST }}
username: ${{ secrets.EXPLORER_SSH_USERNAME }} username: ${{ secrets.EXPLORER_SSH_USERNAME }}
@@ -53,7 +53,7 @@ jobs:
rm: true rm: true
target: ./local-ai target: ./local-ai
- name: restarting - name: restarting
uses: appleboy/ssh-action@v1.2.3 uses: appleboy/ssh-action@v1.2.2
with: with:
host: ${{ secrets.EXPLORER_SSH_HOST }} host: ${{ secrets.EXPLORER_SSH_HOST }}
username: ${{ secrets.EXPLORER_SSH_USERNAME }} username: ${{ secrets.EXPLORER_SSH_USERNAME }}

View File

@@ -2,7 +2,7 @@ name: Gallery Agent
on: on:
schedule: schedule:
- cron: '0 */3 * * *' # Run every 4 hours - cron: '0 */1 * * *' # Run every 4 hours
workflow_dispatch: workflow_dispatch:
inputs: inputs:
search_term: search_term:
@@ -39,6 +39,11 @@ jobs:
with: with:
go-version: '1.21' go-version: '1.21'
- name: Build gallery agent
run: |
cd .github/gallery-agent
go mod download
go build -o gallery-agent .
- name: Run gallery agent - name: Run gallery agent
env: env:
@@ -51,7 +56,9 @@ jobs:
MAX_MODELS: ${{ github.event.inputs.max_models || '1' }} MAX_MODELS: ${{ github.event.inputs.max_models || '1' }}
run: | run: |
export GALLERY_INDEX_PATH=$PWD/gallery/index.yaml export GALLERY_INDEX_PATH=$PWD/gallery/index.yaml
go run .github/gallery-agent cd .github/gallery-agent
./gallery-agent
rm -rf gallery-agent
- name: Check for changes - name: Check for changes
id: check_changes id: check_changes

View File

@@ -30,7 +30,6 @@ Thank you for your interest in contributing to LocalAI! We appreciate your time
3. Install the required dependencies ( see https://localai.io/basics/build/#build-localai-locally ) 3. Install the required dependencies ( see https://localai.io/basics/build/#build-localai-locally )
4. Build LocalAI: `make build` 4. Build LocalAI: `make build`
5. Run LocalAI: `./local-ai` 5. Run LocalAI: `./local-ai`
6. To Build and live reload: `make build-dev`
## Contributing ## Contributing

View File

@@ -103,10 +103,6 @@ build-launcher: ## Build the launcher application
build-all: build build-launcher ## Build both server and launcher build-all: build build-launcher ## Build both server and launcher
build-dev: ## Run LocalAI in dev mode with live reload
@command -v air >/dev/null 2>&1 || go install github.com/air-verse/air@latest
air -c .air.toml
dev-dist: dev-dist:
$(GORELEASER) build --snapshot --clean $(GORELEASER) build --snapshot --clean

View File

@@ -43,7 +43,7 @@
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/) > :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
> >
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on > [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🌍 Explorer](https://explorer.localai.io) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white)](https://t.me/localaiofficial_bot) [![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white)](https://t.me/localaiofficial_bot)
[![tests](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[![Build and Release](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[![build container images](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[![Bump dependencies](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/localai)](https://artifacthub.io/packages/search?repo=localai) [![tests](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[![Build and Release](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[![build container images](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[![Bump dependencies](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/localai)](https://artifacthub.io/packages/search?repo=localai)
@@ -116,8 +116,6 @@ For more installation options, see [Installer Options](https://localai.io/docs/a
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/> <img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
</a> </a>
> Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244
Or run with docker: Or run with docker:
> **💡 Docker Run vs Docker Start** > **💡 Docker Run vs Docker Start**
@@ -202,7 +200,7 @@ local-ai run oci://localai/phi-2:latest
> ⚡ **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection). > ⚡ **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection).
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html), if you are interested in our roadmap items and future enhancements, you can see the [Issues labeled as Roadmap here](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html)
## 📰 Latest project news ## 📰 Latest project news

View File

@@ -154,10 +154,6 @@ message PredictOptions {
repeated string Videos = 45; repeated string Videos = 45;
repeated string Audios = 46; repeated string Audios = 46;
string CorrelationId = 47; string CorrelationId = 47;
string Tools = 48; // JSON array of available tools/functions for tool calling
string ToolChoice = 49; // JSON string or object specifying tool choice behavior
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
} }
// The response message containing the result // The response message containing the result
@@ -168,7 +164,6 @@ message Reply {
double timing_prompt_processing = 4; double timing_prompt_processing = 4;
double timing_token_generation = 5; double timing_token_generation = 5;
bytes audio = 6; bytes audio = 6;
bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format
} }
message GrammarTrigger { message GrammarTrigger {
@@ -387,11 +382,6 @@ message StatusResponse {
message Message { message Message {
string role = 1; string role = 1;
string content = 2; string content = 2;
// Optional fields for OpenAI-compatible message format
string name = 3; // Tool name (for tool messages)
string tool_call_id = 4; // Tool call ID (for tool messages)
string reasoning_content = 5; // Reasoning content (for thinking models)
string tool_calls = 6; // Tool calls as JSON string (for assistant messages with tool calls)
} }
message DetectOptions { message DetectOptions {

View File

@@ -1,5 +1,5 @@
LLAMA_VERSION?=80deff3648b93727422461c41c7279ef1dac7452 LLAMA_VERSION?=5a4ff43e7dd049e35942bc3d12361dab2f155544
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
CMAKE_ARGS?= CMAKE_ARGS?=

View File

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# whisper.cpp version # whisper.cpp version
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
WHISPER_CPP_VERSION?=d9b7613b34a343848af572cc14467fc5e82fc788 WHISPER_CPP_VERSION?=f16c12f3f55f5bd3d6ac8cf2f31ab90a42c884d5
SO_TARGET?=libgowhisper.so SO_TARGET?=libgowhisper.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -2,7 +2,6 @@
accelerate accelerate
torch torch
torchaudio torchaudio
numpy>=1.24.0,<1.26.0
transformers transformers
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289 # https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster

View File

@@ -2,7 +2,6 @@
torch==2.6.0+cu118 torch==2.6.0+cu118
torchaudio==2.6.0+cu118 torchaudio==2.6.0+cu118
transformers==4.46.3 transformers==4.46.3
numpy>=1.24.0,<1.26.0
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289 # https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
accelerate accelerate

View File

@@ -1,7 +1,6 @@
torch torch
torchaudio torchaudio
transformers transformers
numpy>=1.24.0,<1.26.0
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289 # https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
accelerate accelerate

View File

@@ -2,7 +2,6 @@
torch==2.6.0+rocm6.1 torch==2.6.0+rocm6.1
torchaudio==2.6.0+rocm6.1 torchaudio==2.6.0+rocm6.1
transformers transformers
numpy>=1.24.0,<1.26.0
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289 # https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
accelerate accelerate

View File

@@ -3,7 +3,6 @@ intel-extension-for-pytorch==2.3.110+xpu
torch==2.3.1+cxx11.abi torch==2.3.1+cxx11.abi
torchaudio==2.3.1+cxx11.abi torchaudio==2.3.1+cxx11.abi
transformers transformers
numpy>=1.24.0,<1.26.0
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289 # https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
accelerate accelerate

View File

@@ -2,6 +2,5 @@
torch torch
torchaudio torchaudio
transformers transformers
numpy>=1.24.0,<1.26.0
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
accelerate accelerate

View File

@@ -75,13 +75,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
documents.append(doc) documents.append(doc)
ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents)))) ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents))))
# Prepare results to return # Prepare results to return
cropped_results = ranked_results.top_k(request.top_n) if request.top_n > 0 else ranked_results
results = [ results = [
backend_pb2.DocumentResult( backend_pb2.DocumentResult(
index=res.doc_id, index=res.doc_id,
text=res.text, text=res.text,
relevance_score=res.score relevance_score=res.score
) for res in (cropped_results) ) for res in ranked_results.results
] ]
# Calculate the usage and total tokens # Calculate the usage and total tokens

View File

@@ -88,59 +88,3 @@ class TestBackendServicer(unittest.TestCase):
self.fail("Reranker service failed") self.fail("Reranker service failed")
finally: finally:
self.tearDown() self.tearDown()
def test_rerank_omit_top_n(self):
"""
This method tests if the embeddings are generated successfully even top_n is omitted
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
request = backend_pb2.RerankRequest(
query="I love you",
documents=["I hate you", "I really like you"],
top_n=0 #
)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
self.assertTrue(response.success)
rerank_response = stub.Rerank(request)
print(rerank_response.results[0])
self.assertIsNotNone(rerank_response.results)
self.assertEqual(len(rerank_response.results), 2)
self.assertEqual(rerank_response.results[0].text, "I really like you")
self.assertEqual(rerank_response.results[1].text, "I hate you")
except Exception as err:
print(err)
self.fail("Reranker service failed")
finally:
self.tearDown()
def test_rerank_crop(self):
"""
This method tests top_n cropping
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
request = backend_pb2.RerankRequest(
query="I love you",
documents=["I hate you", "I really like you", "I hate ignoring top_n"],
top_n=2
)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
self.assertTrue(response.success)
rerank_response = stub.Rerank(request)
print(rerank_response.results[0])
self.assertIsNotNone(rerank_response.results)
self.assertEqual(len(rerank_response.results), 2)
self.assertEqual(rerank_response.results[0].text, "I really like you")
self.assertEqual(rerank_response.results[1].text, "I hate you")
except Exception as err:
print(err)
self.fail("Reranker service failed")
finally:
self.tearDown()

View File

@@ -22,15 +22,9 @@ func New(opts ...config.AppOption) (*Application, error) {
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.SystemState.Model.ModelsPath) log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.SystemState.Model.ModelsPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
if err := application.start(); err != nil {
return nil, err
}
caps, err := xsysinfo.CPUCapabilities() caps, err := xsysinfo.CPUCapabilities()
if err == nil { if err == nil {
log.Debug().Msgf("CPU capabilities: %v", caps) log.Debug().Msgf("CPU capabilities: %v", caps)
} }
gpus, err := xsysinfo.GPUs() gpus, err := xsysinfo.GPUs()
if err == nil { if err == nil {
@@ -62,12 +56,12 @@ func New(opts ...config.AppOption) (*Application, error) {
} }
} }
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil { if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
log.Error().Err(err).Msg("error installing models") log.Error().Err(err).Msg("error installing models")
} }
for _, backend := range options.ExternalBackends { for _, backend := range options.ExternalBackends {
if err := coreStartup.InstallExternalBackends(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil { if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
log.Error().Err(err).Msg("error installing external backend") log.Error().Err(err).Msg("error installing external backend")
} }
} }
@@ -158,6 +152,10 @@ func New(opts ...config.AppOption) (*Application, error) {
// Watch the configuration directory // Watch the configuration directory
startWatcher(options) startWatcher(options)
if err := application.start(); err != nil {
return nil, err
}
log.Info().Msg("core/startup process completed!") log.Info().Msg("core/startup process completed!")
return application, nil return application, nil
} }

View File

@@ -3,6 +3,7 @@ package backend
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"regexp" "regexp"
"slices" "slices"
"strings" "strings"
@@ -25,7 +26,6 @@ type LLMResponse struct {
Response string // should this be []byte? Response string // should this be []byte?
Usage TokenUsage Usage TokenUsage
AudioOutput string AudioOutput string
Logprobs *schema.Logprobs // Logprobs from the backend response
} }
type TokenUsage struct { type TokenUsage struct {
@@ -35,7 +35,7 @@ type TokenUsage struct {
TimingTokenGeneration float64 TimingTokenGeneration float64
} }
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (LLMResponse, error), error) { func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model modelFile := c.Model
// Check if the modelFile exists, if it doesn't try to load it from the gallery // Check if the modelFile exists, if it doesn't try to load it from the gallery
@@ -47,7 +47,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
if !slices.Contains(modelNames, c.Name) { if !slices.Contains(modelNames, c.Name) {
utils.ResetDownloadTimers() utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it // if we failed to load the model, we try to download it
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries) err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
if err != nil { if err != nil {
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile) log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
//return nil, err //return nil, err
@@ -65,8 +65,29 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
var protoMessages []*proto.Message var protoMessages []*proto.Message
// if we are using the tokenizer template, we need to convert the messages to proto messages // if we are using the tokenizer template, we need to convert the messages to proto messages
// unless the prompt has already been tokenized (non-chat endpoints + functions) // unless the prompt has already been tokenized (non-chat endpoints + functions)
if c.TemplateConfig.UseTokenizerTemplate && len(messages) > 0 { if c.TemplateConfig.UseTokenizerTemplate && s == "" {
protoMessages = messages.ToProto() protoMessages = make([]*proto.Message, len(messages), len(messages))
for i, message := range messages {
protoMessages[i] = &proto.Message{
Role: message.Role,
}
switch ct := message.Content.(type) {
case string:
protoMessages[i].Content = ct
case []interface{}:
// If using the tokenizer template, in case of multimodal we want to keep the multimodal content as and return only strings here
data, _ := json.Marshal(ct)
resultData := []struct {
Text string `json:"text"`
}{}
json.Unmarshal(data, &resultData)
for _, r := range resultData {
protoMessages[i].Content += r.Text
}
default:
return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
}
}
} }
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
@@ -78,21 +99,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
opts.Images = images opts.Images = images
opts.Videos = videos opts.Videos = videos
opts.Audios = audios opts.Audios = audios
opts.Tools = tools
opts.ToolChoice = toolChoice
if logprobs != nil {
opts.Logprobs = int32(*logprobs)
}
if topLogprobs != nil {
opts.TopLogprobs = int32(*topLogprobs)
}
if len(logitBias) > 0 {
// Serialize logit_bias map to JSON string for proto
logitBiasJSON, err := json.Marshal(logitBias)
if err == nil {
opts.LogitBias = string(logitBiasJSON)
}
}
tokenUsage := TokenUsage{} tokenUsage := TokenUsage{}
@@ -124,7 +130,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
} }
ss := "" ss := ""
var logprobs *schema.Logprobs
var partialRune []byte var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) { err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
@@ -136,14 +141,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
// Parse logprobs from reply if present (collect from last chunk that has them)
if len(reply.Logprobs) > 0 {
var parsedLogprobs schema.Logprobs
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
logprobs = &parsedLogprobs
}
}
// Process complete runes and accumulate them // Process complete runes and accumulate them
var completeRunes []byte var completeRunes []byte
for len(partialRune) > 0 { for len(partialRune) > 0 {
@@ -169,7 +166,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
return LLMResponse{ return LLMResponse{
Response: ss, Response: ss,
Usage: tokenUsage, Usage: tokenUsage,
Logprobs: logprobs,
}, err }, err
} else { } else {
// TODO: Is the chicken bit the only way to get here? is that acceptable? // TODO: Is the chicken bit the only way to get here? is that acceptable?
@@ -192,19 +188,9 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
response = c.TemplateConfig.ReplyPrefix + response response = c.TemplateConfig.ReplyPrefix + response
} }
// Parse logprobs from reply if present
var logprobs *schema.Logprobs
if len(reply.Logprobs) > 0 {
var parsedLogprobs schema.Logprobs
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
logprobs = &parsedLogprobs
}
}
return LLMResponse{ return LLMResponse{
Response: response, Response: response,
Usage: tokenUsage, Usage: tokenUsage,
Logprobs: logprobs,
}, err }, err
} }
} }

View File

@@ -212,7 +212,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
} }
} }
pbOpts := &pb.PredictOptions{ return &pb.PredictOptions{
Temperature: float32(*c.Temperature), Temperature: float32(*c.Temperature),
TopP: float32(*c.TopP), TopP: float32(*c.TopP),
NDraft: c.NDraft, NDraft: c.NDraft,
@@ -249,6 +249,4 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
TailFreeSamplingZ: float32(*c.TFZ), TailFreeSamplingZ: float32(*c.TFZ),
TypicalP: float32(*c.TypicalP), TypicalP: float32(*c.TypicalP),
} }
// Logprobs and TopLogprobs are set by the caller if provided
return pbOpts
} }

View File

@@ -1,7 +1,6 @@
package cli package cli
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -103,7 +102,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
} }
modelLoader := model.NewModelLoader(systemState, true) modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallExternalBackends(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias) err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -48,12 +48,10 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
appHTTP := http.Explorer(db) appHTTP := http.Explorer(db)
signals.RegisterGracefulTerminationHandler(func() { signals.RegisterGracefulTerminationHandler(func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) if err := appHTTP.Shutdown(); err != nil {
defer cancel()
if err := appHTTP.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("error during shutdown") log.Error().Err(err).Msg("error during shutdown")
} }
}) })
return appHTTP.Start(e.Address) return appHTTP.Listen(e.Address)
} }

View File

@@ -1,14 +1,12 @@
package cli package cli
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
cliContext "github.com/mudler/LocalAI/core/cli/context" cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/core/startup"
@@ -80,12 +78,6 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
return err return err
} }
galleryService := services.NewGalleryService(&config.ApplicationConfig{}, model.NewModelLoader(systemState, true))
err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState)
if err != nil {
return err
}
var galleries []config.Gallery var galleries []config.Gallery
if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil { if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil {
log.Error().Err(err).Msg("unable to load galleries") log.Error().Err(err).Msg("unable to load galleries")
@@ -135,7 +127,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
} }
modelLoader := model.NewModelLoader(systemState, true) modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName) err = startup.InstallModels(galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -232,5 +232,5 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
} }
}) })
return appHTTP.Start(r.Address) return appHTTP.Listen(r.Address)
} }

View File

@@ -1,7 +1,6 @@
package worker package worker
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -43,7 +42,7 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
log.Error().Err(err).Msg("failed loading galleries") log.Error().Err(err).Msg("failed loading galleries")
return "", err return "", err
} }
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true) err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
if err != nil { if err != nil {
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it") log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
return "", err return "", err

View File

@@ -1,17 +1,151 @@
package config package config
import ( import (
"strings"
"github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/mudler/LocalAI/pkg/xsysinfo"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
gguf "github.com/gpustack/gguf-parser-go" gguf "github.com/gpustack/gguf-parser-go"
) )
type familyType uint8
const (
Unknown familyType = iota
LLaMa3
CommandR
Phi3
ChatML
Mistral03
Gemma
DeepSeek2
)
const ( const (
defaultContextSize = 1024 defaultContextSize = 1024
defaultNGPULayers = 99999999 defaultNGPULayers = 99999999
) )
type settingsConfig struct {
StopWords []string
TemplateConfig TemplateConfig
RepeatPenalty float64
}
// default settings to adopt with a given model family
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
Gemma: {
RepeatPenalty: 1.0,
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
TemplateConfig: TemplateConfig{
Chat: "{{.Input }}\n<start_of_turn>model\n",
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
Completion: "{{.Input}}",
},
},
DeepSeek2: {
StopWords: []string{"<end▁of▁sentence>"},
TemplateConfig: TemplateConfig{
ChatMessage: `{{if eq .RoleName "user" -}}User: {{.Content }}
{{ end -}}
{{if eq .RoleName "assistant" -}}Assistant: {{.Content}}<end▁of▁sentence>{{end}}
{{if eq .RoleName "system" -}}{{.Content}}
{{end -}}`,
Chat: "{{.Input -}}\nAssistant: ",
},
},
LLaMa3: {
StopWords: []string{"<|eot_id|>"},
TemplateConfig: TemplateConfig{
Chat: "<|begin_of_text|>{{.Input }}\n<|start_header_id|>assistant<|end_header_id|>",
ChatMessage: "<|start_header_id|>{{ .RoleName }}<|end_header_id|>\n\n{{.Content }}<|eot_id|>",
},
},
CommandR: {
TemplateConfig: TemplateConfig{
Chat: "{{.Input -}}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
Functions: `<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
You are a function calling AI model, you can call the following functions:
## Available Tools
{{range .Functions}}
- {"type": "function", "function": {"name": "{{.Name}}", "description": "{{.Description}}", "parameters": {{toJson .Parameters}} }}
{{end}}
When using a tool, reply with JSON, for instance {"name": "tool_name", "arguments": {"param1": "value1", "param2": "value2"}}
<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{.Input -}}`,
ChatMessage: `{{if eq .RoleName "user" -}}
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
{{- else if eq .RoleName "system" -}}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
{{- else if eq .RoleName "assistant" -}}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
{{- else if eq .RoleName "tool" -}}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
{{- else if .FunctionCall -}}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{toJson .FunctionCall}}}<|END_OF_TURN_TOKEN|>
{{- end -}}`,
},
StopWords: []string{"<|END_OF_TURN_TOKEN|>"},
},
Phi3: {
TemplateConfig: TemplateConfig{
Chat: "{{.Input}}\n<|assistant|>",
ChatMessage: "<|{{ .RoleName }}|>\n{{.Content}}<|end|>",
Completion: "{{.Input}}",
},
StopWords: []string{"<|end|>", "<|endoftext|>"},
},
ChatML: {
TemplateConfig: TemplateConfig{
Chat: "{{.Input -}}\n<|im_start|>assistant",
Functions: `<|im_start|>system
You are a function calling AI model. You are provided with functions to execute. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:
{{range .Functions}}
{'type': 'function', 'function': {'name': '{{.Name}}', 'description': '{{.Description}}', 'parameters': {{toJson .Parameters}} }}
{{end}}
For each function call return a json object with function name and arguments
<|im_end|>
{{.Input -}}
<|im_start|>assistant`,
ChatMessage: `<|im_start|>{{ .RoleName }}
{{ if .FunctionCall -}}
Function call:
{{ else if eq .RoleName "tool" -}}
Function response:
{{ end -}}
{{ if .Content -}}
{{.Content }}
{{ end -}}
{{ if .FunctionCall -}}
{{toJson .FunctionCall}}
{{ end -}}<|im_end|>`,
},
StopWords: []string{"<|im_end|>", "<dummy32000>", "</s>"},
},
Mistral03: {
TemplateConfig: TemplateConfig{
Chat: "{{.Input -}}",
Functions: `[AVAILABLE_TOOLS] [{{range .Functions}}{"type": "function", "function": {"name": "{{.Name}}", "description": "{{.Description}}", "parameters": {{toJson .Parameters}} }}{{end}} ] [/AVAILABLE_TOOLS]{{.Input }}`,
ChatMessage: `{{if eq .RoleName "user" -}}
[INST] {{.Content }} [/INST]
{{- else if .FunctionCall -}}
[TOOL_CALLS] {{toJson .FunctionCall}} [/TOOL_CALLS]
{{- else if eq .RoleName "tool" -}}
[TOOL_RESULTS] {{.Content}} [/TOOL_RESULTS]
{{- else -}}
{{ .Content -}}
{{ end -}}`,
},
StopWords: []string{"<|im_end|>", "<dummy32000>", "</tool_call>", "<|eot_id|>", "<|end_of_text|>", "</s>", "[/TOOL_CALLS]", "[/ACTIONS]"},
},
}
// this maps well known template used in HF to model families defined above
var knownTemplates = map[string]familyType{
`{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\n' + content + '<|im_end|>\n<|im_start|>assistant\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\n' }}{% endif %}{% endfor %}`: ChatML,
`{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}`: Mistral03,
}
func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) { func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
if defaultCtx == 0 && cfg.ContextSize == nil { if defaultCtx == 0 && cfg.ContextSize == nil {
@@ -82,9 +216,81 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
cfg.Name = f.Metadata().Name cfg.Name = f.Metadata().Name
} }
// Instruct to use template from llama.cpp family := identifyFamily(f)
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.FunctionsConfig.GrammarConfig.NoGrammar = true if family == Unknown {
cfg.Options = append(cfg.Options, "use_jinja:true") log.Debug().Msgf("guessDefaultsFromFile: %s", "family not identified")
cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT") return
}
// identify template
settings, ok := defaultsSettings[family]
if ok {
cfg.TemplateConfig = settings.TemplateConfig
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: guessed template %+v", cfg.TemplateConfig)
if len(cfg.StopWords) == 0 {
cfg.StopWords = settings.StopWords
}
if cfg.RepeatPenalty == 0.0 {
cfg.RepeatPenalty = settings.RepeatPenalty
}
} else {
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family")
}
if cfg.HasTemplate() {
return
}
// identify from well known templates first, otherwise use the raw jinja template
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
if found {
// try to use the jinja template
cfg.TemplateConfig.JinjaTemplate = true
cfg.TemplateConfig.ChatMessage = chatTemplate.ValueString()
}
}
func identifyFamily(f *gguf.GGUFFile) familyType {
// identify from well known templates first
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
if found && chatTemplate.ValueString() != "" {
if family, ok := knownTemplates[chatTemplate.ValueString()]; ok {
return family
}
}
// otherwise try to identify from the model properties
arch := f.Architecture().Architecture
eosTokenID := f.Tokenizer().EOSTokenID
bosTokenID := f.Tokenizer().BOSTokenID
isYI := arch == "llama" && bosTokenID == 1 && eosTokenID == 2
// WTF! Mistral0.3 and isYi have same bosTokenID and eosTokenID
llama3 := arch == "llama" && eosTokenID == 128009
commandR := arch == "command-r" && eosTokenID == 255001
qwen2 := arch == "qwen2"
phi3 := arch == "phi-3"
gemma := strings.HasPrefix(arch, "gemma") || strings.Contains(strings.ToLower(f.Metadata().Name), "gemma")
deepseek2 := arch == "deepseek2"
switch {
case deepseek2:
return DeepSeek2
case gemma:
return Gemma
case llama3:
return LLaMa3
case commandR:
return CommandR
case phi3:
return Phi3
case qwen2, isYI:
return ChatML
default:
return Unknown
}
} }

View File

@@ -9,7 +9,6 @@ import (
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/cogito"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@@ -17,31 +16,30 @@ const (
RAND_SEED = -1 RAND_SEED = -1
) )
// @Description TTS configuration
type TTSConfig struct { type TTSConfig struct {
// Voice wav path or id // Voice wav path or id
Voice string `yaml:"voice,omitempty" json:"voice,omitempty"` Voice string `yaml:"voice" json:"voice"`
AudioPath string `yaml:"audio_path,omitempty" json:"audio_path,omitempty"` AudioPath string `yaml:"audio_path" json:"audio_path"`
} }
// @Description ModelConfig represents a model configuration // ModelConfig represents a model configuration
type ModelConfig struct { type ModelConfig struct {
modelConfigFile string `yaml:"-" json:"-"` modelConfigFile string `yaml:"-" json:"-"`
schema.PredictionOptions `yaml:"parameters,omitempty" json:"parameters,omitempty"` schema.PredictionOptions `yaml:"parameters" json:"parameters"`
Name string `yaml:"name,omitempty" json:"name,omitempty"` Name string `yaml:"name" json:"name"`
F16 *bool `yaml:"f16,omitempty" json:"f16,omitempty"` F16 *bool `yaml:"f16" json:"f16"`
Threads *int `yaml:"threads,omitempty" json:"threads,omitempty"` Threads *int `yaml:"threads" json:"threads"`
Debug *bool `yaml:"debug,omitempty" json:"debug,omitempty"` Debug *bool `yaml:"debug" json:"debug"`
Roles map[string]string `yaml:"roles,omitempty" json:"roles,omitempty"` Roles map[string]string `yaml:"roles" json:"roles"`
Embeddings *bool `yaml:"embeddings,omitempty" json:"embeddings,omitempty"` Embeddings *bool `yaml:"embeddings" json:"embeddings"`
Backend string `yaml:"backend,omitempty" json:"backend,omitempty"` Backend string `yaml:"backend" json:"backend"`
TemplateConfig TemplateConfig `yaml:"template,omitempty" json:"template,omitempty"` TemplateConfig TemplateConfig `yaml:"template" json:"template"`
KnownUsecaseStrings []string `yaml:"known_usecases,omitempty" json:"known_usecases,omitempty"` KnownUsecaseStrings []string `yaml:"known_usecases" json:"known_usecases"`
KnownUsecases *ModelConfigUsecases `yaml:"-" json:"-"` KnownUsecases *ModelConfigUsecases `yaml:"-" json:"-"`
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"` Pipeline Pipeline `yaml:"pipeline" json:"pipeline"`
PromptStrings, InputStrings []string `yaml:"-" json:"-"` PromptStrings, InputStrings []string `yaml:"-" json:"-"`
InputToken [][]int `yaml:"-" json:"-"` InputToken [][]int `yaml:"-" json:"-"`
@@ -49,101 +47,96 @@ type ModelConfig struct {
ResponseFormat string `yaml:"-" json:"-"` ResponseFormat string `yaml:"-" json:"-"`
ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"` ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"`
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"` FunctionsConfig functions.FunctionsConfig `yaml:"function" json:"function"`
FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early. FeatureFlag FeatureFlag `yaml:"feature_flags" json:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...) // LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig `yaml:",inline" json:",inline"` LLMConfig `yaml:",inline" json:",inline"`
// Diffusers // Diffusers
Diffusers Diffusers `yaml:"diffusers,omitempty" json:"diffusers,omitempty"` Diffusers Diffusers `yaml:"diffusers" json:"diffusers"`
Step int `yaml:"step,omitempty" json:"step,omitempty"` Step int `yaml:"step" json:"step"`
// GRPC Options // GRPC Options
GRPC GRPC `yaml:"grpc,omitempty" json:"grpc,omitempty"` GRPC GRPC `yaml:"grpc" json:"grpc"`
// TTS specifics // TTS specifics
TTSConfig `yaml:"tts,omitempty" json:"tts,omitempty"` TTSConfig `yaml:"tts" json:"tts"`
// CUDA // CUDA
// Explicitly enable CUDA or not (some backends might need it) // Explicitly enable CUDA or not (some backends might need it)
CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"` CUDA bool `yaml:"cuda" json:"cuda"`
DownloadFiles []File `yaml:"download_files,omitempty" json:"download_files,omitempty"` DownloadFiles []File `yaml:"download_files" json:"download_files"`
Description string `yaml:"description,omitempty" json:"description,omitempty"` Description string `yaml:"description" json:"description"`
Usage string `yaml:"usage,omitempty" json:"usage,omitempty"` Usage string `yaml:"usage" json:"usage"`
Options []string `yaml:"options,omitempty" json:"options,omitempty"` Options []string `yaml:"options" json:"options"`
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"` Overrides []string `yaml:"overrides" json:"overrides"`
MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"` MCP MCPConfig `yaml:"mcp" json:"mcp"`
Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"` Agent AgentConfig `yaml:"agent" json:"agent"`
} }
// @Description MCP configuration
type MCPConfig struct { type MCPConfig struct {
Servers string `yaml:"remote,omitempty" json:"remote,omitempty"` Servers string `yaml:"remote" json:"remote"`
Stdio string `yaml:"stdio,omitempty" json:"stdio,omitempty"` Stdio string `yaml:"stdio" json:"stdio"`
} }
// @Description Agent configuration
type AgentConfig struct { type AgentConfig struct {
MaxAttempts int `yaml:"max_attempts,omitempty" json:"max_attempts,omitempty"` MaxAttempts int `yaml:"max_attempts" json:"max_attempts"`
MaxIterations int `yaml:"max_iterations,omitempty" json:"max_iterations,omitempty"` MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
EnableReasoning bool `yaml:"enable_reasoning,omitempty" json:"enable_reasoning,omitempty"` EnableReasoning bool `yaml:"enable_reasoning" json:"enable_reasoning"`
EnablePlanning bool `yaml:"enable_planning,omitempty" json:"enable_planning,omitempty"` EnablePlanning bool `yaml:"enable_planning" json:"enable_planning"`
EnableMCPPrompts bool `yaml:"enable_mcp_prompts,omitempty" json:"enable_mcp_prompts,omitempty"` EnableMCPPrompts bool `yaml:"enable_mcp_prompts" json:"enable_mcp_prompts"`
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator,omitempty" json:"enable_plan_re_evaluator,omitempty"` EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator" json:"enable_plan_re_evaluator"`
} }
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) { func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers]) {
var remote MCPGenericConfig[MCPRemoteServers] var remote MCPGenericConfig[MCPRemoteServers]
var stdio MCPGenericConfig[MCPSTDIOServers] var stdio MCPGenericConfig[MCPSTDIOServers]
if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil { if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil {
return remote, stdio, err return remote, stdio
} }
if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil { if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil {
return remote, stdio, err return remote, stdio
} }
return remote, stdio, nil
return remote, stdio
} }
// @Description MCP generic configuration
type MCPGenericConfig[T any] struct { type MCPGenericConfig[T any] struct {
Servers T `yaml:"mcpServers,omitempty" json:"mcpServers,omitempty"` Servers T `yaml:"mcpServers" json:"mcpServers"`
} }
type MCPRemoteServers map[string]MCPRemoteServer type MCPRemoteServers map[string]MCPRemoteServer
type MCPSTDIOServers map[string]MCPSTDIOServer type MCPSTDIOServers map[string]MCPSTDIOServer
// @Description MCP remote server configuration
type MCPRemoteServer struct { type MCPRemoteServer struct {
URL string `json:"url,omitempty"` URL string `json:"url"`
Token string `json:"token,omitempty"` Token string `json:"token"`
} }
// @Description MCP STDIO server configuration
type MCPSTDIOServer struct { type MCPSTDIOServer struct {
Args []string `json:"args,omitempty"` Args []string `json:"args"`
Env map[string]string `json:"env,omitempty"` Env map[string]string `json:"env"`
Command string `json:"command,omitempty"` Command string `json:"command"`
} }
// @Description Pipeline defines other models to use for audio-to-audio // Pipeline defines other models to use for audio-to-audio
type Pipeline struct { type Pipeline struct {
TTS string `yaml:"tts,omitempty" json:"tts,omitempty"` TTS string `yaml:"tts" json:"tts"`
LLM string `yaml:"llm,omitempty" json:"llm,omitempty"` LLM string `yaml:"llm" json:"llm"`
Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"` Transcription string `yaml:"transcription" json:"transcription"`
VAD string `yaml:"vad,omitempty" json:"vad,omitempty"` VAD string `yaml:"vad" json:"vad"`
} }
// @Description File configuration for model downloads
type File struct { type File struct {
Filename string `yaml:"filename,omitempty" json:"filename,omitempty"` Filename string `yaml:"filename" json:"filename"`
SHA256 string `yaml:"sha256,omitempty" json:"sha256,omitempty"` SHA256 string `yaml:"sha256" json:"sha256"`
URI downloader.URI `yaml:"uri,omitempty" json:"uri,omitempty"` URI downloader.URI `yaml:"uri" json:"uri"`
} }
type FeatureFlag map[string]*bool type FeatureFlag map[string]*bool
@@ -155,136 +148,126 @@ func (ff FeatureFlag) Enabled(s string) bool {
return false return false
} }
// @Description GRPC configuration
type GRPC struct { type GRPC struct {
Attempts int `yaml:"attempts,omitempty" json:"attempts,omitempty"` Attempts int `yaml:"attempts" json:"attempts"`
AttemptsSleepTime int `yaml:"attempts_sleep_time,omitempty" json:"attempts_sleep_time,omitempty"` AttemptsSleepTime int `yaml:"attempts_sleep_time" json:"attempts_sleep_time"`
} }
// @Description Diffusers configuration
type Diffusers struct { type Diffusers struct {
CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"` CUDA bool `yaml:"cuda" json:"cuda"`
PipelineType string `yaml:"pipeline_type,omitempty" json:"pipeline_type,omitempty"` PipelineType string `yaml:"pipeline_type" json:"pipeline_type"`
SchedulerType string `yaml:"scheduler_type,omitempty" json:"scheduler_type,omitempty"` SchedulerType string `yaml:"scheduler_type" json:"scheduler_type"`
EnableParameters string `yaml:"enable_parameters,omitempty" json:"enable_parameters,omitempty"` // A list of comma separated parameters to specify EnableParameters string `yaml:"enable_parameters" json:"enable_parameters"` // A list of comma separated parameters to specify
IMG2IMG bool `yaml:"img2img,omitempty" json:"img2img,omitempty"` // Image to Image Diffuser IMG2IMG bool `yaml:"img2img" json:"img2img"` // Image to Image Diffuser
ClipSkip int `yaml:"clip_skip,omitempty" json:"clip_skip,omitempty"` // Skip every N frames ClipSkip int `yaml:"clip_skip" json:"clip_skip"` // Skip every N frames
ClipModel string `yaml:"clip_model,omitempty" json:"clip_model,omitempty"` // Clip model to use ClipModel string `yaml:"clip_model" json:"clip_model"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder,omitempty" json:"clip_subfolder,omitempty"` // Subfolder to use for clip model ClipSubFolder string `yaml:"clip_subfolder" json:"clip_subfolder"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net,omitempty" json:"control_net,omitempty"` ControlNet string `yaml:"control_net" json:"control_net"`
} }
// @Description LLMConfig is a struct that holds the configuration that are generic for most of the LLM backends. // LLMConfig is a struct that holds the configuration that are
// generic for most of the LLM backends.
type LLMConfig struct { type LLMConfig struct {
SystemPrompt string `yaml:"system_prompt,omitempty" json:"system_prompt,omitempty"` SystemPrompt string `yaml:"system_prompt" json:"system_prompt"`
TensorSplit string `yaml:"tensor_split,omitempty" json:"tensor_split,omitempty"` TensorSplit string `yaml:"tensor_split" json:"tensor_split"`
MainGPU string `yaml:"main_gpu,omitempty" json:"main_gpu,omitempty"` MainGPU string `yaml:"main_gpu" json:"main_gpu"`
RMSNormEps float32 `yaml:"rms_norm_eps,omitempty" json:"rms_norm_eps,omitempty"` RMSNormEps float32 `yaml:"rms_norm_eps" json:"rms_norm_eps"`
NGQA int32 `yaml:"ngqa,omitempty" json:"ngqa,omitempty"` NGQA int32 `yaml:"ngqa" json:"ngqa"`
PromptCachePath string `yaml:"prompt_cache_path,omitempty" json:"prompt_cache_path,omitempty"` PromptCachePath string `yaml:"prompt_cache_path" json:"prompt_cache_path"`
PromptCacheAll bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"` PromptCacheAll bool `yaml:"prompt_cache_all" json:"prompt_cache_all"`
PromptCacheRO bool `yaml:"prompt_cache_ro,omitempty" json:"prompt_cache_ro,omitempty"` PromptCacheRO bool `yaml:"prompt_cache_ro" json:"prompt_cache_ro"`
MirostatETA *float64 `yaml:"mirostat_eta,omitempty" json:"mirostat_eta,omitempty"` MirostatETA *float64 `yaml:"mirostat_eta" json:"mirostat_eta"`
MirostatTAU *float64 `yaml:"mirostat_tau,omitempty" json:"mirostat_tau,omitempty"` MirostatTAU *float64 `yaml:"mirostat_tau" json:"mirostat_tau"`
Mirostat *int `yaml:"mirostat,omitempty" json:"mirostat,omitempty"` Mirostat *int `yaml:"mirostat" json:"mirostat"`
NGPULayers *int `yaml:"gpu_layers,omitempty" json:"gpu_layers,omitempty"` NGPULayers *int `yaml:"gpu_layers" json:"gpu_layers"`
MMap *bool `yaml:"mmap,omitempty" json:"mmap,omitempty"` MMap *bool `yaml:"mmap" json:"mmap"`
MMlock *bool `yaml:"mmlock,omitempty" json:"mmlock,omitempty"` MMlock *bool `yaml:"mmlock" json:"mmlock"`
LowVRAM *bool `yaml:"low_vram,omitempty" json:"low_vram,omitempty"` LowVRAM *bool `yaml:"low_vram" json:"low_vram"`
Reranking *bool `yaml:"reranking,omitempty" json:"reranking,omitempty"` Reranking *bool `yaml:"reranking" json:"reranking"`
Grammar string `yaml:"grammar,omitempty" json:"grammar,omitempty"` Grammar string `yaml:"grammar" json:"grammar"`
StopWords []string `yaml:"stopwords,omitempty" json:"stopwords,omitempty"` StopWords []string `yaml:"stopwords" json:"stopwords"`
Cutstrings []string `yaml:"cutstrings,omitempty" json:"cutstrings,omitempty"` Cutstrings []string `yaml:"cutstrings" json:"cutstrings"`
ExtractRegex []string `yaml:"extract_regex,omitempty" json:"extract_regex,omitempty"` ExtractRegex []string `yaml:"extract_regex" json:"extract_regex"`
TrimSpace []string `yaml:"trimspace,omitempty" json:"trimspace,omitempty"` TrimSpace []string `yaml:"trimspace" json:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix,omitempty" json:"trimsuffix,omitempty"` TrimSuffix []string `yaml:"trimsuffix" json:"trimsuffix"`
ContextSize *int `yaml:"context_size,omitempty" json:"context_size,omitempty"` ContextSize *int `yaml:"context_size" json:"context_size"`
NUMA bool `yaml:"numa,omitempty" json:"numa,omitempty"` NUMA bool `yaml:"numa" json:"numa"`
LoraAdapter string `yaml:"lora_adapter,omitempty" json:"lora_adapter,omitempty"` LoraAdapter string `yaml:"lora_adapter" json:"lora_adapter"`
LoraBase string `yaml:"lora_base,omitempty" json:"lora_base,omitempty"` LoraBase string `yaml:"lora_base" json:"lora_base"`
LoraAdapters []string `yaml:"lora_adapters,omitempty" json:"lora_adapters,omitempty"` LoraAdapters []string `yaml:"lora_adapters" json:"lora_adapters"`
LoraScales []float32 `yaml:"lora_scales,omitempty" json:"lora_scales,omitempty"` LoraScales []float32 `yaml:"lora_scales" json:"lora_scales"`
LoraScale float32 `yaml:"lora_scale,omitempty" json:"lora_scale,omitempty"` LoraScale float32 `yaml:"lora_scale" json:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq,omitempty" json:"no_mulmatq,omitempty"` NoMulMatQ bool `yaml:"no_mulmatq" json:"no_mulmatq"`
DraftModel string `yaml:"draft_model,omitempty" json:"draft_model,omitempty"` DraftModel string `yaml:"draft_model" json:"draft_model"`
NDraft int32 `yaml:"n_draft,omitempty" json:"n_draft,omitempty"` NDraft int32 `yaml:"n_draft" json:"n_draft"`
Quantization string `yaml:"quantization,omitempty" json:"quantization,omitempty"` Quantization string `yaml:"quantization" json:"quantization"`
LoadFormat string `yaml:"load_format,omitempty" json:"load_format,omitempty"` LoadFormat string `yaml:"load_format" json:"load_format"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization,omitempty" json:"gpu_memory_utilization,omitempty"` // vLLM GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization" json:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code,omitempty" json:"trust_remote_code,omitempty"` // vLLM TrustRemoteCode bool `yaml:"trust_remote_code" json:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager,omitempty" json:"enforce_eager,omitempty"` // vLLM EnforceEager bool `yaml:"enforce_eager" json:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space,omitempty" json:"swap_space,omitempty"` // vLLM SwapSpace int `yaml:"swap_space" json:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len,omitempty" json:"max_model_len,omitempty"` // vLLM MaxModelLen int `yaml:"max_model_len" json:"max_model_len"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size,omitempty" json:"tensor_parallel_size,omitempty"` // vLLM TensorParallelSize int `yaml:"tensor_parallel_size" json:"tensor_parallel_size"` // vLLM
DisableLogStatus bool `yaml:"disable_log_stats,omitempty" json:"disable_log_stats,omitempty"` // vLLM DisableLogStatus bool `yaml:"disable_log_stats" json:"disable_log_stats"` // vLLM
DType string `yaml:"dtype,omitempty" json:"dtype,omitempty"` // vLLM DType string `yaml:"dtype" json:"dtype"` // vLLM
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt,omitempty" json:"limit_mm_per_prompt,omitempty"` // vLLM LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt" json:"limit_mm_per_prompt"` // vLLM
MMProj string `yaml:"mmproj,omitempty" json:"mmproj,omitempty"` MMProj string `yaml:"mmproj" json:"mmproj"`
FlashAttention *string `yaml:"flash_attention,omitempty" json:"flash_attention,omitempty"` FlashAttention *string `yaml:"flash_attention" json:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading,omitempty" json:"no_kv_offloading,omitempty"` NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"`
CacheTypeK string `yaml:"cache_type_k,omitempty" json:"cache_type_k,omitempty"` CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"`
CacheTypeV string `yaml:"cache_type_v,omitempty" json:"cache_type_v,omitempty"` CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"`
RopeScaling string `yaml:"rope_scaling,omitempty" json:"rope_scaling,omitempty"` RopeScaling string `yaml:"rope_scaling" json:"rope_scaling"`
ModelType string `yaml:"type,omitempty" json:"type,omitempty"` ModelType string `yaml:"type" json:"type"`
YarnExtFactor float32 `yaml:"yarn_ext_factor,omitempty" json:"yarn_ext_factor,omitempty"` YarnExtFactor float32 `yaml:"yarn_ext_factor" json:"yarn_ext_factor"`
YarnAttnFactor float32 `yaml:"yarn_attn_factor,omitempty" json:"yarn_attn_factor,omitempty"` YarnAttnFactor float32 `yaml:"yarn_attn_factor" json:"yarn_attn_factor"`
YarnBetaFast float32 `yaml:"yarn_beta_fast,omitempty" json:"yarn_beta_fast,omitempty"` YarnBetaFast float32 `yaml:"yarn_beta_fast" json:"yarn_beta_fast"`
YarnBetaSlow float32 `yaml:"yarn_beta_slow,omitempty" json:"yarn_beta_slow,omitempty"` YarnBetaSlow float32 `yaml:"yarn_beta_slow" json:"yarn_beta_slow"`
CFGScale float32 `yaml:"cfg_scale,omitempty" json:"cfg_scale,omitempty"` // Classifier-Free Guidance Scale CFGScale float32 `yaml:"cfg_scale" json:"cfg_scale"` // Classifier-Free Guidance Scale
} }
// @Description LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM // LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
type LimitMMPerPrompt struct { type LimitMMPerPrompt struct {
LimitImagePerPrompt int `yaml:"image,omitempty" json:"image,omitempty"` LimitImagePerPrompt int `yaml:"image" json:"image"`
LimitVideoPerPrompt int `yaml:"video,omitempty" json:"video,omitempty"` LimitVideoPerPrompt int `yaml:"video" json:"video"`
LimitAudioPerPrompt int `yaml:"audio,omitempty" json:"audio,omitempty"` LimitAudioPerPrompt int `yaml:"audio" json:"audio"`
} }
// @Description TemplateConfig is a struct that holds the configuration of the templating system // TemplateConfig is a struct that holds the configuration of the templating system
type TemplateConfig struct { type TemplateConfig struct {
// Chat is the template used in the chat completion endpoint // Chat is the template used in the chat completion endpoint
Chat string `yaml:"chat,omitempty" json:"chat,omitempty"` Chat string `yaml:"chat" json:"chat"`
// ChatMessage is the template used for chat messages // ChatMessage is the template used for chat messages
ChatMessage string `yaml:"chat_message,omitempty" json:"chat_message,omitempty"` ChatMessage string `yaml:"chat_message" json:"chat_message"`
// Completion is the template used for completion requests // Completion is the template used for completion requests
Completion string `yaml:"completion,omitempty" json:"completion,omitempty"` Completion string `yaml:"completion" json:"completion"`
// Edit is the template used for edit completion requests // Edit is the template used for edit completion requests
Edit string `yaml:"edit,omitempty" json:"edit,omitempty"` Edit string `yaml:"edit" json:"edit"`
// Functions is the template used when tools are present in the client requests // Functions is the template used when tools are present in the client requests
Functions string `yaml:"function,omitempty" json:"function,omitempty"` Functions string `yaml:"function" json:"function"`
// UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used. // UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used.
// Note: this is mostly consumed for backends such as vllm and transformers // Note: this is mostly consumed for backends such as vllm and transformers
// that can use the tokenizers specified in the JSON config files of the models // that can use the tokenizers specified in the JSON config files of the models
UseTokenizerTemplate bool `yaml:"use_tokenizer_template,omitempty" json:"use_tokenizer_template,omitempty"` UseTokenizerTemplate bool `yaml:"use_tokenizer_template" json:"use_tokenizer_template"`
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together. // JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
// It defaults to \n // It defaults to \n
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character,omitempty" json:"join_chat_messages_by_character,omitempty"` JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character" json:"join_chat_messages_by_character"`
Multimodal string `yaml:"multimodal,omitempty" json:"multimodal,omitempty"` Multimodal string `yaml:"multimodal" json:"multimodal"`
ReplyPrefix string `yaml:"reply_prefix,omitempty" json:"reply_prefix,omitempty"` JinjaTemplate bool `yaml:"jinja_template" json:"jinja_template"`
}
func (c *ModelConfig) syncKnownUsecasesFromString() { ReplyPrefix string `yaml:"reply_prefix" json:"reply_prefix"`
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
// Make sure the usecases are valid, we rewrite with what we identified
c.KnownUsecaseStrings = []string{}
for k, usecase := range GetAllModelConfigUsecases() {
if c.HasUsecases(usecase) {
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
}
}
} }
func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error { func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
@@ -295,7 +278,14 @@ func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
} }
*c = ModelConfig(aux) *c = ModelConfig(aux)
c.syncKnownUsecasesFromString() c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
// Make sure the usecases are valid, we rewrite with what we identified
c.KnownUsecaseStrings = []string{}
for k, usecase := range GetAllModelConfigUsecases() {
if c.HasUsecases(usecase) {
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
}
}
return nil return nil
} }
@@ -472,7 +462,6 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
} }
guessDefaultsFromFile(cfg, lo.modelPath, ctx) guessDefaultsFromFile(cfg, lo.modelPath, ctx)
cfg.syncKnownUsecasesFromString()
} }
func (c *ModelConfig) Validate() bool { func (c *ModelConfig) Validate() bool {
@@ -503,7 +492,7 @@ func (c *ModelConfig) Validate() bool {
} }
func (c *ModelConfig) HasTemplate() bool { func (c *ModelConfig) HasTemplate() bool {
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" || c.TemplateConfig.UseTokenizerTemplate return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != ""
} }
func (c *ModelConfig) GetModelConfigFile() string { func (c *ModelConfig) GetModelConfigFile() string {
@@ -584,7 +573,7 @@ func (c *ModelConfig) HasUsecases(u ModelConfigUsecases) bool {
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently. // This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool { func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
if (u & FLAG_CHAT) == FLAG_CHAT { if (u & FLAG_CHAT) == FLAG_CHAT {
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate { if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" {
return false return false
} }
} }
@@ -669,40 +658,3 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
return true return true
} }
// BuildCogitoOptions generates cogito options from the model configuration
// It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results
func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
cogitoOpts := []cogito.Option{
cogito.WithIterations(3), // default to 3 iterations
cogito.WithMaxAttempts(3), // default to 3 attempts
cogito.WithForceReasoning(),
}
// Apply agent configuration options
if c.Agent.EnableReasoning {
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
}
if c.Agent.EnablePlanning {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
}
if c.Agent.EnableMCPPrompts {
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
}
if c.Agent.EnablePlanReEvaluator {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
}
if c.Agent.MaxIterations != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations))
}
if c.Agent.MaxAttempts != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts))
}
return cogitoOpts
}

View File

@@ -3,9 +3,7 @@
package gallery package gallery
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -70,7 +68,7 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
} }
// InstallBackendFromGallery installs a backend from the gallery. // InstallBackendFromGallery installs a backend from the gallery.
func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error { func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
if !force { if !force {
// check if we already have the backend installed // check if we already have the backend installed
backends, err := ListSystemBackends(systemState) backends, err := ListSystemBackends(systemState)
@@ -110,7 +108,7 @@ func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery,
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend") log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
// Then, let's install the best backend // Then, let's install the best backend
if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil { if err := InstallBackend(systemState, modelLoader, bestBackend, downloadStatus); err != nil {
return err return err
} }
@@ -135,10 +133,10 @@ func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery,
return nil return nil
} }
return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus) return InstallBackend(systemState, modelLoader, backend, downloadStatus)
} }
func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error { func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist // Create base path if it doesn't exist
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750) err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
if err != nil { if err != nil {
@@ -165,17 +163,11 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
} }
} else { } else {
uri := downloader.URI(config.URI) uri := downloader.URI(config.URI)
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil { if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
success := false success := false
// Try to download from mirrors // Try to download from mirrors
for _, mirror := range config.Mirrors { for _, mirror := range config.Mirrors {
// Check for cancellation before trying next mirror if err := downloader.URI(mirror).DownloadFile(backendPath, "", 1, 1, downloadStatus); err == nil {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
success = true success = true
break break
} }
@@ -318,10 +310,8 @@ func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error)
} }
} }
} }
} else if !errors.Is(err, os.ErrNotExist) { } else {
log.Warn().Err(err).Msg("Failed to read system backends, proceeding with user-managed backends") log.Warn().Err(err).Msg("Failed to read system backends, proceeding with user-managed backends")
} else if errors.Is(err, os.ErrNotExist) {
log.Debug().Msg("No system backends found")
} }
// User-managed backends and alias collection // User-managed backends and alias collection

View File

@@ -1,7 +1,6 @@
package gallery package gallery
import ( import (
"context"
"encoding/json" "encoding/json"
"os" "os"
"path/filepath" "path/filepath"
@@ -117,13 +116,13 @@ var _ = Describe("Gallery Backends", func() {
Describe("InstallBackendFromGallery", func() { Describe("InstallBackendFromGallery", func() {
It("should return error when backend is not found", func() { It("should return error when backend is not found", func() {
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true) err := InstallBackendFromGallery(galleries, systemState, ml, "non-existent", nil, true)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\"")) Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
}) })
It("should install backend from gallery", func() { It("should install backend from gallery", func() {
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true) err := InstallBackendFromGallery(galleries, systemState, ml, "test-backend", nil, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile()) Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
}) })
@@ -299,7 +298,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000, VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir}, Backend: system.Backend{BackendsPath: tempDir},
} }
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true) err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend") metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -379,7 +378,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000, VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir}, Backend: system.Backend{BackendsPath: tempDir},
} }
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true) err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend") metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -463,7 +462,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000, VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir}, Backend: system.Backend{BackendsPath: tempDir},
} }
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true) err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend") metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -562,7 +561,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(newPath), system.WithBackendPath(newPath),
) )
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil) err = InstallBackend(systemState, ml, &backend, nil)
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
Expect(newPath).To(BeADirectory()) Expect(newPath).To(BeADirectory())
}) })
@@ -594,7 +593,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir), system.WithBackendPath(tempDir),
) )
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil) err = InstallBackend(systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile()) Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json")) dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
@@ -627,7 +626,7 @@ var _ = Describe("Gallery Backends", func() {
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile()) Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil) err = InstallBackend(systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile()) Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
}) })
@@ -648,7 +647,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir), system.WithBackendPath(tempDir),
) )
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil) err = InstallBackend(systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile()) Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())

View File

@@ -1,7 +1,6 @@
package gallery package gallery
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -29,19 +28,6 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
return config, nil return config, nil
} }
func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
var config T
uri := downloader.URI(url)
err := uri.DownloadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url")
return config, err
}
return config, nil
}
func ReadConfigFile[T any](filePath string) (*T, error) { func ReadConfigFile[T any](filePath string) (*T, error) {
// Read the YAML file // Read the YAML file
yamlFile, err := os.ReadFile(filePath) yamlFile, err := os.ReadFile(filePath)
@@ -75,15 +61,12 @@ func (gm GalleryElements[T]) Search(term string) GalleryElements[T] {
term = strings.ToLower(term) term = strings.ToLower(term)
for _, m := range gm { for _, m := range gm {
if fuzzy.Match(term, strings.ToLower(m.GetName())) || if fuzzy.Match(term, strings.ToLower(m.GetName())) ||
fuzzy.Match(term, strings.ToLower(m.GetDescription())) ||
fuzzy.Match(term, strings.ToLower(m.GetGallery().Name)) || fuzzy.Match(term, strings.ToLower(m.GetGallery().Name)) ||
strings.Contains(strings.ToLower(m.GetName()), term) ||
strings.Contains(strings.ToLower(m.GetDescription()), term) ||
strings.Contains(strings.ToLower(m.GetGallery().Name), term) ||
strings.Contains(strings.ToLower(strings.Join(m.GetTags(), ",")), term) { strings.Contains(strings.ToLower(strings.Join(m.GetTags(), ",")), term) {
filteredModels = append(filteredModels, m) filteredModels = append(filteredModels, m)
} }
} }
return filteredModels return filteredModels
} }

View File

@@ -1,67 +0,0 @@
package importers
import (
"encoding/json"
"strings"
"github.com/rs/zerolog/log"
"github.com/mudler/LocalAI/core/gallery"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
)
var defaultImporters = []Importer{
&LlamaCPPImporter{},
&MLXImporter{},
&VLLMImporter{},
&TransformersImporter{},
}
type Details struct {
HuggingFace *hfapi.ModelDetails
URI string
Preferences json.RawMessage
}
type Importer interface {
Match(details Details) bool
Import(details Details) (gallery.ModelConfig, error)
}
func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.ModelConfig, error) {
var err error
var modelConfig gallery.ModelConfig
hf := hfapi.NewClient()
hfrepoID := strings.ReplaceAll(uri, "huggingface://", "")
hfrepoID = strings.ReplaceAll(hfrepoID, "hf://", "")
hfrepoID = strings.ReplaceAll(hfrepoID, "https://huggingface.co/", "")
hfDetails, err := hf.GetModelDetails(hfrepoID)
if err != nil {
// maybe not a HF repository
// TODO: maybe we can check if the URI is a valid HF repository
log.Debug().Str("uri", uri).Msg("Failed to get model details, maybe not a HF repository")
} else {
log.Debug().Str("uri", uri).Msg("Got model details")
log.Debug().Any("details", hfDetails).Msg("Model details")
}
details := Details{
HuggingFace: hfDetails,
URI: uri,
Preferences: preferences,
}
for _, importer := range defaultImporters {
if importer.Match(details) {
modelConfig, err = importer.Import(details)
if err != nil {
continue
}
break
}
}
return modelConfig, err
}

View File

@@ -1,13 +0,0 @@
package importers_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestImporters(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Importers test suite")
}

View File

@@ -1,215 +0,0 @@
package importers_test
import (
"encoding/json"
"fmt"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("DiscoverModelConfig", func() {
Context("With only a repository URI", func() {
It("should discover and import using LlamaCPPImporter", func() {
uri := "https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
Expect(modelConfig.Name).To(Equal("LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].Filename).To(Equal("localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF/resolve/main/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].SHA256).To(Equal("4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4"), fmt.Sprintf("Model config: %+v", modelConfig))
})
It("should discover and import using LlamaCPPImporter", func() {
uri := "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: mmproj/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].Filename).To(Equal("Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[1].Filename).To(Equal("mmproj/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
})
It("should discover and import using LlamaCPPImporter", func() {
uri := "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"
preferences := json.RawMessage(`{ "quantizations": "Q8_0", "mmproj_quantizations": "f16" }`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: mmproj/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].Filename).To(Equal("Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[1].Filename).To(Equal("mmproj/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
})
})
Context("with .gguf URI", func() {
It("should discover and import using LlamaCPPImporter", func() {
uri := "https://example.com/my-model.gguf"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("my-model.gguf"))
Expect(modelConfig.Description).To(Equal("Imported from https://example.com/my-model.gguf"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
})
It("should use custom preferences when provided", func() {
uri := "https://example.com/my-model.gguf"
preferences := json.RawMessage(`{"name": "custom-name", "description": "Custom description"}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-name"))
Expect(modelConfig.Description).To(Equal("Custom description"))
})
})
Context("with mlx-community URI", func() {
It("should discover and import using MLXImporter", func() {
uri := "https://huggingface.co/mlx-community/test-model"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("test-model"))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mlx-community/test-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
})
It("should use custom preferences when provided", func() {
uri := "https://huggingface.co/mlx-community/test-model"
preferences := json.RawMessage(`{"name": "custom-mlx", "description": "Custom MLX description"}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-mlx"))
Expect(modelConfig.Description).To(Equal("Custom MLX description"))
})
})
Context("with backend preference", func() {
It("should use llama-cpp backend when specified", func() {
uri := "https://example.com/model"
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
})
It("should use mlx backend when specified", func() {
uri := "https://example.com/model"
preferences := json.RawMessage(`{"backend": "mlx"}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
})
It("should use mlx-vlm backend when specified", func() {
uri := "https://example.com/model"
preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx-vlm"))
})
})
Context("with HuggingFace URI formats", func() {
It("should handle huggingface:// prefix", func() {
uri := "huggingface://mlx-community/test-model"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("test-model"))
})
It("should handle hf:// prefix", func() {
uri := "hf://mlx-community/test-model"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("test-model"))
})
It("should handle https://huggingface.co/ prefix", func() {
uri := "https://huggingface.co/mlx-community/test-model"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("test-model"))
})
})
Context("with invalid or non-matching URI", func() {
It("should return error when no importer matches", func() {
uri := "https://example.com/unknown-model.bin"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
// When no importer matches, the function returns empty config and error
// The exact behavior depends on implementation, but typically an error is returned
Expect(modelConfig.Name).To(BeEmpty())
Expect(err).To(HaveOccurred())
})
})
Context("with invalid JSON preferences", func() {
It("should return error when JSON is invalid even if URI matches", func() {
uri := "https://example.com/model.gguf"
preferences := json.RawMessage(`invalid json`)
// Even though Match() returns true for .gguf extension,
// Import() will fail when trying to unmarshal invalid JSON preferences
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).To(HaveOccurred())
Expect(modelConfig.Name).To(BeEmpty())
})
})
})

View File

@@ -1,209 +0,0 @@
package importers
import (
"encoding/json"
"path/filepath"
"slices"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"go.yaml.in/yaml/v2"
)
var _ Importer = &LlamaCPPImporter{}
type LlamaCPPImporter struct{}
func (i *LlamaCPPImporter) Match(details Details) bool {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return false
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return false
}
if preferencesMap["backend"] == "llama-cpp" {
return true
}
if strings.HasSuffix(details.URI, ".gguf") {
return true
}
if details.HuggingFace != nil {
for _, file := range details.HuggingFace.Files {
if strings.HasSuffix(file.Path, ".gguf") {
return true
}
}
}
return false
}
func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return gallery.ModelConfig{}, err
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return gallery.ModelConfig{}, err
}
name, ok := preferencesMap["name"].(string)
if !ok {
name = filepath.Base(details.URI)
}
description, ok := preferencesMap["description"].(string)
if !ok {
description = "Imported from " + details.URI
}
preferedQuantizations, _ := preferencesMap["quantizations"].(string)
quants := []string{"q4_k_m"}
if preferedQuantizations != "" {
quants = strings.Split(preferedQuantizations, ",")
}
mmprojQuants, _ := preferencesMap["mmproj_quantizations"].(string)
mmprojQuantsList := []string{"fp16"}
if mmprojQuants != "" {
mmprojQuantsList = strings.Split(mmprojQuants, ",")
}
embeddings, _ := preferencesMap["embeddings"].(string)
modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Options: []string{"use_jinja:true"},
Backend: "llama-cpp",
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
},
FunctionsConfig: functions.FunctionsConfig{
GrammarConfig: functions.GrammarConfig{
NoGrammar: true,
},
},
}
if embeddings != "" && strings.ToLower(embeddings) == "true" || strings.ToLower(embeddings) == "yes" {
trueV := true
modelConfig.Embeddings = &trueV
}
cfg := gallery.ModelConfig{
Name: name,
Description: description,
}
if strings.HasSuffix(details.URI, ".gguf") {
cfg.Files = append(cfg.Files, gallery.File{
URI: details.URI,
Filename: filepath.Base(details.URI),
})
modelConfig.PredictionOptions = schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: filepath.Base(details.URI),
},
}
} else if details.HuggingFace != nil {
// We want to:
// Get first the chosen quants that match filenames
// OR the first mmproj/gguf file found
var lastMMProjFile *gallery.File
var lastGGUFFile *gallery.File
foundPreferedQuant := false
foundPreferedMMprojQuant := false
for _, file := range details.HuggingFace.Files {
// Get the mmproj prefered quants
if strings.Contains(strings.ToLower(file.Path), "mmproj") {
lastMMProjFile = &gallery.File{
URI: file.URL,
Filename: filepath.Join("mmproj", filepath.Base(file.Path)),
SHA256: file.SHA256,
}
if slices.ContainsFunc(mmprojQuantsList, func(quant string) bool {
return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
}) {
cfg.Files = append(cfg.Files, *lastMMProjFile)
foundPreferedMMprojQuant = true
}
} else if strings.HasSuffix(strings.ToLower(file.Path), "gguf") {
lastGGUFFile = &gallery.File{
URI: file.URL,
Filename: filepath.Base(file.Path),
SHA256: file.SHA256,
}
// get the files of the prefered quants
if slices.ContainsFunc(quants, func(quant string) bool {
return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
}) {
foundPreferedQuant = true
cfg.Files = append(cfg.Files, *lastGGUFFile)
}
}
}
// Make sure to add at least one file if not already present (which is the latest one)
if lastMMProjFile != nil && !foundPreferedMMprojQuant {
if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
return f.Filename == lastMMProjFile.Filename
}) {
cfg.Files = append(cfg.Files, *lastMMProjFile)
}
}
if lastGGUFFile != nil && !foundPreferedQuant {
if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
return f.Filename == lastGGUFFile.Filename
}) {
cfg.Files = append(cfg.Files, *lastGGUFFile)
}
}
// Find first mmproj file and configure it in the config file
for _, file := range cfg.Files {
if !strings.Contains(strings.ToLower(file.Filename), "mmproj") {
continue
}
modelConfig.MMProj = file.Filename
break
}
// Find first non-mmproj file and configure it in the config file
for _, file := range cfg.Files {
if strings.Contains(strings.ToLower(file.Filename), "mmproj") {
continue
}
modelConfig.PredictionOptions = schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: file.Filename,
},
}
break
}
}
data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err
}
cfg.ConfigFile = string(data)
return cfg, nil
}

View File

@@ -1,132 +0,0 @@
package importers_test
import (
"encoding/json"
"fmt"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("LlamaCPPImporter", func() {
var importer *LlamaCPPImporter
BeforeEach(func() {
importer = &LlamaCPPImporter{}
})
Context("Match", func() {
It("should match when URI ends with .gguf", func() {
details := Details{
URI: "https://example.com/model.gguf",
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when backend preference is llama-cpp", func() {
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should not match when URI does not end with .gguf and no backend preference", func() {
details := Details{
URI: "https://example.com/model.bin",
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should not match when backend preference is different", func() {
preferences := json.RawMessage(`{"backend": "mlx"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should return false when JSON preferences are invalid", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://example.com/model.gguf",
Preferences: preferences,
}
// Invalid JSON causes Match to return false early
result := importer.Match(details)
Expect(result).To(BeFalse())
})
})
Context("Import", func() {
It("should import model config with default name and description", func() {
details := Details{
URI: "https://example.com/my-model.gguf",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("my-model.gguf"))
Expect(modelConfig.Description).To(Equal("Imported from https://example.com/my-model.gguf"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
})
It("should import model config with custom name and description from preferences", func() {
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
details := Details{
URI: "https://example.com/my-model.gguf",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-model"))
Expect(modelConfig.Description).To(Equal("Custom description"))
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
})
It("should handle invalid JSON preferences", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://example.com/my-model.gguf",
Preferences: preferences,
}
_, err := importer.Import(details)
Expect(err).To(HaveOccurred())
})
It("should extract filename correctly from URI with path", func() {
details := importers.Details{
URI: "https://example.com/path/to/model.gguf",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/path/to/model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].Filename).To(Equal("model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
})
})
})

View File

@@ -1,94 +0,0 @@
package importers
import (
"encoding/json"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"go.yaml.in/yaml/v2"
)
var _ Importer = &MLXImporter{}
type MLXImporter struct{}
func (i *MLXImporter) Match(details Details) bool {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return false
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return false
}
b, ok := preferencesMap["backend"].(string)
if ok && b == "mlx" || b == "mlx-vlm" {
return true
}
// All https://huggingface.co/mlx-community/*
if strings.Contains(details.URI, "mlx-community/") {
return true
}
return false
}
func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return gallery.ModelConfig{}, err
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return gallery.ModelConfig{}, err
}
name, ok := preferencesMap["name"].(string)
if !ok {
name = filepath.Base(details.URI)
}
description, ok := preferencesMap["description"].(string)
if !ok {
description = "Imported from " + details.URI
}
backend := "mlx"
b, ok := preferencesMap["backend"].(string)
if ok {
backend = b
}
modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Backend: backend,
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: details.URI,
},
},
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
},
}
data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err
}
return gallery.ModelConfig{
Name: name,
Description: description,
ConfigFile: string(data),
}, nil
}

View File

@@ -1,147 +0,0 @@
package importers_test
import (
"encoding/json"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("MLXImporter", func() {
var importer *importers.MLXImporter
BeforeEach(func() {
importer = &importers.MLXImporter{}
})
Context("Match", func() {
It("should match when URI contains mlx-community/", func() {
details := importers.Details{
URI: "https://huggingface.co/mlx-community/test-model",
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when backend preference is mlx", func() {
preferences := json.RawMessage(`{"backend": "mlx"}`)
details := importers.Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when backend preference is mlx-vlm", func() {
preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
details := importers.Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should not match when URI does not contain mlx-community/ and no backend preference", func() {
details := importers.Details{
URI: "https://huggingface.co/other-org/test-model",
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should not match when backend preference is different", func() {
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
details := importers.Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should return false when JSON preferences are invalid", func() {
preferences := json.RawMessage(`invalid json`)
details := importers.Details{
URI: "https://huggingface.co/mlx-community/test-model",
Preferences: preferences,
}
// Invalid JSON causes Match to return false early
result := importer.Match(details)
Expect(result).To(BeFalse())
})
})
Context("Import", func() {
It("should import model config with default name and description", func() {
details := importers.Details{
URI: "https://huggingface.co/mlx-community/test-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("test-model"))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mlx-community/test-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/mlx-community/test-model"))
})
It("should import model config with custom name and description from preferences", func() {
preferences := json.RawMessage(`{"name": "custom-mlx-model", "description": "Custom MLX description"}`)
details := importers.Details{
URI: "https://huggingface.co/mlx-community/test-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-mlx-model"))
Expect(modelConfig.Description).To(Equal("Custom MLX description"))
})
It("should use custom backend from preferences", func() {
preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
details := importers.Details{
URI: "https://huggingface.co/mlx-community/test-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx-vlm"))
})
It("should handle invalid JSON preferences", func() {
preferences := json.RawMessage(`invalid json`)
details := importers.Details{
URI: "https://huggingface.co/mlx-community/test-model",
Preferences: preferences,
}
_, err := importer.Import(details)
Expect(err).To(HaveOccurred())
})
It("should extract filename correctly from URI with path", func() {
details := importers.Details{
URI: "https://huggingface.co/mlx-community/path/to/model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("model"))
})
})
})

View File

@@ -1,110 +0,0 @@
package importers
import (
"encoding/json"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"go.yaml.in/yaml/v2"
)
var _ Importer = &TransformersImporter{}
type TransformersImporter struct{}
func (i *TransformersImporter) Match(details Details) bool {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return false
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return false
}
b, ok := preferencesMap["backend"].(string)
if ok && b == "transformers" {
return true
}
if details.HuggingFace != nil {
for _, file := range details.HuggingFace.Files {
if strings.Contains(file.Path, "tokenizer.json") ||
strings.Contains(file.Path, "tokenizer_config.json") {
return true
}
}
}
return false
}
func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, error) {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return gallery.ModelConfig{}, err
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return gallery.ModelConfig{}, err
}
name, ok := preferencesMap["name"].(string)
if !ok {
name = filepath.Base(details.URI)
}
description, ok := preferencesMap["description"].(string)
if !ok {
description = "Imported from " + details.URI
}
backend := "transformers"
b, ok := preferencesMap["backend"].(string)
if ok {
backend = b
}
modelType, ok := preferencesMap["type"].(string)
if !ok {
modelType = "AutoModelForCausalLM"
}
quantization, ok := preferencesMap["quantization"].(string)
if !ok {
quantization = ""
}
modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Backend: backend,
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: details.URI,
},
},
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
},
}
modelConfig.ModelType = modelType
modelConfig.Quantization = quantization
data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err
}
return gallery.ModelConfig{
Name: name,
Description: description,
ConfigFile: string(data),
}, nil
}

View File

@@ -1,219 +0,0 @@
package importers_test
import (
"encoding/json"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/mudler/LocalAI/core/gallery/importers"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("TransformersImporter", func() {
var importer *TransformersImporter
BeforeEach(func() {
importer = &TransformersImporter{}
})
Context("Match", func() {
It("should match when backend preference is transformers", func() {
preferences := json.RawMessage(`{"backend": "transformers"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when HuggingFace details contain tokenizer.json", func() {
hfDetails := &hfapi.ModelDetails{
Files: []hfapi.ModelFile{
{Path: "tokenizer.json"},
},
}
details := Details{
URI: "https://huggingface.co/test/model",
HuggingFace: hfDetails,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when HuggingFace details contain tokenizer_config.json", func() {
hfDetails := &hfapi.ModelDetails{
Files: []hfapi.ModelFile{
{Path: "tokenizer_config.json"},
},
}
details := Details{
URI: "https://huggingface.co/test/model",
HuggingFace: hfDetails,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should not match when URI has no tokenizer files and no backend preference", func() {
details := Details{
URI: "https://example.com/model.bin",
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should not match when backend preference is different", func() {
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should return false when JSON preferences are invalid", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
})
Context("Import", func() {
It("should import model config with default name and description", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("my-model"))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
})
It("should import model config with custom name and description from preferences", func() {
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-model"))
Expect(modelConfig.Description).To(Equal("Custom description"))
})
It("should use custom model type from preferences", func() {
preferences := json.RawMessage(`{"type": "SentenceTransformer"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: SentenceTransformer"))
})
It("should use default model type when not specified", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
})
It("should use custom backend from preferences", func() {
preferences := json.RawMessage(`{"backend": "transformers"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
})
It("should use quantization from preferences", func() {
preferences := json.RawMessage(`{"quantization": "int8"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("quantization: int8"))
})
It("should handle invalid JSON preferences", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
_, err := importer.Import(details)
Expect(err).To(HaveOccurred())
})
It("should extract filename correctly from URI with path", func() {
details := importers.Details{
URI: "https://huggingface.co/test/path/to/model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("model"))
})
It("should include use_tokenizer_template in config", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
})
It("should include known_usecases in config", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
})
})
})

View File

@@ -1,98 +0,0 @@
package importers
import (
"encoding/json"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"go.yaml.in/yaml/v2"
)
var _ Importer = &VLLMImporter{}
type VLLMImporter struct{}
func (i *VLLMImporter) Match(details Details) bool {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return false
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return false
}
b, ok := preferencesMap["backend"].(string)
if ok && b == "vllm" {
return true
}
if details.HuggingFace != nil {
for _, file := range details.HuggingFace.Files {
if strings.Contains(file.Path, "tokenizer.json") ||
strings.Contains(file.Path, "tokenizer_config.json") {
return true
}
}
}
return false
}
func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return gallery.ModelConfig{}, err
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return gallery.ModelConfig{}, err
}
name, ok := preferencesMap["name"].(string)
if !ok {
name = filepath.Base(details.URI)
}
description, ok := preferencesMap["description"].(string)
if !ok {
description = "Imported from " + details.URI
}
backend := "vllm"
b, ok := preferencesMap["backend"].(string)
if ok {
backend = b
}
modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Backend: backend,
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: details.URI,
},
},
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
},
}
data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err
}
return gallery.ModelConfig{
Name: name,
Description: description,
ConfigFile: string(data),
}, nil
}

View File

@@ -1,181 +0,0 @@
package importers_test
import (
"encoding/json"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/mudler/LocalAI/core/gallery/importers"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("VLLMImporter", func() {
var importer *VLLMImporter
BeforeEach(func() {
importer = &VLLMImporter{}
})
Context("Match", func() {
It("should match when backend preference is vllm", func() {
preferences := json.RawMessage(`{"backend": "vllm"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when HuggingFace details contain tokenizer.json", func() {
hfDetails := &hfapi.ModelDetails{
Files: []hfapi.ModelFile{
{Path: "tokenizer.json"},
},
}
details := Details{
URI: "https://huggingface.co/test/model",
HuggingFace: hfDetails,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when HuggingFace details contain tokenizer_config.json", func() {
hfDetails := &hfapi.ModelDetails{
Files: []hfapi.ModelFile{
{Path: "tokenizer_config.json"},
},
}
details := Details{
URI: "https://huggingface.co/test/model",
HuggingFace: hfDetails,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should not match when URI has no tokenizer files and no backend preference", func() {
details := Details{
URI: "https://example.com/model.bin",
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should not match when backend preference is different", func() {
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should return false when JSON preferences are invalid", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
})
Context("Import", func() {
It("should import model config with default name and description", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("my-model"))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
})
It("should import model config with custom name and description from preferences", func() {
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-model"))
Expect(modelConfig.Description).To(Equal("Custom description"))
})
It("should use custom backend from preferences", func() {
preferences := json.RawMessage(`{"backend": "vllm"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
})
It("should handle invalid JSON preferences", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
_, err := importer.Import(details)
Expect(err).To(HaveOccurred())
})
It("should extract filename correctly from URI with path", func() {
details := importers.Details{
URI: "https://huggingface.co/test/path/to/model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("model"))
})
It("should include use_tokenizer_template in config", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
})
It("should include known_usecases in config", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
})
})
})

View File

@@ -1,7 +1,6 @@
package gallery package gallery
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"os" "os"
@@ -73,7 +72,6 @@ type PromptTemplate struct {
// Installs a model from the gallery // Installs a model from the gallery
func InstallModelFromGallery( func InstallModelFromGallery(
ctx context.Context,
modelGalleries, backendGalleries []config.Gallery, modelGalleries, backendGalleries []config.Gallery,
systemState *system.SystemState, systemState *system.SystemState,
modelLoader *model.ModelLoader, modelLoader *model.ModelLoader,
@@ -86,7 +84,7 @@ func InstallModelFromGallery(
if len(model.URL) > 0 { if len(model.URL) > 0 {
var err error var err error
config, err = GetGalleryConfigFromURLWithContext[ModelConfig](ctx, model.URL, systemState.Model.ModelsPath) config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath)
if err != nil { if err != nil {
return err return err
} }
@@ -127,7 +125,7 @@ func InstallModelFromGallery(
return err return err
} }
installedModel, err := InstallModel(ctx, systemState, installName, &config, model.Overrides, downloadStatus, enforceScan) installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
if err != nil { if err != nil {
return err return err
} }
@@ -135,7 +133,7 @@ func InstallModelFromGallery(
if automaticallyInstallBackend && installedModel.Backend != "" { if automaticallyInstallBackend && installedModel.Backend != "" {
log.Debug().Msgf("Installing backend %q", installedModel.Backend) log.Debug().Msgf("Installing backend %q", installedModel.Backend)
if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil { if err := InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
return err return err
} }
} }
@@ -156,7 +154,7 @@ func InstallModelFromGallery(
return applyModel(model) return applyModel(model)
} }
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) { func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
basePath := systemState.Model.ModelsPath basePath := systemState.Model.ModelsPath
// Create base path if it doesn't exist // Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0750) err := os.MkdirAll(basePath, 0750)
@@ -170,13 +168,6 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
// Download files and verify their SHA // Download files and verify their SHA
for i, file := range config.Files { for i, file := range config.Files {
// Check for cancellation before each file
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, basePath); err != nil { if err := utils.VerifyPath(file.Filename, basePath); err != nil {
@@ -194,7 +185,7 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
} }
} }
uri := downloader.URI(file.URI) uri := downloader.URI(file.URI)
if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil { if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
return nil, err return nil, err
} }
} }

View File

@@ -1,7 +1,6 @@
package gallery_test package gallery_test
import ( import (
"context"
"errors" "errors"
"os" "os"
"path/filepath" "path/filepath"
@@ -35,7 +34,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir), system.WithModelPath(tempdir),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) _, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
@@ -89,7 +88,7 @@ var _ = Describe("Model test", func() {
Expect(models[0].URL).To(Equal(bertEmbeddingsURL)) Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
Expect(models[0].Installed).To(BeFalse()) Expect(models[0].Installed).To(BeFalse())
err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true) err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml")) dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
@@ -130,7 +129,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir), system.WithModelPath(tempdir),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) _, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -150,7 +149,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir), system.WithModelPath(tempdir),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true) _, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -180,7 +179,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir), system.WithModelPath(tempdir),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) _, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
}) })
}) })

View File

@@ -4,23 +4,30 @@ import (
"embed" "embed"
"errors" "errors"
"fmt" "fmt"
"io/fs"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"github.com/labstack/echo/v4" "github.com/dave-gray101/v2keyauth"
"github.com/labstack/echo/v4/middleware" "github.com/gofiber/websocket/v2"
"github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/localai"
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes" "github.com/mudler/LocalAI/core/http/routes"
"github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
"github.com/gofiber/contrib/fiberzerolog"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/csrf"
"github.com/gofiber/fiber/v2/middleware/favicon"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/recover"
// swagger handler
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@@ -42,85 +49,86 @@ var embedDirStatic embed.FS
// @in header // @in header
// @name Authorization // @name Authorization
func API(application *application.Application) (*echo.Echo, error) { func API(application *application.Application) (*fiber.App, error) {
e := echo.New()
// Set body limit fiberCfg := fiber.Config{
if application.ApplicationConfig().UploadLimitMB > 0 { Views: renderEngine(),
e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", application.ApplicationConfig().UploadLimitMB))) BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
// We disable the Fiber startup message as it does not conform to structured logging.
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
DisableStartupMessage: true,
// Override default error handler
} }
// Set error handler
if !application.ApplicationConfig().OpaqueErrors { if !application.ApplicationConfig().OpaqueErrors {
e.HTTPErrorHandler = func(err error, c echo.Context) { // Normally, return errors as JSON responses
code := http.StatusInternalServerError fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
var he *echo.HTTPError // Status code defaults to 500
if errors.As(err, &he) { code := fiber.StatusInternalServerError
code = he.Code
}
// Handle 404 errors with HTML rendering when appropriate // Retrieve the custom status code if it's a *fiber.Error
if code == http.StatusNotFound { var e *fiber.Error
notFoundHandler(c) if errors.As(err, &e) {
return code = e.Code
} }
// Send custom error page // Send custom error page
c.JSON(code, schema.ErrorResponse{ return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code}, Error: &schema.APIError{Message: err.Error(), Code: code},
}) },
)
} }
} else { } else {
e.HTTPErrorHandler = func(err error, c echo.Context) { // If OpaqueErrors are required, replace everything with a blank 500.
code := http.StatusInternalServerError fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, _ error) error {
var he *echo.HTTPError return ctx.Status(500).SendString("")
if errors.As(err, &he) {
code = he.Code
}
c.NoContent(code)
} }
} }
// Set renderer router := fiber.New(fiberCfg)
e.Renderer = renderEngine()
// Hide banner router.Use(middleware.StripPathPrefix())
e.HideBanner = true
// Middleware - StripPathPrefix must be registered early as it uses Rewrite which runs before routing
e.Pre(httpMiddleware.StripPathPrefix())
if application.ApplicationConfig().MachineTag != "" { if application.ApplicationConfig().MachineTag != "" {
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { router.Use(func(c *fiber.Ctx) error {
return func(c echo.Context) error { c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag)
c.Response().Header().Set("Machine-Tag", application.ApplicationConfig().MachineTag)
return next(c) return c.Next()
}
}) })
} }
// Custom logger middleware using zerolog router.Use("/v1/realtime", func(c *fiber.Ctx) error {
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { if websocket.IsWebSocketUpgrade(c) {
return func(c echo.Context) error { // Returns true if the client requested upgrade to the WebSocket protocol
req := c.Request() return c.Next()
res := c.Response()
start := log.Logger.Info()
err := next(c)
start.
Str("method", req.Method).
Str("path", req.URL.Path).
Int("status", res.Status).
Msg("HTTP request")
return err
} }
return nil
}) })
// Recover middleware router.Hooks().OnListen(func(listenData fiber.ListenData) error {
scheme := "http"
if listenData.TLS {
scheme = "https"
}
log.Info().Str("endpoint", scheme+"://"+listenData.Host+":"+listenData.Port).Msg("LocalAI API is listening! Please connect to the endpoint for API documentation.")
return nil
})
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
logger := log.Logger
router.Use(fiberzerolog.New(fiberzerolog.Config{
Logger: &logger,
}))
// Default middleware config
if !application.ApplicationConfig().Debug { if !application.ApplicationConfig().Debug {
e.Use(middleware.Recover()) router.Use(recover.New())
} }
// Metrics middleware // OpenTelemetry metrics for Prometheus export
if !application.ApplicationConfig().DisableMetrics { if !application.ApplicationConfig().DisableMetrics {
metricsService, err := services.NewLocalAIMetricsService() metricsService, err := services.NewLocalAIMetricsService()
if err != nil { if err != nil {
@@ -128,40 +136,35 @@ func API(application *application.Application) (*echo.Echo, error) {
} }
if metricsService != nil { if metricsService != nil {
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
e.Server.RegisterOnShutdown(func() { router.Hooks().OnShutdown(func() error {
metricsService.Shutdown() return metricsService.Shutdown()
}) })
} }
} }
// Health Checks should always be exempt from auth, so register these first // Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(e) routes.HealthRoutes(router)
// Get key auth middleware kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig()) if err != nil || kaConfig == nil {
if err != nil {
return nil, fmt.Errorf("failed to create key auth config: %w", err) return nil, fmt.Errorf("failed to create key auth config: %w", err)
} }
// Favicon handler httpFS := http.FS(embedDirStatic)
e.GET("/favicon.svg", func(c echo.Context) error {
data, err := embedDirStatic.ReadFile("static/favicon.svg")
if err != nil {
return c.NoContent(http.StatusNotFound)
}
c.Response().Header().Set("Content-Type", "image/svg+xml")
return c.Blob(http.StatusOK, "image/svg+xml", data)
})
// Static files - use fs.Sub to create a filesystem rooted at "static" router.Use(favicon.New(favicon.Config{
staticFS, err := fs.Sub(embedDirStatic, "static") URL: "/favicon.svg",
if err != nil { FileSystem: httpFS,
return nil, fmt.Errorf("failed to create static filesystem: %w", err) File: "static/favicon.svg",
} }))
e.StaticFS("/static", staticFS)
router.Use("/static", filesystem.New(filesystem.Config{
Root: httpFS,
PathPrefix: "static",
Browse: true,
}))
// Generated content directories
if application.ApplicationConfig().GeneratedContentDir != "" { if application.ApplicationConfig().GeneratedContentDir != "" {
os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750) os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)
audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio") audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
@@ -172,53 +175,62 @@ func API(application *application.Application) (*echo.Echo, error) {
os.MkdirAll(imagePath, 0750) os.MkdirAll(imagePath, 0750)
os.MkdirAll(videoPath, 0750) os.MkdirAll(videoPath, 0750)
e.Static("/generated-audio", audioPath) router.Static("/generated-audio", audioPath)
e.Static("/generated-images", imagePath) router.Static("/generated-images", imagePath)
e.Static("/generated-videos", videoPath) router.Static("/generated-videos", videoPath)
} }
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration // Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
e.Use(keyAuthMiddleware) router.Use(v2keyauth.New(*kaConfig))
// CORS middleware
if application.ApplicationConfig().CORS { if application.ApplicationConfig().CORS {
corsConfig := middleware.CORSConfig{} var c func(ctx *fiber.Ctx) error
if application.ApplicationConfig().CORSAllowOrigins != "" { if application.ApplicationConfig().CORSAllowOrigins == "" {
corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",") c = cors.New()
} } else {
e.Use(middleware.CORSWithConfig(corsConfig)) c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
}
router.Use(c)
} }
// CSRF middleware
if application.ApplicationConfig().CSRF { if application.ApplicationConfig().CSRF {
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests") log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
e.Use(middleware.CSRF()) router.Use(csrf.New())
} }
requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) requestExtractor := middleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) routes.RegisterElevenLabsRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterLocalAIRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
var opcache *services.OpCache
if !application.ApplicationConfig().DisableWebUI { if !application.ApplicationConfig().DisableWebUI {
opcache = services.NewOpCache(application.GalleryService())
}
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator()) // Create metrics store for tracking usage (before API routes registration)
routes.RegisterOpenAIRoutes(e, requestExtractor, application) metricsStore := services.NewInMemoryMetricsStore()
if !application.ApplicationConfig().DisableWebUI {
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
}
routes.RegisterJINARoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
// Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route // Add metrics middleware BEFORE API routes so it can intercept them
router.Use(middleware.MetricsMiddleware(metricsStore))
// Log startup message // Register cleanup on shutdown
e.Server.RegisterOnShutdown(func() { router.Hooks().OnShutdown(func() error {
log.Info().Msg("LocalAI API server shutting down") metricsStore.Stop()
log.Info().Msg("Metrics store stopped")
return nil
}) })
return e, nil // Create opcache for tracking UI operations
opcache := services.NewOpCache(application.GalleryService())
routes.RegisterUIAPIRoutes(router, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, metricsStore)
routes.RegisterUIRoutes(router, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
}
routes.RegisterJINARoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
// Define a custom 404 handler
// Note: keep this at the bottom!
router.Use(notFoundHandler)
return router, nil
} }

View File

@@ -10,14 +10,13 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"time"
"github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http" . "github.com/mudler/LocalAI/core/http"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/system"
@@ -26,7 +25,6 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
openaigo "github.com/otiai10/openaigo" openaigo "github.com/otiai10/openaigo"
"github.com/rs/zerolog/log"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema" "github.com/sashabaranov/go-openai/jsonschema"
) )
@@ -87,7 +85,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
response := []gallery.GalleryModel{} response := []gallery.GalleryModel{}
uri := downloader.URI(url) uri := downloader.URI(url)
// TODO: No tests currently seem to exercise file:// urls. Fix? // TODO: No tests currently seem to exercise file:// urls. Fix?
err := uri.DownloadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error { err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
// Unmarshal YAML data into a struct // Unmarshal YAML data into a struct
return json.Unmarshal(i, &response) return json.Unmarshal(i, &response)
}) })
@@ -268,7 +266,7 @@ const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b8
var _ = Describe("API test", func() { var _ = Describe("API test", func() {
var app *echo.Echo var app *fiber.App
var client *openai.Client var client *openai.Client
var client2 *openaigo.Client var client2 *openaigo.Client
var c context.Context var c context.Context
@@ -341,11 +339,7 @@ var _ = Describe("API test", func() {
app, err = API(application) app, err = API(application)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go func() { go app.Listen("127.0.0.1:9090")
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
defaultConfig := openai.DefaultConfig(apiKey) defaultConfig := openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -364,9 +358,7 @@ var _ = Describe("API test", func() {
AfterEach(func(sc SpecContext) { AfterEach(func(sc SpecContext) {
cancel() cancel()
if app != nil { if app != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) err := app.Shutdown()
defer cancel()
err := app.Shutdown(ctx)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
err := os.RemoveAll(tmpdir) err := os.RemoveAll(tmpdir)
@@ -555,11 +547,7 @@ var _ = Describe("API test", func() {
app, err = API(application) app, err = API(application)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go func() { go app.Listen("127.0.0.1:9090")
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -578,9 +566,7 @@ var _ = Describe("API test", func() {
AfterEach(func() { AfterEach(func() {
cancel() cancel()
if app != nil { if app != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) err := app.Shutdown()
defer cancel()
err := app.Shutdown(ctx)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
err := os.RemoveAll(tmpdir) err := os.RemoveAll(tmpdir)
@@ -769,11 +755,7 @@ var _ = Describe("API test", func() {
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
app, err = API(application) app, err = API(application)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go func() { go app.Listen("127.0.0.1:9090")
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -791,9 +773,7 @@ var _ = Describe("API test", func() {
AfterEach(func() { AfterEach(func() {
cancel() cancel()
if app != nil { if app != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) err := app.Shutdown()
defer cancel()
err := app.Shutdown(ctx)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
}) })
@@ -816,83 +796,6 @@ var _ = Describe("API test", func() {
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
}) })
It("returns logprobs in chat completions when requested", func() {
topLogprobsVal := 3
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
Model: "testmodel.ggml",
LogProbs: true,
TopLogProbs: topLogprobsVal,
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(response.Choices)).To(Equal(1))
Expect(response.Choices[0].Message).ToNot(BeNil())
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
// Verify logprobs are present and have correct structure
Expect(response.Choices[0].LogProbs).ToNot(BeNil())
Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty())
Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1))
foundatLeastToken := ""
foundAtLeastBytes := []byte{}
foundAtLeastTopLogprobBytes := []byte{}
foundatLeastTopLogprob := ""
// Verify logprobs content structure matches OpenAI format
for _, logprobContent := range response.Choices[0].LogProbs.Content {
// Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it
if len(logprobContent.Bytes) > 0 {
foundAtLeastBytes = logprobContent.Bytes
}
if len(logprobContent.Token) > 0 {
foundatLeastToken = logprobContent.Token
}
Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1))
// If top_logprobs is requested, verify top_logprobs array respects the limit
if len(logprobContent.TopLogProbs) > 0 {
// Should respect top_logprobs limit (3 in this test)
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal))
for _, topLogprob := range logprobContent.TopLogProbs {
if len(topLogprob.Bytes) > 0 {
foundAtLeastTopLogprobBytes = topLogprob.Bytes
}
if len(topLogprob.Token) > 0 {
foundatLeastTopLogprob = topLogprob.Token
}
Expect(topLogprob.LogProb).To(BeNumerically("<=", 0))
}
}
}
Expect(foundAtLeastBytes).ToNot(BeEmpty())
Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty())
Expect(foundatLeastToken).ToNot(BeEmpty())
Expect(foundatLeastTopLogprob).ToNot(BeEmpty())
})
It("applies logit_bias to chat completions when requested", func() {
// logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
// According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion
logitBias := map[string]int{
"15043": 1, // Bias token ID 15043 (example token ID) with bias value 1
}
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
Model: "testmodel.ggml",
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
LogitBias: logitBias,
})
Expect(err).ToNot(HaveOccurred())
Expect(len(response.Choices)).To(Equal(1))
Expect(response.Choices[0].Message).ToNot(BeNil())
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
// If logit_bias is applied, the response should be generated successfully
// We can't easily verify the bias effect without knowing the actual token IDs for the model,
// but the fact that the request succeeds confirms the API accepts and processes logit_bias
})
It("returns errors", func() { It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt}) _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
@@ -1103,11 +1006,7 @@ var _ = Describe("API test", func() {
app, err = API(application) app, err = API(application)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go func() { go app.Listen("127.0.0.1:9090")
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -1123,9 +1022,7 @@ var _ = Describe("API test", func() {
AfterEach(func() { AfterEach(func() {
cancel() cancel()
if app != nil { if app != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) err := app.Shutdown()
defer cancel()
err := app.Shutdown(ctx)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
} }
}) })

View File

@@ -1,9 +1,7 @@
package elevenlabs package elevenlabs
import ( import (
"path/filepath" "github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
@@ -17,17 +15,17 @@ import (
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params" // @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
// @Success 200 {string} binary "Response" // @Success 200 {string} binary "Response"
// @Router /v1/sound-generation [post] // @Router /v1/sound-generation [post]
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
if !ok || input.ModelID == "" { if !ok || input.ModelID == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil { if !ok || cfg == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend") log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
@@ -37,7 +35,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
if err != nil { if err != nil {
return err return err
} }
return c.Attachment(filePath, filepath.Base(filePath)) return c.Download(filePath)
} }
} }

View File

@@ -1,14 +1,13 @@
package elevenlabs package elevenlabs
import ( import (
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@@ -18,19 +17,19 @@ import (
// @Param request body schema.TTSRequest true "query params" // @Param request body schema.TTSRequest true "query params"
// @Success 200 {string} binary "Response" // @Success 200 {string} binary "Response"
// @Router /v1/text-to-speech/{voice-id} [post] // @Router /v1/text-to-speech/{voice-id} [post]
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
voiceID := c.Param("voice-id") voiceID := c.Params("voice-id")
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
if !ok || input.ModelID == "" { if !ok || input.ModelID == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil { if !ok || cfg == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request received") log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request received")
@@ -39,6 +38,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
if err != nil { if err != nil {
return err return err
} }
return c.Attachment(filePath, filepath.Base(filePath)) return c.Download(filePath)
} }
} }

View File

@@ -2,32 +2,28 @@ package explorer
import ( import (
"encoding/base64" "encoding/base64"
"net/http"
"sort" "sort"
"strings"
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/explorer" "github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/internal"
) )
func Dashboard() echo.HandlerFunc { func Dashboard() func(*fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
summary := map[string]interface{}{ summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(), "Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(), "Version": internal.PrintableVersion(),
"BaseURL": middleware.BaseURL(c), "BaseURL": utils.BaseURL(c),
} }
contentType := c.Request().Header.Get("Content-Type") if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
accept := c.Request().Header.Get("Accept")
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "html")) {
// The client expects a JSON response // The client expects a JSON response
return c.JSON(http.StatusOK, summary) return c.Status(fiber.StatusOK).JSON(summary)
} else { } else {
// Render index // Render index
return c.Render(http.StatusOK, "views/explorer", summary) return c.Render("views/explorer", summary)
} }
} }
} }
@@ -43,8 +39,8 @@ type Network struct {
Token string `json:"token"` Token string `json:"token"`
} }
func ShowNetworks(db *explorer.Database) echo.HandlerFunc { func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
results := []Network{} results := []Network{}
for _, token := range db.TokenList() { for _, token := range db.TokenList() {
networkData, exists := db.Get(token) // get the token data networkData, exists := db.Get(token) // get the token data
@@ -65,44 +61,44 @@ func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
return len(results[i].Clusters) > len(results[j].Clusters) return len(results[i].Clusters) > len(results[j].Clusters)
}) })
return c.JSON(http.StatusOK, results) return c.JSON(results)
} }
} }
func AddNetwork(db *explorer.Database) echo.HandlerFunc { func AddNetwork(db *explorer.Database) func(*fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
request := new(AddNetworkRequest) request := new(AddNetworkRequest)
if err := c.Bind(request); err != nil { if err := c.BodyParser(request); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
} }
if request.Token == "" { if request.Token == "" {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token is required"})
} }
if request.Name == "" { if request.Name == "" {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Name is required"})
} }
if request.Description == "" { if request.Description == "" {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Description is required"})
} }
// TODO: check if token is valid, otherwise reject // TODO: check if token is valid, otherwise reject
// try to decode the token from base64 // try to decode the token from base64
_, err := base64.StdEncoding.DecodeString(request.Token) _, err := base64.StdEncoding.DecodeString(request.Token)
if err != nil { if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid token"})
} }
if _, exists := db.Get(request.Token); exists { if _, exists := db.Get(request.Token); exists {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"}) return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token already exists"})
} }
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description}) err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
if err != nil { if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Cannot add token"})
} }
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"}) return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Token added"})
} }
} }

View File

@@ -1,12 +1,11 @@
package jina package jina
import ( import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
@@ -18,36 +17,24 @@ import (
// @Param request body schema.JINARerankRequest true "query params" // @Param request body schema.JINARerankRequest true "query params"
// @Success 200 {object} schema.JINARerankResponse "Response" // @Success 200 {object} schema.JINARerankResponse "Response"
// @Router /v1/rerank [post] // @Router /v1/rerank [post]
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil { if !ok || cfg == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received") log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received")
var requestTopN int32
docs := int32(len(input.Documents))
if input.TopN == nil { // omit top_n to get all
requestTopN = docs
} else {
requestTopN = int32(*input.TopN)
if requestTopN < 1 {
return c.JSON(http.StatusUnprocessableEntity, "top_n - should be greater than or equal to 1")
}
if requestTopN > docs { // make it more obvious for backends
requestTopN = docs
}
}
request := &proto.RerankRequest{ request := &proto.RerankRequest{
Query: input.Query, Query: input.Query,
TopN: requestTopN, TopN: int32(input.TopN),
Documents: input.Documents, Documents: input.Documents,
} }
@@ -71,6 +58,6 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
response.Usage.TotalTokens = int(results.Usage.TotalTokens) response.Usage.TotalTokens = int(results.Usage.TotalTokens)
response.Usage.PromptTokens = int(results.Usage.PromptTokens) response.Usage.PromptTokens = int(results.Usage.PromptTokens)
return c.JSON(http.StatusOK, response) return c.Status(fiber.StatusOK).JSON(response)
} }
} }

View File

@@ -4,11 +4,11 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/system"
@@ -39,13 +39,13 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste
// @Summary Returns the job status // @Summary Returns the job status
// @Success 200 {object} services.GalleryOpStatus "Response" // @Success 200 {object} services.GalleryOpStatus "Response"
// @Router /backends/jobs/{uuid} [get] // @Router /backends/jobs/{uuid} [get]
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc { func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
status := mgs.backendApplier.GetStatus(c.Param("uuid")) status := mgs.backendApplier.GetStatus(c.Params("uuid"))
if status == nil { if status == nil {
return fmt.Errorf("could not find any status for ID") return fmt.Errorf("could not find any status for ID")
} }
return c.JSON(200, status) return c.JSON(status)
} }
} }
@@ -53,9 +53,9 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
// @Summary Returns all the jobs status progress // @Summary Returns all the jobs status progress
// @Success 200 {object} map[string]services.GalleryOpStatus "Response" // @Success 200 {object} map[string]services.GalleryOpStatus "Response"
// @Router /backends/jobs [get] // @Router /backends/jobs [get]
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc { func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
return c.JSON(200, mgs.backendApplier.GetAllStatus()) return c.JSON(mgs.backendApplier.GetAllStatus())
} }
} }
@@ -64,11 +64,11 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
// @Param request body GalleryBackend true "query params" // @Param request body GalleryBackend true "query params"
// @Success 200 {object} schema.BackendResponse "Response" // @Success 200 {object} schema.BackendResponse "Response"
// @Router /backends/apply [post] // @Router /backends/apply [post]
func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc { func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input := new(GalleryBackend) input := new(GalleryBackend)
// Get input data from the request body // Get input data from the request body
if err := c.Bind(input); err != nil { if err := c.BodyParser(input); err != nil {
return err return err
} }
@@ -76,13 +76,13 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
if err != nil { if err != nil {
return err return err
} }
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{ mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend]{
ID: uuid.String(), ID: uuid.String(),
GalleryElementName: input.ID, GalleryElementName: input.ID,
Galleries: mgs.galleries, Galleries: mgs.galleries,
} }
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())}) return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
} }
} }
@@ -91,11 +91,11 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
// @Param name path string true "Backend name" // @Param name path string true "Backend name"
// @Success 200 {object} schema.BackendResponse "Response" // @Success 200 {object} schema.BackendResponse "Response"
// @Router /backends/delete/{name} [post] // @Router /backends/delete/{name} [post]
func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc { func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
backendName := c.Param("name") backendName := c.Params("name")
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{ mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend]{
Delete: true, Delete: true,
GalleryElementName: backendName, GalleryElementName: backendName,
Galleries: mgs.galleries, Galleries: mgs.galleries,
@@ -106,7 +106,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
return err return err
} }
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())}) return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
} }
} }
@@ -114,13 +114,13 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
// @Summary List all Backends // @Summary List all Backends
// @Success 200 {object} []gallery.GalleryBackend "Response" // @Success 200 {object} []gallery.GalleryBackend "Response"
// @Router /backends [get] // @Router /backends [get]
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc { func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
backends, err := gallery.ListSystemBackends(systemState) backends, err := gallery.ListSystemBackends(systemState)
if err != nil { if err != nil {
return err return err
} }
return c.JSON(200, backends.GetAll()) return c.JSON(backends.GetAll())
} }
} }
@@ -129,14 +129,14 @@ func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.Syst
// @Success 200 {object} []config.Gallery "Response" // @Success 200 {object} []config.Gallery "Response"
// @Router /backends/galleries [get] // @Router /backends/galleries [get]
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! // NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFunc { func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing backend galleries %+v", mgs.galleries) log.Debug().Msgf("Listing backend galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries) dat, err := json.Marshal(mgs.galleries)
if err != nil { if err != nil {
return err return err
} }
return c.Blob(200, "application/json", dat) return c.Send(dat)
} }
} }
@@ -144,12 +144,12 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFu
// @Summary List all available Backends // @Summary List all available Backends
// @Success 200 {object} []gallery.GalleryBackend "Response" // @Success 200 {object} []gallery.GalleryBackend "Response"
// @Router /backends/available [get] // @Router /backends/available [get]
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc { func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
backends, err := gallery.AvailableBackends(mgs.galleries, systemState) backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
if err != nil { if err != nil {
return err return err
} }
return c.JSON(200, backends) return c.JSON(backends)
} }
} }

View File

@@ -1,7 +1,7 @@
package localai package localai
import ( import (
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
) )
@@ -11,12 +11,12 @@ import (
// @Param request body schema.BackendMonitorRequest true "Backend statistics request" // @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Success 200 {object} proto.StatusResponse "Response" // @Success 200 {object} proto.StatusResponse "Response"
// @Router /backend/monitor [get] // @Router /backend/monitor [get]
func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc { func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest) input := new(schema.BackendMonitorRequest)
// Get input data from the request body // Get input data from the request body
if err := c.Bind(input); err != nil { if err := c.BodyParser(input); err != nil {
return err return err
} }
@@ -24,7 +24,7 @@ func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc
if err != nil { if err != nil {
return err return err
} }
return c.JSON(200, resp) return c.JSON(resp)
} }
} }
@@ -32,11 +32,11 @@ func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc
// @Summary Backend monitor endpoint // @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request" // @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Router /backend/shutdown [post] // @Router /backend/shutdown [post]
func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc { func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest) input := new(schema.BackendMonitorRequest)
// Get input data from the request body // Get input data from the request body
if err := c.Bind(input); err != nil { if err := c.BodyParser(input); err != nil {
return err return err
} }

View File

@@ -1,7 +1,7 @@
package localai package localai
import ( import (
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
@@ -16,17 +16,17 @@ import (
// @Param request body schema.DetectionRequest true "query params" // @Param request body schema.DetectionRequest true "query params"
// @Success 200 {object} schema.DetectionResponse "Response" // @Success 200 {object} schema.DetectionResponse "Response"
// @Router /v1/detection [post] // @Router /v1/detection [post]
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil { if !ok || cfg == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection") log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection")
@@ -54,6 +54,6 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
} }
} }
return c.JSON(200, response) return c.JSON(response)
} }
} }

View File

@@ -2,13 +2,11 @@ package localai
import ( import (
"fmt" "fmt"
"io"
"net/http"
"os" "os"
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
httpUtils "github.com/mudler/LocalAI/core/http/middleware" httpUtils "github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/utils"
@@ -16,15 +14,15 @@ import (
) )
// GetEditModelPage renders the edit model page with current configuration // GetEditModelPage renders the edit model page with current configuration
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
modelName := c.Param("name") modelName := c.Params("name")
if modelName == "" { if modelName == "" {
response := ModelResponse{ response := ModelResponse{
Success: false, Success: false,
Error: "Model name is required", Error: "Model name is required",
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
modelConfig, exists := cl.GetModelConfig(modelName) modelConfig, exists := cl.GetModelConfig(modelName)
@@ -33,7 +31,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Success: false, Success: false,
Error: "Model configuration not found", Error: "Model configuration not found",
} }
return c.JSON(http.StatusNotFound, response) return c.Status(404).JSON(response)
} }
modelConfigFile := modelConfig.GetModelConfigFile() modelConfigFile := modelConfig.GetModelConfigFile()
@@ -42,7 +40,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Success: false, Success: false,
Error: "Model configuration file not found", Error: "Model configuration file not found",
} }
return c.JSON(http.StatusNotFound, response) return c.Status(404).JSON(response)
} }
configData, err := os.ReadFile(modelConfigFile) configData, err := os.ReadFile(modelConfigFile)
if err != nil { if err != nil {
@@ -50,7 +48,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Success: false, Success: false,
Error: "Failed to read configuration file: " + err.Error(), Error: "Failed to read configuration file: " + err.Error(),
} }
return c.JSON(http.StatusInternalServerError, response) return c.Status(500).JSON(response)
} }
// Render the edit page with the current configuration // Render the edit page with the current configuration
@@ -71,20 +69,20 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Version: internal.PrintableVersion(), Version: internal.PrintableVersion(),
} }
return c.Render(http.StatusOK, "views/model-editor", templateData) return c.Render("views/model-editor", templateData)
} }
} }
// EditModelEndpoint handles updating existing model configurations // EditModelEndpoint handles updating existing model configurations
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
modelName := c.Param("name") modelName := c.Params("name")
if modelName == "" { if modelName == "" {
response := ModelResponse{ response := ModelResponse{
Success: false, Success: false,
Error: "Model name is required", Error: "Model name is required",
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
modelConfig, exists := cl.GetModelConfig(modelName) modelConfig, exists := cl.GetModelConfig(modelName)
@@ -93,24 +91,17 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false, Success: false,
Error: "Existing model configuration not found", Error: "Existing model configuration not found",
} }
return c.JSON(http.StatusNotFound, response) return c.Status(404).JSON(response)
} }
// Get the raw body // Get the raw body
body, err := io.ReadAll(c.Request().Body) body := c.Body()
if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to read request body: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
}
if len(body) == 0 { if len(body) == 0 {
response := ModelResponse{ response := ModelResponse{
Success: false, Success: false,
Error: "Request body is empty", Error: "Request body is empty",
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
// Check content to see if it's a valid model config // Check content to see if it's a valid model config
@@ -122,7 +113,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false, Success: false,
Error: "Failed to parse YAML: " + err.Error(), Error: "Failed to parse YAML: " + err.Error(),
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
// Validate required fields // Validate required fields
@@ -131,7 +122,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false, Success: false,
Error: "Name is required", Error: "Name is required",
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
// Validate the configuration // Validate the configuration
@@ -141,7 +132,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Error: "Validation failed", Error: "Validation failed",
Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."}, Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."},
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
// Load the existing configuration // Load the existing configuration
@@ -151,7 +142,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false, Success: false,
Error: "Model configuration not trusted: " + err.Error(), Error: "Model configuration not trusted: " + err.Error(),
} }
return c.JSON(http.StatusNotFound, response) return c.Status(404).JSON(response)
} }
// Write new content to file // Write new content to file
@@ -160,16 +151,16 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false, Success: false,
Error: "Failed to write configuration file: " + err.Error(), Error: "Failed to write configuration file: " + err.Error(),
} }
return c.JSON(http.StatusInternalServerError, response) return c.Status(500).JSON(response)
} }
// Reload configurations // Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil { if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{ response := ModelResponse{
Success: false, Success: false,
Error: "Failed to reload configurations: " + err.Error(), Error: "Failed to reload configurations: " + err.Error(),
} }
return c.JSON(http.StatusInternalServerError, response) return c.Status(500).JSON(response)
} }
// Preload the model // Preload the model
@@ -178,7 +169,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false, Success: false,
Error: "Failed to preload model: " + err.Error(), Error: "Failed to preload model: " + err.Error(),
} }
return c.JSON(http.StatusInternalServerError, response) return c.Status(500).JSON(response)
} }
// Return success response // Return success response
@@ -188,20 +179,20 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Filename: configPath, Filename: configPath,
Config: req, Config: req,
} }
return c.JSON(200, response) return c.JSON(response)
} }
} }
// ReloadModelsEndpoint handles reloading model configurations from disk // ReloadModelsEndpoint handles reloading model configurations from disk
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
// Reload configurations // Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil { if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{ response := ModelResponse{
Success: false, Success: false,
Error: "Failed to reload configurations: " + err.Error(), Error: "Failed to reload configurations: " + err.Error(),
} }
return c.JSON(http.StatusInternalServerError, response) return c.Status(500).JSON(response)
} }
// Preload the models // Preload the models
@@ -210,7 +201,7 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
Success: false, Success: false,
Error: "Failed to preload models: " + err.Error(), Error: "Failed to preload models: " + err.Error(),
} }
return c.JSON(http.StatusInternalServerError, response) return c.Status(500).JSON(response)
} }
// Return success response // Return success response
@@ -218,6 +209,6 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
Success: true, Success: true,
Message: "Model configurations reloaded successfully", Message: "Model configurations reloaded successfully",
} }
return c.JSON(http.StatusOK, response) return c.Status(fiber.StatusOK).JSON(response)
} }
} }

View File

@@ -2,14 +2,12 @@ package localai_test
import ( import (
"bytes" "bytes"
"encoding/json"
"io" "io"
"net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http/endpoints/localai" . "github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/system"
@@ -17,14 +15,6 @@ import (
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
) )
// testRenderer is a simple renderer for tests that returns JSON
type testRenderer struct{}
func (t *testRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
// For tests, just return the data as JSON
return json.NewEncoder(w).Encode(data)
}
var _ = Describe("Edit Model test", func() { var _ = Describe("Edit Model test", func() {
var tempDir string var tempDir string
@@ -50,35 +40,33 @@ var _ = Describe("Edit Model test", func() {
//modelLoader := model.NewModelLoader(systemState, true) //modelLoader := model.NewModelLoader(systemState, true)
modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath) modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath)
// Define Echo app and register all routes upfront // Define Fiber app.
app := echo.New() app := fiber.New()
// Set up a simple renderer for the test app.Put("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
app.Renderer = &testRenderer{}
app.POST("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
app.GET("/edit-model/:name", GetEditModelPage(modelConfigLoader, applicationConfig))
requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`) requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`)
req := httptest.NewRequest("POST", "/import-model", requestBody) req := httptest.NewRequest("PUT", "/import-model", requestBody)
req.Header.Set("Content-Type", "application/json") resp, err := app.Test(req, 5000)
rec := httptest.NewRecorder() Expect(err).ToNot(HaveOccurred())
app.ServeHTTP(rec, req)
body, err := io.ReadAll(rec.Body) body, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(ContainSubstring("Model configuration created successfully")) Expect(string(body)).To(ContainSubstring("Model configuration created successfully"))
Expect(rec.Code).To(Equal(http.StatusOK)) Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
req = httptest.NewRequest("GET", "/edit-model/foo", nil) app.Get("/edit-model/:name", EditModelEndpoint(modelConfigLoader, applicationConfig))
rec = httptest.NewRecorder() requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`)
app.ServeHTTP(rec, req)
body, err = io.ReadAll(rec.Body) req = httptest.NewRequest("GET", "/edit-model/foo", requestBody)
resp, _ = app.Test(req, 1)
body, err = io.ReadAll(resp.Body)
defer resp.Body.Close()
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
// The response contains the model configuration with backend field Expect(string(body)).To(ContainSubstring(`"model":"foo"`))
Expect(string(body)).To(ContainSubstring(`"backend":"foo"`)) Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
Expect(string(body)).To(ContainSubstring(`"name":"foo"`))
Expect(rec.Code).To(Equal(http.StatusOK))
}) })
}) })
}) })

View File

@@ -4,11 +4,11 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/system"
@@ -40,13 +40,13 @@ func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGaller
// @Summary Returns the job status // @Summary Returns the job status
// @Success 200 {object} services.GalleryOpStatus "Response" // @Success 200 {object} services.GalleryOpStatus "Response"
// @Router /models/jobs/{uuid} [get] // @Router /models/jobs/{uuid} [get]
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc { func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.GetStatus(c.Param("uuid")) status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
if status == nil { if status == nil {
return fmt.Errorf("could not find any status for ID") return fmt.Errorf("could not find any status for ID")
} }
return c.JSON(200, status) return c.JSON(status)
} }
} }
@@ -54,9 +54,9 @@ func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
// @Summary Returns all the jobs status progress // @Summary Returns all the jobs status progress
// @Success 200 {object} map[string]services.GalleryOpStatus "Response" // @Success 200 {object} map[string]services.GalleryOpStatus "Response"
// @Router /models/jobs [get] // @Router /models/jobs [get]
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc { func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
return c.JSON(200, mgs.galleryApplier.GetAllStatus()) return c.JSON(mgs.galleryApplier.GetAllStatus())
} }
} }
@@ -65,11 +65,11 @@ func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc
// @Param request body GalleryModel true "query params" // @Param request body GalleryModel true "query params"
// @Success 200 {object} schema.GalleryResponse "Response" // @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/apply [post] // @Router /models/apply [post]
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.HandlerFunc { func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input := new(GalleryModel) input := new(GalleryModel)
// Get input data from the request body // Get input data from the request body
if err := c.Bind(input); err != nil { if err := c.BodyParser(input); err != nil {
return err return err
} }
@@ -77,7 +77,7 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.Handler
if err != nil { if err != nil {
return err return err
} }
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel]{
Req: input.GalleryModel, Req: input.GalleryModel,
ID: uuid.String(), ID: uuid.String(),
GalleryElementName: input.ID, GalleryElementName: input.ID,
@@ -85,7 +85,7 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.Handler
BackendGalleries: mgs.backendGalleries, BackendGalleries: mgs.backendGalleries,
} }
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())}) return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
} }
} }
@@ -94,11 +94,11 @@ func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.Handler
// @Param name path string true "Model name" // @Param name path string true "Model name"
// @Success 200 {object} schema.GalleryResponse "Response" // @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/delete/{name} [post] // @Router /models/delete/{name} [post]
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.HandlerFunc { func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
modelName := c.Param("name") modelName := c.Params("name")
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel]{
Delete: true, Delete: true,
GalleryElementName: modelName, GalleryElementName: modelName,
} }
@@ -108,7 +108,7 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.Handle
return err return err
} }
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())}) return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
} }
} }
@@ -116,8 +116,8 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.Handle
// @Summary List installable models. // @Summary List installable models.
// @Success 200 {object} []gallery.GalleryModel "Response" // @Success 200 {object} []gallery.GalleryModel "Response"
// @Router /models/available [get] // @Router /models/available [get]
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc { func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState) models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
if err != nil { if err != nil {
@@ -139,7 +139,7 @@ func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState
if err != nil { if err != nil {
return fmt.Errorf("could not marshal models: %w", err) return fmt.Errorf("could not marshal models: %w", err)
} }
return c.Blob(200, "application/json", dat) return c.Send(dat)
} }
} }
@@ -148,13 +148,13 @@ func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState
// @Success 200 {object} []config.Gallery "Response" // @Success 200 {object} []config.Gallery "Response"
// @Router /models/galleries [get] // @Router /models/galleries [get]
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! // NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() echo.HandlerFunc { func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries) log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries) dat, err := json.Marshal(mgs.galleries)
if err != nil { if err != nil {
return err return err
} }
return c.Blob(200, "application/json", dat) return c.Send(dat)
} }
} }

View File

@@ -1,7 +1,7 @@
package localai package localai
import ( import (
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
@@ -21,17 +21,17 @@ import (
// @Success 200 {string} binary "generated audio/wav file" // @Success 200 {string} binary "generated audio/wav file"
// @Router /v1/tokenMetrics [get] // @Router /v1/tokenMetrics [get]
// @Router /tokenMetrics [get] // @Router /tokenMetrics [get]
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input := new(schema.TokenMetricsRequest) input := new(schema.TokenMetricsRequest)
// Get input data from the request body // Get input data from the request body
if err := c.Bind(input); err != nil { if err := c.BodyParser(input); err != nil {
return err return err
} }
modelFile, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string) modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || modelFile != "" { if !ok || modelFile != "" {
modelFile = input.Model modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model) log.Warn().Msgf("Model not found in context: %s", input.Model)
@@ -52,6 +52,6 @@ func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
if err != nil { if err != nil {
return err return err
} }
return c.JSON(200, response) return c.JSON(response)
} }
} }

View File

@@ -2,97 +2,33 @@ package localai
import ( import (
"encoding/json" "encoding/json"
"fmt"
"io"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/google/uuid" "github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/gallery/importers"
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/utils"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
// ImportModelURIEndpoint handles creating new model configurations from a URI
func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.ImportModelRequest)
if err := c.Bind(input); err != nil {
return err
}
modelConfig, err := importers.DiscoverModelConfig(input.URI, input.Preferences)
if err != nil {
return fmt.Errorf("failed to discover model config: %w", err)
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
// Determine gallery ID for tracking - use model name if available, otherwise use URI
galleryID := input.URI
if modelConfig.Name != "" {
galleryID = modelConfig.Name
}
// Register operation in opcache if available (for UI progress tracking)
if opcache != nil {
opcache.Set(galleryID, uuid.String())
}
galleryService.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
Req: gallery.GalleryModel{
Overrides: map[string]interface{}{},
},
ID: uuid.String(),
GalleryElementName: galleryID,
GalleryElement: &modelConfig,
BackendGalleries: appConfig.BackendGalleries,
}
return c.JSON(200, schema.GalleryResponse{
ID: uuid.String(),
StatusURL: fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), uuid.String()),
})
}
}
// ImportModelEndpoint handles creating new model configurations // ImportModelEndpoint handles creating new model configurations
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
// Get the raw body // Get the raw body
body, err := io.ReadAll(c.Request().Body) body := c.Body()
if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to read request body: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
}
if len(body) == 0 { if len(body) == 0 {
response := ModelResponse{ response := ModelResponse{
Success: false, Success: false,
Error: "Request body is empty", Error: "Request body is empty",
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
// Check content type to determine how to parse // Check content type to determine how to parse
contentType := c.Request().Header.Get("Content-Type") contentType := string(c.Context().Request.Header.ContentType())
var modelConfig config.ModelConfig var modelConfig config.ModelConfig
var err error
if strings.Contains(contentType, "application/json") { if strings.Contains(contentType, "application/json") {
// Parse JSON // Parse JSON
@@ -101,7 +37,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false, Success: false,
Error: "Failed to parse JSON: " + err.Error(), Error: "Failed to parse JSON: " + err.Error(),
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") { } else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
// Parse YAML // Parse YAML
@@ -110,18 +46,18 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false, Success: false,
Error: "Failed to parse YAML: " + err.Error(), Error: "Failed to parse YAML: " + err.Error(),
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
} else { } else {
// Try to auto-detect format // Try to auto-detect format
if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' { if strings.TrimSpace(string(body))[0] == '{' {
// Looks like JSON // Looks like JSON
if err := json.Unmarshal(body, &modelConfig); err != nil { if err := json.Unmarshal(body, &modelConfig); err != nil {
response := ModelResponse{ response := ModelResponse{
Success: false, Success: false,
Error: "Failed to parse JSON: " + err.Error(), Error: "Failed to parse JSON: " + err.Error(),
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
} else { } else {
// Assume YAML // Assume YAML
@@ -130,7 +66,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false, Success: false,
Error: "Failed to parse YAML: " + err.Error(), Error: "Failed to parse YAML: " + err.Error(),
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
} }
} }
@@ -141,7 +77,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false, Success: false,
Error: "Name is required", Error: "Name is required",
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
// Set defaults // Set defaults
@@ -153,7 +89,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false, Success: false,
Error: "Invalid configuration", Error: "Invalid configuration",
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
// Create the configuration file // Create the configuration file
@@ -163,7 +99,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false, Success: false,
Error: "Model path not trusted: " + err.Error(), Error: "Model path not trusted: " + err.Error(),
} }
return c.JSON(http.StatusBadRequest, response) return c.Status(400).JSON(response)
} }
// Marshal to YAML for storage // Marshal to YAML for storage
@@ -173,7 +109,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false, Success: false,
Error: "Failed to marshal configuration: " + err.Error(), Error: "Failed to marshal configuration: " + err.Error(),
} }
return c.JSON(http.StatusInternalServerError, response) return c.Status(500).JSON(response)
} }
// Write the file // Write the file
@@ -182,7 +118,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false, Success: false,
Error: "Failed to write configuration file: " + err.Error(), Error: "Failed to write configuration file: " + err.Error(),
} }
return c.JSON(http.StatusInternalServerError, response) return c.Status(500).JSON(response)
} }
// Reload configurations // Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil { if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
@@ -190,7 +126,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false, Success: false,
Error: "Failed to reload configurations: " + err.Error(), Error: "Failed to reload configurations: " + err.Error(),
} }
return c.JSON(http.StatusInternalServerError, response) return c.Status(500).JSON(response)
} }
// Preload the model // Preload the model
@@ -199,7 +135,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false, Success: false,
Error: "Failed to preload model: " + err.Error(), Error: "Failed to preload model: " + err.Error(),
} }
return c.JSON(http.StatusInternalServerError, response) return c.Status(500).JSON(response)
} }
// Return success response // Return success response
response := ModelResponse{ response := ModelResponse{
@@ -207,6 +143,6 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Message: "Model configuration created successfully", Message: "Model configuration created successfully",
Filename: filepath.Base(configPath), Filename: filepath.Base(configPath),
} }
return c.JSON(200, response) return c.JSON(response)
} }
} }

View File

@@ -1,323 +0,0 @@
package localai
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/cogito"
"github.com/rs/zerolog/log"
)
// MCP SSE Event Types
type MCPReasoningEvent struct {
Type string `json:"type"`
Content string `json:"content"`
}
type MCPToolCallEvent struct {
Type string `json:"type"`
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
Reasoning string `json:"reasoning"`
}
type MCPToolResultEvent struct {
Type string `json:"type"`
Name string `json:"name"`
Result string `json:"result"`
}
type MCPStatusEvent struct {
Type string `json:"type"`
Message string `json:"message"`
}
type MCPAssistantEvent struct {
Type string `json:"type"`
Content string `json:"content"`
}
type MCPErrorEvent struct {
Type string `json:"type"`
Message string `json:"message"`
}
// MCPStreamEndpoint is the SSE streaming endpoint for MCP chat completions
// @Summary Stream MCP chat completions with reasoning, tool calls, and results
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/mcp/chat/completions [post]
func MCPStreamEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
created := int(time.Now().Unix())
// Handle Correlation
id := c.Request().Header.Get("X-Correlation-ID")
if id == "" {
id = fmt.Sprintf("mcp-%d", time.Now().UnixNano())
}
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
}
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return echo.ErrBadRequest
}
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
return fmt.Errorf("no MCP servers configured")
}
// Get MCP config from model config
remote, stdio, err := config.MCP.MCPConfigFromYAML()
if err != nil {
return fmt.Errorf("failed to get MCP config: %w", err)
}
// Check if we have tools in cache, or we have to have an initial connection
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
if err != nil {
return fmt.Errorf("failed to get MCP sessions: %w", err)
}
if len(sessions) == 0 {
return fmt.Errorf("no working MCP servers found")
}
// Build fragment from messages
fragment := cogito.NewEmptyFragment()
for _, message := range input.Messages {
fragment = fragment.AddMessage(message.Role, message.StringContent)
}
port := appConfig.APIAddress[strings.LastIndex(appConfig.APIAddress, ":")+1:]
apiKey := ""
if len(appConfig.ApiKeys) > 0 {
apiKey = appConfig.ApiKeys[0]
}
ctxWithCancellation, cancel := context.WithCancel(ctx)
defer cancel()
// TODO: instead of connecting to the API, we should just wire this internally
// and act like completion.go.
// We can do this as cogito expects an interface and we can create one that
// we satisfy to just call internally ComputeChoices
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
// Build cogito options using the consolidated method
cogitoOpts := config.BuildCogitoOptions()
cogitoOpts = append(
cogitoOpts,
cogito.WithContext(ctxWithCancellation),
cogito.WithMCPs(sessions...),
)
// Check if streaming is requested
toStream := input.Stream
if !toStream {
// Non-streaming mode: execute synchronously and return JSON response
cogitoOpts = append(
cogitoOpts,
cogito.WithStatusCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
}),
cogito.WithReasoningCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
}),
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("reasoning", t.Reasoning).Interface("arguments", t.Arguments).Msg("[model agent] Tool call")
return true
}),
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("result", t.Result).Interface("tool_arguments", t.ToolArguments).Msg("[model agent] Tool call result")
}),
)
f, err := cogito.ExecuteTools(
defaultLLM, fragment,
cogitoOpts...,
)
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
return err
}
f, err = defaultLLM.Ask(ctxWithCancellation, f)
if err != nil {
return err
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Message: &schema.Message{Role: "assistant", Content: &f.LastMessage().Content}}},
Object: "chat.completion",
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(200, resp)
}
// Streaming mode: use SSE
// Set up SSE headers
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().Header().Set("X-Correlation-ID", id)
// Create channel for streaming events
events := make(chan interface{})
ended := make(chan error, 1)
// Set up callbacks for streaming
statusCallback := func(s string) {
events <- MCPStatusEvent{
Type: "status",
Message: s,
}
}
reasoningCallback := func(s string) {
events <- MCPReasoningEvent{
Type: "reasoning",
Content: s,
}
}
toolCallCallback := func(t *cogito.ToolChoice) bool {
events <- MCPToolCallEvent{
Type: "tool_call",
Name: t.Name,
Arguments: t.Arguments,
Reasoning: t.Reasoning,
}
return true
}
toolCallResultCallback := func(t cogito.ToolStatus) {
events <- MCPToolResultEvent{
Type: "tool_result",
Name: t.Name,
Result: t.Result,
}
}
cogitoOpts = append(cogitoOpts,
cogito.WithStatusCallback(statusCallback),
cogito.WithReasoningCallback(reasoningCallback),
cogito.WithToolCallBack(toolCallCallback),
cogito.WithToolCallResultCallback(toolCallResultCallback),
)
// Execute tools in a goroutine
go func() {
defer close(events)
f, err := cogito.ExecuteTools(
defaultLLM, fragment,
cogitoOpts...,
)
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
events <- MCPErrorEvent{
Type: "error",
Message: fmt.Sprintf("Failed to execute tools: %v", err),
}
ended <- err
return
}
// Get final response
f, err = defaultLLM.Ask(ctxWithCancellation, f)
if err != nil {
events <- MCPErrorEvent{
Type: "error",
Message: fmt.Sprintf("Failed to get response: %v", err),
}
ended <- err
return
}
// Stream final assistant response
content := f.LastMessage().Content
events <- MCPAssistantEvent{
Type: "assistant",
Content: content,
}
ended <- nil
}()
// Stream events to client
LOOP:
for {
select {
case <-ctx.Done():
// Context was cancelled (client disconnected or request cancelled)
log.Debug().Msgf("Request context cancelled, stopping stream")
cancel()
break LOOP
case event := <-events:
if event == nil {
// Channel closed
break LOOP
}
eventData, err := json.Marshal(event)
if err != nil {
log.Debug().Msgf("Failed to marshal event: %v", err)
continue
}
log.Debug().Msgf("Sending event: %s", string(eventData))
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(eventData))
if err != nil {
log.Debug().Msgf("Sending event failed: %v", err)
cancel()
return err
}
c.Response().Flush()
case err := <-ended:
if err == nil {
// Send done signal
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
errorEvent := MCPErrorEvent{
Type: "error",
Message: err.Error(),
}
errorData, marshalErr := json.Marshal(errorEvent)
if marshalErr != nil {
fmt.Fprintf(c.Response().Writer, "data: {\"type\":\"error\",\"message\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
}
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
}
}
log.Debug().Msgf("Stream ended")
return nil
}
}

View File

@@ -3,7 +3,8 @@ package localai
import ( import (
"time" "time"
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/adaptor"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
) )
@@ -12,36 +13,34 @@ import (
// @Summary Prometheus metrics endpoint // @Summary Prometheus metrics endpoint
// @Param request body config.Gallery true "Gallery details" // @Param request body config.Gallery true "Gallery details"
// @Router /metrics [get] // @Router /metrics [get]
func LocalAIMetricsEndpoint() echo.HandlerFunc { func LocalAIMetricsEndpoint() fiber.Handler {
return echo.WrapHandler(promhttp.Handler()) return adaptor.HTTPHandler(promhttp.Handler())
} }
type apiMiddlewareConfig struct { type apiMiddlewareConfig struct {
Filter func(c echo.Context) bool Filter func(c *fiber.Ctx) bool
metricsService *services.LocalAIMetricsService metricsService *services.LocalAIMetricsService
} }
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) echo.MiddlewareFunc { func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler {
cfg := apiMiddlewareConfig{ cfg := apiMiddlewareConfig{
metricsService: metrics, metricsService: metrics,
Filter: func(c echo.Context) bool { Filter: func(c *fiber.Ctx) bool {
return c.Path() == "/metrics" return c.Path() == "/metrics"
}, },
} }
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *fiber.Ctx) error {
return func(c echo.Context) error {
if cfg.Filter != nil && cfg.Filter(c) { if cfg.Filter != nil && cfg.Filter(c) {
return next(c) return c.Next()
} }
path := c.Path() path := c.Path()
method := c.Request().Method method := c.Method()
start := time.Now() start := time.Now()
err := next(c) err := c.Next()
elapsed := float64(time.Since(start)) / float64(time.Second) elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metricsService.ObserveAPICall(method, path, elapsed) cfg.metricsService.ObserveAPICall(method, path, elapsed)
return err return err
} }
}
} }

View File

@@ -1,7 +1,7 @@
package localai package localai
import ( import (
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
@@ -11,10 +11,10 @@ import (
// @Summary Returns available P2P nodes // @Summary Returns available P2P nodes
// @Success 200 {object} []schema.P2PNodesResponse "Response" // @Success 200 {object} []schema.P2PNodesResponse "Response"
// @Router /api/p2p [get] // @Router /api/p2p [get]
func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc { func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
// Render index // Render index
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
return c.JSON(200, schema.P2PNodesResponse{ return c.JSON(schema.P2PNodesResponse{
Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)), Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)),
FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)), FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)),
}) })
@@ -25,6 +25,6 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
// @Summary Show the P2P token // @Summary Show the P2P token
// @Success 200 {string} string "Response" // @Success 200 {string} string "Response"
// @Router /api/p2p/token [get] // @Router /api/p2p/token [get]
func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc { func ShowP2PToken(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
return func(c echo.Context) error { return c.String(200, appConfig.P2PToken) } return func(c *fiber.Ctx) error { return c.Send([]byte(appConfig.P2PToken)) }
} }

View File

@@ -0,0 +1,61 @@
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
)
// SettingsEndpoint handles the settings page which shows detailed model/backend management
func SettingsEndpoint(appConfig *config.ApplicationConfig,
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelConfigs := cl.GetAllModelsConfigs()
galleryConfigs := map[string]*gallery.ModelConfig{}
installedBackends, err := gallery.ListSystemBackends(appConfig.SystemState)
if err != nil {
return err
}
for _, m := range modelConfigs {
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
if err != nil {
continue
}
galleryConfigs[m.Name] = cfg
}
loadedModels := ml.ListLoadedModels()
loadedModelsMap := map[string]bool{}
for _, m := range loadedModels {
loadedModelsMap[m.ID] = true
}
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
// Get model statuses to display in the UI the operation in progress
processingModels, taskTypes := opcache.GetStatus()
summary := fiber.Map{
"Title": "LocalAI - Settings & Management",
"Version": internal.PrintableVersion(),
"BaseURL": utils.BaseURL(c),
"Models": modelsWithoutConfig,
"ModelsConfig": modelConfigs,
"GalleryConfig": galleryConfigs,
"ApplicationConfig": appConfig,
"ProcessingModels": processingModels,
"TaskTypes": taskTypes,
"LoadedModels": loadedModelsMap,
"InstalledBackends": installedBackends,
}
// Render settings page
return c.Render("views/settings", summary)
}
}

View File

@@ -1,7 +1,7 @@
package localai package localai
import ( import (
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
@@ -9,11 +9,11 @@ import (
"github.com/mudler/LocalAI/pkg/store" "github.com/mudler/LocalAI/pkg/store"
) )
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input := new(schema.StoresSet) input := new(schema.StoresSet)
if err := c.Bind(input); err != nil { if err := c.BodyParser(input); err != nil {
return err return err
} }
@@ -28,20 +28,20 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
vals[i] = []byte(v) vals[i] = []byte(v)
} }
err = store.SetCols(c.Request().Context(), sb, input.Keys, vals) err = store.SetCols(c.Context(), sb, input.Keys, vals)
if err != nil { if err != nil {
return err return err
} }
return c.NoContent(200) return c.Send(nil)
} }
} }
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input := new(schema.StoresDelete) input := new(schema.StoresDelete)
if err := c.Bind(input); err != nil { if err := c.BodyParser(input); err != nil {
return err return err
} }
@@ -51,19 +51,19 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
} }
defer sl.Close() defer sl.Close()
if err := store.DeleteCols(c.Request().Context(), sb, input.Keys); err != nil { if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
return err return err
} }
return c.NoContent(200) return c.Send(nil)
} }
} }
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input := new(schema.StoresGet) input := new(schema.StoresGet)
if err := c.Bind(input); err != nil { if err := c.BodyParser(input); err != nil {
return err return err
} }
@@ -73,7 +73,7 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
} }
defer sl.Close() defer sl.Close()
keys, vals, err := store.GetCols(c.Request().Context(), sb, input.Keys) keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
if err != nil { if err != nil {
return err return err
} }
@@ -87,15 +87,15 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
res.Values[i] = string(v) res.Values[i] = string(v)
} }
return c.JSON(200, res) return c.JSON(res)
} }
} }
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input := new(schema.StoresFind) input := new(schema.StoresFind)
if err := c.Bind(input); err != nil { if err := c.BodyParser(input); err != nil {
return err return err
} }
@@ -105,7 +105,7 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
} }
defer sl.Close() defer sl.Close()
keys, vals, similarities, err := store.Find(c.Request().Context(), sb, input.Key, input.Topk) keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
if err != nil { if err != nil {
return err return err
} }
@@ -120,6 +120,6 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
res.Values[i] = string(v) res.Values[i] = string(v)
} }
return c.JSON(200, res) return c.JSON(res)
} }
} }

View File

@@ -1,7 +1,7 @@
package localai package localai
import ( import (
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
@@ -11,8 +11,8 @@ import (
// @Summary Show the LocalAI instance information // @Summary Show the LocalAI instance information
// @Success 200 {object} schema.SystemInformationResponse "Response" // @Success 200 {object} schema.SystemInformationResponse "Response"
// @Router /system [get] // @Router /system [get]
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
availableBackends := []string{} availableBackends := []string{}
loadedModels := ml.ListLoadedModels() loadedModels := ml.ListLoadedModels()
for b := range appConfig.ExternalGRPCBackends { for b := range appConfig.ExternalGRPCBackends {
@@ -26,7 +26,7 @@ func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConf
for _, m := range loadedModels { for _, m := range loadedModels {
sysmodels = append(sysmodels, schema.SysInfoModel{ID: m.ID}) sysmodels = append(sysmodels, schema.SysInfoModel{ID: m.ID})
} }
return c.JSON(200, return c.JSON(
schema.SystemInformationResponse{ schema.SystemInformationResponse{
Backends: availableBackends, Backends: availableBackends,
Models: sysmodels, Models: sysmodels,

View File

@@ -1,7 +1,7 @@
package localai package localai
import ( import (
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
@@ -14,22 +14,22 @@ import (
// @Param request body schema.TokenizeRequest true "Request" // @Param request body schema.TokenizeRequest true "Request"
// @Success 200 {object} schema.TokenizeResponse "Response" // @Success 200 {object} schema.TokenizeResponse "Response"
// @Router /v1/tokenize [post] // @Router /v1/tokenize [post]
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(ctx *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest) input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil { if !ok || cfg == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig) tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
if err != nil { if err != nil {
return err return err
} }
return c.JSON(200, tokenResponse) return ctx.JSON(tokenResponse)
} }
} }

View File

@@ -1,14 +1,12 @@
package localai package localai
import ( import (
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@@ -24,16 +22,16 @@ import (
// @Success 200 {string} binary "generated audio/wav file" // @Success 200 {string} binary "generated audio/wav file"
// @Router /v1/audio/speech [post] // @Router /v1/audio/speech [post]
// @Router /tts [post] // @Router /tts [post]
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil { if !ok || cfg == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request received") log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request received")
@@ -61,6 +59,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
return err return err
} }
return c.Attachment(filePath, filepath.Base(filePath)) return c.Download(filePath)
} }
} }

View File

@@ -1,7 +1,7 @@
package localai package localai
import ( import (
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
@@ -16,26 +16,26 @@ import (
// @Param request body schema.VADRequest true "query params" // @Param request body schema.VADRequest true "query params"
// @Success 200 {object} proto.VADResponse "Response" // @Success 200 {object} proto.VADResponse "Response"
// @Router /vad [post] // @Router /vad [post]
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil { if !ok || cfg == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request received") log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request received")
resp, err := backend.VAD(input, c.Request().Context(), ml, appConfig, *cfg) resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
if err != nil { if err != nil {
return err return err
} }
return c.JSON(200, resp) return c.JSON(resp)
} }
} }

View File

@@ -7,20 +7,19 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/gofiber/fiber/v2"
model "github.com/mudler/LocalAI/pkg/model" model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@@ -65,18 +64,18 @@ func downloadFile(url string) (string, error) {
// @Param request body schema.VideoRequest true "query params" // @Param request body schema.VideoRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response" // @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /video [post] // @Router /video [post]
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
log.Error().Msg("Video Endpoint - Invalid Input") log.Error().Msg("Video Endpoint - Invalid Input")
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil { if !ok || config == nil {
log.Error().Msg("Video Endpoint - Invalid Config") log.Error().Msg("Video Endpoint - Invalid Config")
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
src := "" src := ""
@@ -165,7 +164,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
return err return err
} }
baseURL := middleware.BaseURL(c) baseURL := c.BaseURL()
fn, err := backend.VideoGeneration( fn, err := backend.VideoGeneration(
height, height,
@@ -202,10 +201,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
item.B64JSON = base64.StdEncoding.EncodeToString(data) item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else { } else {
base := filepath.Base(output) base := filepath.Base(output)
item.URL, err = url.JoinPath(baseURL, "generated-videos", base) item.URL = baseURL + "/generated-videos/" + base
if err != nil {
return err
}
} }
id := uuid.New().String() id := uuid.New().String()
@@ -220,6 +216,6 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
log.Debug().Msgf("Response: %s", jsonResult) log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body // Return the prediction in the response body
return c.JSON(200, resp) return c.JSON(resp)
} }
} }

View File

@@ -1,20 +1,18 @@
package localai package localai
import ( import (
"strings" "github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
) )
func WelcomeEndpoint(appConfig *config.ApplicationConfig, func WelcomeEndpoint(appConfig *config.ApplicationConfig,
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) echo.HandlerFunc { cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
modelConfigs := cl.GetAllModelsConfigs() modelConfigs := cl.GetAllModelsConfigs()
galleryConfigs := map[string]*gallery.ModelConfig{} galleryConfigs := map[string]*gallery.ModelConfig{}
@@ -42,10 +40,10 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
// Get model statuses to display in the UI the operation in progress // Get model statuses to display in the UI the operation in progress
processingModels, taskTypes := opcache.GetStatus() processingModels, taskTypes := opcache.GetStatus()
summary := map[string]interface{}{ summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(), "Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(), "Version": internal.PrintableVersion(),
"BaseURL": middleware.BaseURL(c), "BaseURL": utils.BaseURL(c),
"Models": modelsWithoutConfig, "Models": modelsWithoutConfig,
"ModelsConfig": modelConfigs, "ModelsConfig": modelConfigs,
"GalleryConfig": galleryConfigs, "GalleryConfig": galleryConfigs,
@@ -56,21 +54,12 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
"InstalledBackends": installedBackends, "InstalledBackends": installedBackends,
} }
contentType := c.Request().Header.Get("Content-Type") if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
accept := c.Request().Header.Get("Accept")
// Default to HTML if Accept header is empty (browser behavior)
// Only return JSON if explicitly requested or Content-Type is application/json
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "text/html")) {
// The client expects a JSON response // The client expects a JSON response
return c.JSON(200, summary) return c.Status(fiber.StatusOK).JSON(summary)
} else { } else {
// Check if this is the manage route // Render index
templateName := "views/index" return c.Render("views/index", summary)
if strings.HasSuffix(c.Request().URL.Path, "/manage") || c.Request().URL.Path == "/manage" {
templateName = "views/manage"
}
// Render appropriate template
return c.Render(200, templateName, summary)
} }
} }
} }

View File

@@ -1,12 +1,14 @@
package openai package openai
import ( import (
"bufio"
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"time" "time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
@@ -17,6 +19,7 @@ import (
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
) )
// ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create // ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create
@@ -24,7 +27,7 @@ import (
// @Param request body schema.OpenAIRequest true "query params" // @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response" // @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/chat/completions [post] // @Router /v1/chat/completions [post]
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) echo.HandlerFunc { func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
var id, textContentToReturn string var id, textContentToReturn string
var created int var created int
@@ -33,7 +36,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
ID: id, ID: id,
Created: created, Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}}, Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &textContentToReturn}}},
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
} }
responses <- initialMessage responses <- initialMessage
@@ -53,7 +56,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
ID: id, ID: id,
Created: created, Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0, FinishReason: nil}}, Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Usage: usage, Usage: usage,
} }
@@ -87,7 +90,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
ID: id, ID: id,
Created: created, Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}}, Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &textContentToReturn}}},
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
} }
responses <- initialMessage responses <- initialMessage
@@ -111,7 +114,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
ID: id, ID: id,
Created: created, Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0, FinishReason: nil}}, Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Usage: usage, Usage: usage,
} }
@@ -139,10 +142,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}, },
}, },
}, },
}, }}},
Index: 0,
FinishReason: nil,
}},
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
} }
responses <- initialMessage responses <- initialMessage
@@ -165,10 +165,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}, },
}, },
}, },
}, }}},
Index: 0,
FinishReason: nil,
}},
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
} }
} }
@@ -178,21 +175,21 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
return err return err
} }
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
textContentToReturn = "" textContentToReturn = ""
id = uuid.New().String() id = uuid.New().String()
created = int(time.Now().Unix()) created = int(time.Now().Unix())
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
extraUsage := c.Request().Header.Get("Extra-Usage") != "" extraUsage := c.Get("Extra-Usage", "") != ""
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil { if !ok || config == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
log.Debug().Msgf("Chat endpoint configuration read: %+v", config) log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
@@ -220,7 +217,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
noActionDescription = config.FunctionsConfig.NoActionDescriptionName noActionDescription = config.FunctionsConfig.NoActionDescriptionName
} }
// If we are using a response format, we need to generate a grammar for it
if config.ResponseFormatMap != nil { if config.ResponseFormatMap != nil {
d := schema.ChatCompletionResponseFormat{} d := schema.ChatCompletionResponseFormat{}
dat, err := json.Marshal(config.ResponseFormatMap) dat, err := json.Marshal(config.ResponseFormatMap)
@@ -264,7 +260,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
} }
switch { switch {
// Generates grammar with internal's LocalAI engine
case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn: case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn:
noActionGrammar := functions.Function{ noActionGrammar := functions.Function{
Name: noActionName, Name: noActionName,
@@ -288,7 +283,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
funcs = funcs.Select(config.FunctionToCall()) funcs = funcs.Select(config.FunctionToCall())
} }
// Update input grammar or json_schema based on use_llama_grammar option // Update input grammar
jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey) jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey)
g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...) g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...)
if err == nil { if err == nil {
@@ -303,7 +298,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
} else { } else {
log.Error().Err(err).Msg("Failed generating grammar") log.Error().Err(err).Msg("Failed generating grammar")
} }
default: default:
// Force picking one of the functions by the request // Force picking one of the functions by the request
if config.FunctionToCall() != "" { if config.FunctionToCall() != "" {
@@ -322,7 +316,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
// If we are using the tokenizer template, we don't need to process the messages // If we are using the tokenizer template, we don't need to process the messages
// unless we are processing functions // unless we are processing functions
if !config.TemplateConfig.UseTokenizerTemplate { if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn) predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn)
log.Debug().Msgf("Prompt (after templating): %s", predInput) log.Debug().Msgf("Prompt (after templating): %s", predInput)
@@ -335,10 +329,13 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
case toStream: case toStream:
log.Debug().Msgf("Stream request received") log.Debug().Msgf("Stream request received")
c.Response().Header().Set("Content-Type", "text/event-stream") c.Context().SetContentType("text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
c.Response().Header().Set("Connection", "keep-alive") // c.Set("Content-Type", "text/event-stream")
c.Response().Header().Set("X-Correlation-ID", id) c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.Set("X-Correlation-ID", id)
responses := make(chan schema.OpenAIResponse) responses := make(chan schema.OpenAIResponse)
ended := make(chan error, 1) ended := make(chan error, 1)
@@ -351,17 +348,13 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
} }
}() }()
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{} usage := &schema.OpenAIUsage{}
toolsCalled := false toolsCalled := false
LOOP: LOOP:
for { for {
select { select {
case <-input.Context.Done():
// Context was cancelled (client disconnected or request cancelled)
log.Debug().Msgf("Request context cancelled, stopping stream")
input.Cancel()
break LOOP
case ev := <-responses: case ev := <-responses:
if len(ev.Choices) == 0 { if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping") log.Debug().Msgf("No choices in the response, skipping")
@@ -371,60 +364,50 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
if len(ev.Choices[0].Delta.ToolCalls) > 0 { if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true toolsCalled = true
} }
respData, err := json.Marshal(ev) var buf bytes.Buffer
if err != nil { enc := json.NewEncoder(&buf)
log.Debug().Msgf("Failed to marshal response: %v", err) enc.Encode(ev)
input.Cancel() log.Debug().Msgf("Sending chunk: %s", buf.String())
continue _, err := fmt.Fprintf(w, "data: %v\n", buf.String())
}
log.Debug().Msgf("Sending chunk: %s", string(respData))
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
if err != nil { if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err) log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel() input.Cancel()
return err
} }
c.Response().Flush() w.Flush()
case err := <-ended: case err := <-ended:
if err == nil { if err == nil {
break LOOP break LOOP
} }
log.Error().Msgf("Stream ended with error: %v", err) log.Error().Msgf("Stream ended with error: %v", err)
stopReason := FinishReasonStop
resp := &schema.OpenAIResponse{ resp := &schema.OpenAIResponse{
ID: id, ID: id,
Created: created, Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{ Choices: []schema.Choice{
{ {
FinishReason: &stopReason, FinishReason: "stop",
Index: 0, Index: 0,
Delta: &schema.Message{Content: "Internal error: " + err.Error()}, Delta: &schema.Message{Content: "Internal error: " + err.Error()},
}}, }},
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Usage: *usage, Usage: *usage,
} }
respData, marshalErr := json.Marshal(resp) respData, _ := json.Marshal(resp)
if marshalErr != nil {
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
// Send a simple error message as fallback
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
}
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
return
} }
} }
finishReason := FinishReasonStop finishReason := "stop"
if toolsCalled && len(input.Tools) > 0 { if toolsCalled && len(input.Tools) > 0 {
finishReason = FinishReasonToolCalls finishReason = "tool_calls"
} else if toolsCalled { } else if toolsCalled {
finishReason = FinishReasonFunctionCall finishReason = "function_call"
} }
resp := &schema.OpenAIResponse{ resp := &schema.OpenAIResponse{
@@ -433,19 +416,21 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{ Choices: []schema.Choice{
{ {
FinishReason: &finishReason, FinishReason: finishReason,
Index: 0, Index: 0,
Delta: &schema.Message{}, Delta: &schema.Message{Content: &textContentToReturn},
}}, }},
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Usage: *usage, Usage: *usage,
} }
respData, _ := json.Marshal(resp) respData, _ := json.Marshal(resp)
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") w.WriteString("data: [DONE]\n\n")
c.Response().Flush() w.Flush()
log.Debug().Msgf("Stream ended") log.Debug().Msgf("Stream ended")
}))
return nil return nil
// no streaming mode // no streaming mode
@@ -454,8 +439,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
tokenCallback := func(s string, c *[]schema.Choice) { tokenCallback := func(s string, c *[]schema.Choice) {
if !shouldUseFn { if !shouldUseFn {
// no function is called, just reply and use stop as finish reason // no function is called, just reply and use stop as finish reason
stopReason := FinishReasonStop *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
*c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return return
} }
@@ -473,14 +457,12 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
return return
} }
stopReason := FinishReasonStop
*c = append(*c, schema.Choice{ *c = append(*c, schema.Choice{
FinishReason: &stopReason, FinishReason: "stop",
Message: &schema.Message{Role: "assistant", Content: &result}}) Message: &schema.Message{Role: "assistant", Content: &result}})
default: default:
toolCallsReason := FinishReasonToolCalls
toolChoice := schema.Choice{ toolChoice := schema.Choice{
FinishReason: &toolCallsReason, FinishReason: "tool_calls",
Message: &schema.Message{ Message: &schema.Message{
Role: "assistant", Role: "assistant",
}, },
@@ -504,9 +486,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
) )
} else { } else {
// otherwise we return more choices directly (deprecated) // otherwise we return more choices directly (deprecated)
functionCallReason := FinishReasonFunctionCall
*c = append(*c, schema.Choice{ *c = append(*c, schema.Choice{
FinishReason: &functionCallReason, FinishReason: "function_call",
Message: &schema.Message{ Message: &schema.Message{
Role: "assistant", Role: "assistant",
Content: &textContentToReturn, Content: &textContentToReturn,
@@ -527,9 +508,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
} }
// Echo properly supports context cancellation via c.Request().Context()
// No workaround needed!
result, tokenUsage, err := ComputeChoices( result, tokenUsage, err := ComputeChoices(
input, input,
predInput, predInput,
@@ -565,7 +543,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
log.Debug().Msgf("Response: %s", respData) log.Debug().Msgf("Response: %s", respData)
// Return the prediction in the response body // Return the prediction in the response body
return c.JSON(200, resp) return c.JSON(resp)
} }
} }
} }
@@ -619,48 +597,7 @@ func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, in
audios = append(audios, m.StringAudios...) audios = append(audios, m.StringAudios...)
} }
// Serialize tools and tool_choice to JSON strings predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil)
toolsJSON := ""
if len(input.Tools) > 0 {
toolsBytes, err := json.Marshal(input.Tools)
if err == nil {
toolsJSON = string(toolsBytes)
}
}
toolChoiceJSON := ""
if input.ToolsChoice != nil {
toolChoiceBytes, err := json.Marshal(input.ToolsChoice)
if err == nil {
toolChoiceJSON = string(toolChoiceBytes)
}
}
// Extract logprobs from request
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
var logprobs *int
var topLogprobs *int
if input.Logprobs.IsEnabled() {
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
if input.TopLogprobs != nil {
topLogprobs = input.TopLogprobs
// For backend compatibility, set logprobs to the top_logprobs value
logprobs = input.TopLogprobs
} else {
// Default to 1 if logprobs is true but top_logprobs not specified
val := 1
logprobs = &val
topLogprobs = &val
}
}
// Extract logit_bias from request
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
var logitBias map[string]float64
if len(input.LogitBias) > 0 {
logitBias = input.LogitBias
}
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
if err != nil { if err != nil {
log.Error().Err(err).Msg("model inference failed") log.Error().Err(err).Msg("model inference failed")
return "", err return "", err

View File

@@ -1,22 +1,25 @@
package openai package openai
import ( import (
"bufio"
"bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"time" "time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
) )
// CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions // CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions
@@ -24,7 +27,7 @@ import (
// @Param request body schema.OpenAIRequest true "query params" // @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response" // @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post] // @Router /v1/completions [post]
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool { tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
created := int(time.Now().Unix()) created := int(time.Now().Unix())
@@ -46,7 +49,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
{ {
Index: 0, Index: 0,
Text: s, Text: s,
FinishReason: nil,
}, },
}, },
Object: "text_completion", Object: "text_completion",
@@ -62,25 +64,22 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
return err return err
} }
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
created := int(time.Now().Unix()) created := int(time.Now().Unix())
// Handle Correlation // Handle Correlation
id := c.Request().Header.Get("X-Correlation-ID") id := c.Get("X-Correlation-ID", uuid.New().String())
if id == "" { extraUsage := c.Get("Extra-Usage", "") != ""
id = uuid.New().String()
}
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil { if !ok || config == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
if config.ResponseFormatMap != nil { if config.ResponseFormatMap != nil {
@@ -98,10 +97,15 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
if input.Stream { if input.Stream {
log.Debug().Msgf("Stream request received") log.Debug().Msgf("Stream request received")
c.Response().Header().Set("Content-Type", "text/event-stream") c.Context().SetContentType("text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
c.Response().Header().Set("Connection", "keep-alive") //c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
if input.Stream {
if len(config.PromptStrings) > 1 { if len(config.PromptStrings) > 1 {
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
} }
@@ -126,6 +130,8 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
ended <- process(id, predInput, input, config, ml, responses, extraUsage) ended <- process(id, predInput, input, config, ml, responses, extraUsage)
}() }()
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
LOOP: LOOP:
for { for {
select { select {
@@ -134,52 +140,24 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
log.Debug().Msgf("No choices in the response, skipping") log.Debug().Msgf("No choices in the response, skipping")
continue continue
} }
respData, err := json.Marshal(ev) var buf bytes.Buffer
if err != nil { enc := json.NewEncoder(&buf)
log.Debug().Msgf("Failed to marshal response: %v", err) enc.Encode(ev)
continue
}
log.Debug().Msgf("Sending chunk: %s", string(respData)) log.Debug().Msgf("Sending chunk: %s", buf.String())
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData)) fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil { w.Flush()
return err
}
c.Response().Flush()
case err := <-ended: case err := <-ended:
if err == nil { if err == nil {
break LOOP break LOOP
} }
log.Error().Msgf("Stream ended with error: %v", err) log.Error().Msgf("Stream ended with error: %v", err)
fmt.Fprintf(w, "data: %v\n", "Internal error: "+err.Error())
stopReason := FinishReasonStop w.Flush()
errorResp := schema.OpenAIResponse{ break LOOP
ID: id,
Created: created,
Model: input.Model,
Choices: []schema.Choice{
{
Index: 0,
FinishReason: &stopReason,
Text: "Internal error: " + err.Error(),
},
},
Object: "text_completion",
}
errorData, marshalErr := json.Marshal(errorResp)
if marshalErr != nil {
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
// Send a simple error message as fallback
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
}
c.Response().Flush()
return nil
} }
} }
stopReason := FinishReasonStop
resp := &schema.OpenAIResponse{ resp := &schema.OpenAIResponse{
ID: id, ID: id,
Created: created, Created: created,
@@ -187,17 +165,18 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
Choices: []schema.Choice{ Choices: []schema.Choice{
{ {
Index: 0, Index: 0,
FinishReason: &stopReason, FinishReason: "stop",
}, },
}, },
Object: "text_completion", Object: "text_completion",
} }
respData, _ := json.Marshal(resp) respData, _ := json.Marshal(resp)
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") w.WriteString("data: [DONE]\n\n")
c.Response().Flush() w.Flush()
return nil }))
return <-ended
} }
var result []schema.Choice var result []schema.Choice
@@ -218,8 +197,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
r, tokenUsage, err := ComputeChoices( r, tokenUsage, err := ComputeChoices(
input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) { input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
stopReason := FinishReasonStop *c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
*c = append(*c, schema.Choice{Text: s, FinishReason: &stopReason, Index: k})
}, nil) }, nil)
if err != nil { if err != nil {
return err return err
@@ -253,6 +231,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
log.Debug().Msgf("Response: %s", jsonResult) log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body // Return the prediction in the response body
return c.JSON(200, resp) return c.JSON(resp)
} }
} }

View File

@@ -1,8 +0,0 @@
package openai
// Finish reason constants for OpenAI API responses
const (
FinishReasonStop = "stop"
FinishReasonToolCalls = "tool_calls"
FinishReasonFunctionCall = "function_call"
)

View File

@@ -4,11 +4,11 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
@@ -23,20 +23,20 @@ import (
// @Param request body schema.OpenAIRequest true "query params" // @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response" // @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/edits [post] // @Router /v1/edits [post]
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
// Opt-in extra usage flag // Opt-in extra usage flag
extraUsage := c.Request().Header.Get("Extra-Usage") != "" extraUsage := c.Get("Extra-Usage", "") != ""
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil { if !ok || config == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
log.Debug().Msgf("Edit Endpoint Input : %+v", input) log.Debug().Msgf("Edit Endpoint Input : %+v", input)
@@ -98,6 +98,6 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
log.Debug().Msgf("Response: %s", jsonResult) log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body // Return the prediction in the response body
return c.JSON(200, resp) return c.JSON(resp)
} }
} }

View File

@@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"time" "time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
@@ -13,6 +12,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@@ -21,16 +21,16 @@ import (
// @Param request body schema.OpenAIRequest true "query params" // @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response" // @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/embeddings [post] // @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil { if !ok || config == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
log.Debug().Msgf("Parameter Config: %+v", config) log.Debug().Msgf("Parameter Config: %+v", config)
@@ -78,6 +78,6 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
log.Debug().Msgf("Response: %s", jsonResult) log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body // Return the prediction in the response body
return c.JSON(200, resp) return c.JSON(resp)
} }
} }

View File

@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
@@ -15,13 +14,13 @@ import (
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/gofiber/fiber/v2"
model "github.com/mudler/LocalAI/pkg/model" model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@@ -66,18 +65,18 @@ func downloadFile(url string) (string, error) {
// @Param request body schema.OpenAIRequest true "query params" // @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response" // @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/images/generations [post] // @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
log.Error().Msg("Image Endpoint - Invalid Input") log.Error().Msg("Image Endpoint - Invalid Input")
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil { if !ok || config == nil {
log.Error().Msg("Image Endpoint - Invalid Config") log.Error().Msg("Image Endpoint - Invalid Config")
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
// Process input images (for img2img/inpainting) // Process input images (for img2img/inpainting)
@@ -189,7 +188,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
return err return err
} }
baseURL := middleware.BaseURL(c) baseURL := c.BaseURL()
// Use the first input image as src if available, otherwise use the original src // Use the first input image as src if available, otherwise use the original src
inputSrc := src inputSrc := src
@@ -216,10 +215,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
item.B64JSON = base64.StdEncoding.EncodeToString(data) item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else { } else {
base := filepath.Base(output) base := filepath.Base(output)
item.URL, err = url.JoinPath(baseURL, "generated-images", base) item.URL = baseURL + "/generated-images/" + base
if err != nil {
return err
}
} }
result = append(result, *item) result = append(result, *item)
@@ -238,7 +234,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
log.Debug().Msgf("Response: %s", jsonResult) log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body // Return the prediction in the response body
return c.JSON(200, resp) return c.JSON(resp)
} }
} }

View File

@@ -1,8 +1,6 @@
package openai package openai
import ( import (
"encoding/json"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
@@ -39,50 +37,8 @@ func ComputeChoices(
audios = append(audios, m.StringAudios...) audios = append(audios, m.StringAudios...)
} }
// Serialize tools and tool_choice to JSON strings
toolsJSON := ""
if len(req.Tools) > 0 {
toolsBytes, err := json.Marshal(req.Tools)
if err == nil {
toolsJSON = string(toolsBytes)
}
}
toolChoiceJSON := ""
if req.ToolsChoice != nil {
toolChoiceBytes, err := json.Marshal(req.ToolsChoice)
if err == nil {
toolChoiceJSON = string(toolChoiceBytes)
}
}
// Extract logprobs from request
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
var logprobs *int
var topLogprobs *int
if req.Logprobs.IsEnabled() {
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
if req.TopLogprobs != nil {
topLogprobs = req.TopLogprobs
// For backend compatibility, set logprobs to the top_logprobs value
logprobs = req.TopLogprobs
} else {
// Default to 1 if logprobs is true but top_logprobs not specified
val := 1
logprobs = &val
topLogprobs = &val
}
}
// Extract logit_bias from request
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
var logitBias map[string]float64
if len(req.LogitBias) > 0 {
logitBias = req.LogitBias
}
// get the model function to call for the result // get the model function to call for the result
predFunc, err := backend.ModelInference( predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback)
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
if err != nil { if err != nil {
return result, backend.TokenUsage{}, err return result, backend.TokenUsage{}, err
} }
@@ -103,11 +59,6 @@ func ComputeChoices(
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response) finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result) cb(finetunedResponse, &result)
// Add logprobs to the last choice if present
if prediction.Logprobs != nil && len(result) > 0 {
result[len(result)-1].Logprobs = prediction.Logprobs
}
//result = append(result, Choice{Text: prediction}) //result = append(result, Choice{Text: prediction})
} }

View File

@@ -1,7 +1,7 @@
package openai package openai
import ( import (
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
@@ -12,15 +12,14 @@ import (
// @Summary List and describe the various models available in the API. // @Summary List and describe the various models available in the API.
// @Success 200 {object} schema.ModelsDataResponse "Response" // @Success 200 {object} schema.ModelsDataResponse "Response"
// @Router /v1/models [get] // @Router /v1/models [get]
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
// If blank, no filter is applied. // If blank, no filter is applied.
filter := c.QueryParam("filter") filter := c.Query("filter")
// By default, exclude any loose files that are already referenced by a configuration file. // By default, exclude any loose files that are already referenced by a configuration file.
var policy services.LooseFilePolicy var policy services.LooseFilePolicy
excludeConfigured := c.QueryParam("excludeConfigured") if c.QueryBool("excludeConfigured", true) {
if excludeConfigured == "" || excludeConfigured == "true" {
policy = services.SKIP_IF_CONFIGURED policy = services.SKIP_IF_CONFIGURED
} else { } else {
policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user? policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user?
@@ -42,7 +41,7 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
} }
return c.JSON(200, schema.ModelsDataResponse{ return c.JSON(schema.ModelsDataResponse{
Object: "list", Object: "list",
Data: dataModels, Data: dataModels,
}) })

View File

@@ -1,18 +1,17 @@
package openai package openai
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/core/templates"
@@ -26,27 +25,24 @@ import (
// @Param request body schema.OpenAIRequest true "query params" // @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response" // @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /mcp/v1/completions [post] // @Router /mcp/v1/completions [post]
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
// We do not support streaming mode (Yet?) // We do not support streaming mode (Yet?)
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
created := int(time.Now().Unix()) created := int(time.Now().Unix())
ctx := c.Request().Context() ctx := c.Context()
// Handle Correlation // Handle Correlation
id := c.Request().Header.Get("X-Correlation-ID") id := c.Get("X-Correlation-ID", uuid.New().String())
if id == "" {
id = uuid.New().String()
}
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil { if !ok || config == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
if config.MCP.Servers == "" && config.MCP.Stdio == "" { if config.MCP.Servers == "" && config.MCP.Stdio == "" {
@@ -54,15 +50,12 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
} }
// Get MCP config from model config // Get MCP config from model config
remote, stdio, err := config.MCP.MCPConfigFromYAML() remote, stdio := config.MCP.MCPConfigFromYAML()
if err != nil {
return fmt.Errorf("failed to get MCP config: %w", err)
}
// Check if we have tools in cache, or we have to have an initial connection // Check if we have tools in cache, or we have to have an initial connection
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio) sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
if err != nil { if err != nil {
return fmt.Errorf("failed to get MCP sessions: %w", err) return err
} }
if len(sessions) == 0 { if len(sessions) == 0 {
@@ -80,37 +73,46 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
if appConfig.ApiKeys != nil { if appConfig.ApiKeys != nil {
apiKey = appConfig.ApiKeys[0] apiKey = appConfig.ApiKeys[0]
} }
ctxWithCancellation, cancel := context.WithCancel(ctx)
defer cancel()
// TODO: instead of connecting to the API, we should just wire this internally // TODO: instead of connecting to the API, we should just wire this internally
// and act like completion.go. // and act like completion.go.
// We can do this as cogito expects an interface and we can create one that // We can do this as cogito expects an interface and we can create one that
// we satisfy to just call internally ComputeChoices // we satisfy to just call internally ComputeChoices
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port) defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
// Build cogito options using the consolidated method cogitoOpts := []cogito.Option{
cogitoOpts := config.BuildCogitoOptions()
cogitoOpts = append(
cogitoOpts,
cogito.WithContext(ctxWithCancellation),
cogito.WithMCPs(sessions...),
cogito.WithStatusCallback(func(s string) { cogito.WithStatusCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s) log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
}), }),
cogito.WithReasoningCallback(func(s string) { cogito.WithContext(ctx),
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s) cogito.WithMCPs(sessions...),
}), cogito.WithIterations(3), // default to 3 iterations
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool { cogito.WithMaxAttempts(3), // default to 3 attempts
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", t.Name, t.Reasoning, t.Arguments) cogito.WithForceReasoning(),
return true }
}),
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) { if config.Agent.EnableReasoning {
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, tool arguments: %+v", t.Name, t.Result, t.ToolArguments) cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
}), }
)
if config.Agent.EnablePlanning {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
}
if config.Agent.EnableMCPPrompts {
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
}
if config.Agent.EnablePlanReEvaluator {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
}
if config.Agent.MaxIterations != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithIterations(config.Agent.MaxIterations))
}
if config.Agent.MaxAttempts != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(config.Agent.MaxAttempts))
}
f, err := cogito.ExecuteTools( f, err := cogito.ExecuteTools(
defaultLLM, fragment, defaultLLM, fragment,
@@ -137,6 +139,6 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
log.Debug().Msgf("Response: %s", jsonResult) log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body // Return the prediction in the response body
return c.JSON(200, resp) return c.JSON(resp)
} }
} }

View File

@@ -10,11 +10,9 @@ import (
"sync" "sync"
"time" "time"
"net/http"
"github.com/go-audio/audio" "github.com/go-audio/audio"
"github.com/gorilla/websocket" "github.com/gofiber/fiber/v2"
"github.com/labstack/echo/v4" "github.com/gofiber/websocket/v2"
"github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types" "github.com/mudler/LocalAI/core/http/endpoints/openai/types"
@@ -169,50 +167,32 @@ type Model interface {
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
} }
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins
},
}
// TODO: Implement ephemeral keys to allow these endpoints to be used // TODO: Implement ephemeral keys to allow these endpoints to be used
func RealtimeSessions(application *application.Application) echo.HandlerFunc { func RealtimeSessions(application *application.Application) fiber.Handler {
return func(c echo.Context) error { return func(ctx *fiber.Ctx) error {
return c.NoContent(501) return ctx.SendStatus(501)
} }
} }
func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc { func RealtimeTranscriptionSession(application *application.Application) fiber.Handler {
return func(c echo.Context) error { return func(ctx *fiber.Ctx) error {
return c.NoContent(501) return ctx.SendStatus(501)
} }
} }
func Realtime(application *application.Application) echo.HandlerFunc { func Realtime(application *application.Application) fiber.Handler {
return func(c echo.Context) error { return websocket.New(registerRealtime(application))
ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
if err != nil {
return err
}
defer ws.Close()
// Extract query parameters from Echo context before passing to websocket handler
model := c.QueryParam("model")
if model == "" {
model = "gpt-4o"
}
intent := c.QueryParam("intent")
registerRealtime(application, model, intent)(ws)
return nil
}
} }
func registerRealtime(application *application.Application, model, intent string) func(c *websocket.Conn) { func registerRealtime(application *application.Application) func(c *websocket.Conn) {
return func(c *websocket.Conn) { return func(c *websocket.Conn) {
evaluator := application.TemplatesEvaluator() evaluator := application.TemplatesEvaluator()
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String()) log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
model := c.Query("model", "gpt-4o")
intent := c.Query("intent")
if intent != "transcription" { if intent != "transcription" {
sendNotImplemented(c, "Only transcription mode is supported which requires the intent=transcription parameter") sendNotImplemented(c, "Only transcription mode is supported which requires the intent=transcription parameter")
} }
@@ -1087,13 +1067,12 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st
// For example, the model might return a special token or JSON indicating a function call // For example, the model might return a special token or JSON indicating a function call
/* /*
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil, "", "", nil, nil, nil) predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) { result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
if !shouldUseFn { if !shouldUseFn {
// no function is called, just reply and use stop as finish reason // no function is called, just reply and use stop as finish reason
stopReason := FinishReasonStop *c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
*c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return return
} }
@@ -1120,8 +1099,7 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st
} }
if len(input.Tools) > 0 { if len(input.Tools) > 0 {
toolCallsReason := FinishReasonToolCalls toolChoice.FinishReason = "tool_calls"
toolChoice.FinishReason = &toolCallsReason
} }
for _, ss := range results { for _, ss := range results {
@@ -1142,9 +1120,8 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st
) )
} else { } else {
// otherwise we return more choices directly // otherwise we return more choices directly
functionCallReason := FinishReasonFunctionCall
*c = append(*c, schema.Choice{ *c = append(*c, schema.Choice{
FinishReason: &functionCallReason, FinishReason: "function_call",
Message: &schema.Message{ Message: &schema.Message{
Role: "assistant", Role: "assistant",
Content: &textContentToReturn, Content: &textContentToReturn,

View File

@@ -7,13 +7,13 @@ import (
"path" "path"
"path/filepath" "path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model" model "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@@ -24,19 +24,19 @@ import (
// @Param file formData file true "file" // @Param file formData file true "file"
// @Success 200 {object} map[string]string "Response" // @Success 200 {object} map[string]string "Response"
// @Router /v1/audio/transcriptions [post] // @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil { if !ok || config == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
diarize := c.FormValue("diarize") != "false" diarize := c.FormValue("diarize", "false") != "false"
// retrieve the file data from the request // retrieve the file data from the request
file, err := c.FormFile("file") file, err := c.FormFile("file")
@@ -76,6 +76,6 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
log.Debug().Msgf("Trascribed: %+v", tr) log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here // TODO: handle different outputs here
return c.JSON(http.StatusOK, tr) return c.Status(http.StatusOK).JSON(tr)
} }
} }

View File

@@ -6,7 +6,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
@@ -14,24 +14,20 @@ import (
model "github.com/mudler/LocalAI/pkg/model" model "github.com/mudler/LocalAI/pkg/model"
) )
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error { return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input == nil { if !ok || input == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
var raw map[string]interface{} var raw map[string]interface{}
body := make([]byte, 0) if body := c.Body(); len(body) > 0 {
if c.Request().Body != nil {
c.Request().Body.Read(body)
}
if len(body) > 0 {
_ = json.Unmarshal(body, &raw) _ = json.Unmarshal(body, &raw)
} }
// Build VideoRequest using shared mapper // Build VideoRequest using shared mapper
vr := MapOpenAIToVideo(input, raw) vr := MapOpenAIToVideo(input, raw)
// Place VideoRequest into context so localai.VideoEndpoint can consume it // Place VideoRequest into locals so localai.VideoEndpoint can consume it
c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr) c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
// Delegate to existing localai handler // Delegate to existing localai handler
return localai.VideoEndpoint(cl, ml, appConfig)(c) return localai.VideoEndpoint(cl, ml, appConfig)(c)
} }

View File

@@ -1,50 +1,48 @@
package http package http
import ( import (
"io/fs"
"net/http" "net/http"
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/favicon"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/mudler/LocalAI/core/explorer" "github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes" "github.com/mudler/LocalAI/core/http/routes"
"github.com/rs/zerolog/log"
) )
func Explorer(db *explorer.Database) *echo.Echo { func Explorer(db *explorer.Database) *fiber.App {
e := echo.New()
// Set renderer fiberCfg := fiber.Config{
e.Renderer = renderEngine() Views: renderEngine(),
// We disable the Fiber startup message as it does not conform to structured logging.
// Hide banner // We register a startup log line with connection information in the OnListen hook to keep things user friendly though
e.HideBanner = true DisableStartupMessage: false,
// Override default error handler
e.Pre(middleware.StripPathPrefix())
routes.RegisterExplorerRoutes(e, db)
// Favicon handler
e.GET("/favicon.svg", func(c echo.Context) error {
data, err := embedDirStatic.ReadFile("static/favicon.svg")
if err != nil {
return c.NoContent(http.StatusNotFound)
} }
c.Response().Header().Set("Content-Type", "image/svg+xml")
return c.Blob(http.StatusOK, "image/svg+xml", data)
})
// Static files - use fs.Sub to create a filesystem rooted at "static" app := fiber.New(fiberCfg)
staticFS, err := fs.Sub(embedDirStatic, "static")
if err != nil { app.Use(middleware.StripPathPrefix())
// Log error but continue - static files might not work routes.RegisterExplorerRoutes(app, db)
log.Error().Err(err).Msg("failed to create static filesystem")
} else { httpFS := http.FS(embedDirStatic)
e.StaticFS("/static", staticFS)
} app.Use(favicon.New(favicon.Config{
URL: "/favicon.svg",
FileSystem: httpFS,
File: "static/favicon.svg",
}))
app.Use("/static", filesystem.New(filesystem.Config{
Root: httpFS,
PathPrefix: "static",
Browse: true,
}))
// Define a custom 404 handler // Define a custom 404 handler
// Note: keep this at the bottom! // Note: keep this at the bottom!
e.GET("/*", notFoundHandler) app.Use(notFoundHandler)
return e return app
} }

View File

@@ -3,108 +3,50 @@ package middleware
import ( import (
"crypto/subtle" "crypto/subtle"
"errors" "errors"
"net/http"
"strings" "strings"
"github.com/labstack/echo/v4" "github.com/dave-gray101/v2keyauth"
"github.com/labstack/echo/v4/middleware" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/keyauth"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
) )
var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key") // This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
// Currently this requires an upstream patch - and feature patches are no longer accepted to v2
// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate.
// GetKeyAuthConfig returns Echo's KeyAuth middleware configuration func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) {
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) { customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key", "cookie:token"}, keyauth.ConfigDefault.AuthScheme)
// Create validator function
validator := getApiKeyValidationFunction(applicationConfig)
// Create error handler
errorHandler := getApiKeyErrorHandler(applicationConfig)
// Create Next function (skip middleware for certain requests)
skipper := getApiKeyRequiredFilterFunction(applicationConfig)
// Wrap it with our custom key lookup that checks multiple sources
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if len(applicationConfig.ApiKeys) == 0 {
return next(c)
}
// Skip if skipper says so
if skipper != nil && skipper(c) {
return next(c)
}
// Try to extract key from multiple sources
key, err := extractKeyFromMultipleSources(c)
if err != nil { if err != nil {
return errorHandler(err, c) return nil, err
} }
// Validate the key return &v2keyauth.Config{
valid, err := validator(key, c) CustomKeyLookup: customLookup,
if err != nil || !valid { Next: getApiKeyRequiredFilterFunction(applicationConfig),
return errorHandler(ErrMissingOrMalformedAPIKey, c) Validator: getApiKeyValidationFunction(applicationConfig),
} ErrorHandler: getApiKeyErrorHandler(applicationConfig),
AuthScheme: "Bearer",
// Store key in context for later use
c.Set("api_key", key)
return next(c)
}
}, nil }, nil
} }
// extractKeyFromMultipleSources checks multiple sources for the API key func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler {
// in order: Authorization header, x-api-key header, xi-api-key header, token cookie return func(ctx *fiber.Ctx, err error) error {
func extractKeyFromMultipleSources(c echo.Context) (string, error) { if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) {
// Check Authorization header first
auth := c.Request().Header.Get("Authorization")
if auth != "" {
// Check for Bearer scheme
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer "), nil
}
// If no Bearer prefix, return as-is (for backward compatibility)
return auth, nil
}
// Check x-api-key header
if key := c.Request().Header.Get("x-api-key"); key != "" {
return key, nil
}
// Check xi-api-key header
if key := c.Request().Header.Get("xi-api-key"); key != "" {
return key, nil
}
// Check token cookie
cookie, err := c.Cookie("token")
if err == nil && cookie != nil && cookie.Value != "" {
return cookie.Value, nil
}
return "", ErrMissingOrMalformedAPIKey
}
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error {
return func(err error, c echo.Context) error {
if errors.Is(err, ErrMissingOrMalformedAPIKey) {
if len(applicationConfig.ApiKeys) == 0 { if len(applicationConfig.ApiKeys) == 0 {
return nil // if no keys are set up, any error we get here is not an error. return ctx.Next() // if no keys are set up, any error we get here is not an error.
} }
c.Response().Header().Set("WWW-Authenticate", "Bearer") ctx.Set("WWW-Authenticate", "Bearer")
if applicationConfig.OpaqueErrors { if applicationConfig.OpaqueErrors {
return c.NoContent(http.StatusUnauthorized) return ctx.SendStatus(401)
} }
// Check if the request content type is JSON // Check if the request content type is JSON
contentType := c.Request().Header.Get("Content-Type") contentType := string(ctx.Context().Request.Header.ContentType())
if strings.Contains(contentType, "application/json") { if strings.Contains(contentType, "application/json") {
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ return ctx.Status(401).JSON(schema.ErrorResponse{
Error: &schema.APIError{ Error: &schema.APIError{
Message: "An authentication key is required", Message: "An authentication key is required",
Code: 401, Code: 401,
@@ -113,69 +55,50 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(err
}) })
} }
return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{ return ctx.Status(401).Render("views/login", fiber.Map{
"BaseURL": BaseURL(c), "BaseURL": utils.BaseURL(ctx),
}) })
} }
if applicationConfig.OpaqueErrors { if applicationConfig.OpaqueErrors {
return c.NoContent(http.StatusInternalServerError) return ctx.SendStatus(500)
} }
return err return err
} }
} }
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) { func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) {
if applicationConfig.UseSubtleKeyComparison { if applicationConfig.UseSubtleKeyComparison {
return func(key string, c echo.Context) (bool, error) { return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 { if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything return true, nil // If no keys are setup, accept everything
} }
for _, validKey := range applicationConfig.ApiKeys { for _, validKey := range applicationConfig.ApiKeys {
if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 { if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 {
return true, nil return true, nil
} }
} }
return false, ErrMissingOrMalformedAPIKey return false, v2keyauth.ErrMissingOrMalformedAPIKey
} }
} }
return func(key string, c echo.Context) (bool, error) { return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 { if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything return true, nil // If no keys are setup, accept everything
} }
for _, validKey := range applicationConfig.ApiKeys { for _, validKey := range applicationConfig.ApiKeys {
if key == validKey { if apiKey == validKey {
return true, nil return true, nil
} }
} }
return false, ErrMissingOrMalformedAPIKey return false, v2keyauth.ErrMissingOrMalformedAPIKey
} }
} }
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper { func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool {
return func(c echo.Context) bool {
path := c.Request().URL.Path
// Always skip authentication for static files
if strings.HasPrefix(path, "/static/") {
return true
}
// Always skip authentication for generated content
if strings.HasPrefix(path, "/generated-audio/") ||
strings.HasPrefix(path, "/generated-images/") ||
strings.HasPrefix(path, "/generated-videos/") {
return true
}
// Skip authentication for favicon
if path == "/favicon.svg" {
return true
}
// Handle GET request exemptions if enabled
if applicationConfig.DisableApiKeyRequirementForHttpGet { if applicationConfig.DisableApiKeyRequirementForHttpGet {
if c.Request().Method != http.MethodGet { return func(c *fiber.Ctx) bool {
if c.Method() != "GET" {
return false return false
} }
for _, rx := range applicationConfig.HttpGetExemptedEndpoints { for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
@@ -183,8 +106,8 @@ func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig
return true return true
} }
} }
}
return false return false
} }
}
return func(c *fiber.Ctx) bool { return false }
} }

View File

@@ -1,48 +0,0 @@
package middleware
import (
"strings"
"github.com/labstack/echo/v4"
)
// BaseURL returns the base URL for the given HTTP request context.
// It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path.
// The returned URL is guaranteed to end with `/`.
// The method should be used in conjunction with the StripPathPrefix middleware.
func BaseURL(c echo.Context) string {
path := c.Path()
origPath := c.Request().URL.Path
// Check if StripPathPrefix middleware stored the original path
if storedPath, ok := c.Get("_original_path").(string); ok && storedPath != "" {
origPath = storedPath
}
// Check X-Forwarded-Proto for scheme
scheme := "http"
if c.Request().Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
} else if c.Request().TLS != nil {
scheme = "https"
}
// Check X-Forwarded-Host for host
host := c.Request().Host
if forwardedHost := c.Request().Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
if path != origPath && strings.HasSuffix(origPath, path) && len(path) > 0 {
prefixLen := len(origPath) - len(path)
if prefixLen > 0 && prefixLen <= len(origPath) {
pathPrefix := origPath[:prefixLen]
if !strings.HasSuffix(pathPrefix, "/") {
pathPrefix += "/"
}
return scheme + "://" + host + pathPrefix
}
}
return scheme + "://" + host + "/"
}

View File

@@ -1,58 +0,0 @@
package middleware
import (
"net/http/httptest"
"github.com/labstack/echo/v4"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("BaseURL", func() {
Context("without prefix", func() {
It("should return base URL without prefix", func() {
app := echo.New()
actualURL := ""
// Register route - use the actual request path so routing works
routePath := "/hello/world"
app.GET(routePath, func(c echo.Context) error {
actualURL = BaseURL(c)
return nil
})
req := httptest.NewRequest("GET", "/hello/world", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualURL).To(Equal("http://example.com/"), "base URL")
})
})
Context("with prefix", func() {
It("should return base URL with prefix", func() {
app := echo.New()
actualURL := ""
// Register route with the stripped path (after middleware removes prefix)
routePath := "/hello/world"
app.GET(routePath, func(c echo.Context) error {
// Simulate what StripPathPrefix middleware does - store original path
c.Set("_original_path", "/myprefix/hello/world")
// Modify the request path to simulate prefix stripping
c.Request().URL.Path = "/hello/world"
actualURL = BaseURL(c)
return nil
})
// Make request with stripped path (middleware would have already processed it)
req := httptest.NewRequest("GET", "/hello/world", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualURL).To(Equal("http://example.com/myprefix/"), "base URL")
})
})
})

View File

@@ -0,0 +1,174 @@
package middleware
import (
"encoding/json"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/services"
"github.com/rs/zerolog/log"
)
// MetricsMiddleware creates a middleware that tracks API usage metrics
// Note: Uses CONTEXT_LOCALS_KEY_MODEL_NAME constant defined in request.go
func MetricsMiddleware(metricsStore services.MetricsStore) fiber.Handler {
return func(c *fiber.Ctx) error {
path := c.Path()
// Skip tracking for UI routes, static files, and non-API endpoints
if shouldSkipMetrics(path) {
return c.Next()
}
// Record start time
start := time.Now()
// Get endpoint category
endpoint := categorizeEndpoint(path)
// Continue with the request
err := c.Next()
// Record metrics after request completes
duration := time.Since(start)
success := err == nil && c.Response().StatusCode() < 400
// Extract model name from context (set by RequestExtractor middleware)
// Use the same constant as RequestExtractor
model := "unknown"
if modelVal, ok := c.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string); ok && modelVal != "" {
model = modelVal
log.Debug().Str("model", model).Str("endpoint", endpoint).Msg("Recording metrics for request")
} else {
// Fallback: try to extract from path params or query
model = extractModelFromRequest(c)
log.Debug().Str("model", model).Str("endpoint", endpoint).Msg("Recording metrics for request (fallback)")
}
// Extract backend from response headers if available
backend := string(c.Response().Header.Peek("X-LocalAI-Backend"))
// Record the request
metricsStore.RecordRequest(endpoint, model, backend, success, duration)
return err
}
}
// shouldSkipMetrics determines if a request should be excluded from metrics
func shouldSkipMetrics(path string) bool {
// Skip UI routes
skipPrefixes := []string{
"/views/",
"/static/",
"/browse/",
"/chat/",
"/text2image/",
"/tts/",
"/talk/",
"/models/edit/",
"/import-model",
"/settings",
"/api/models", // UI API endpoints
"/api/backends", // UI API endpoints
"/api/operations", // UI API endpoints
"/api/p2p", // UI API endpoints
"/api/metrics", // Metrics API itself
}
for _, prefix := range skipPrefixes {
if strings.HasPrefix(path, prefix) {
return true
}
}
// Also skip root path and other UI pages
if path == "/" || path == "/index" {
return true
}
return false
}
// categorizeEndpoint maps request paths to friendly endpoint categories
func categorizeEndpoint(path string) string {
// OpenAI-compatible endpoints
if strings.HasPrefix(path, "/v1/chat/completions") || strings.HasPrefix(path, "/chat/completions") {
return "chat"
}
if strings.HasPrefix(path, "/v1/completions") || strings.HasPrefix(path, "/completions") {
return "completions"
}
if strings.HasPrefix(path, "/v1/embeddings") || strings.HasPrefix(path, "/embeddings") {
return "embeddings"
}
if strings.HasPrefix(path, "/v1/images/generations") || strings.HasPrefix(path, "/images/generations") {
return "image-generation"
}
if strings.HasPrefix(path, "/v1/audio/transcriptions") || strings.HasPrefix(path, "/audio/transcriptions") {
return "transcriptions"
}
if strings.HasPrefix(path, "/v1/audio/speech") || strings.HasPrefix(path, "/audio/speech") {
return "text-to-speech"
}
if strings.HasPrefix(path, "/v1/models") || strings.HasPrefix(path, "/models") {
return "models"
}
// LocalAI-specific endpoints
if strings.HasPrefix(path, "/v1/internal") {
return "internal"
}
if strings.Contains(path, "/tts") {
return "text-to-speech"
}
if strings.Contains(path, "/stt") || strings.Contains(path, "/whisper") {
return "speech-to-text"
}
if strings.Contains(path, "/sound-generation") {
return "sound-generation"
}
// Default to the first path segment
parts := strings.Split(strings.Trim(path, "/"), "/")
if len(parts) > 0 {
return parts[0]
}
return "unknown"
}
// extractModelFromRequest attempts to extract the model name from the request
func extractModelFromRequest(c *fiber.Ctx) string {
// Try query parameter first
model := c.Query("model")
if model != "" {
return model
}
// Try to extract from JSON body for POST requests
if c.Method() == fiber.MethodPost {
// Read body
bodyBytes := c.Body()
if len(bodyBytes) > 0 {
// Parse JSON
var reqBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &reqBody); err == nil {
if modelVal, ok := reqBody["model"]; ok {
if modelStr, ok := modelVal.(string); ok {
return modelStr
}
}
}
}
}
// Try path parameter for endpoints like /models/:model
model = c.Params("model")
if model != "" {
return model
}
return "unknown"
}

View File

@@ -1,13 +0,0 @@
package middleware_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestMiddleware(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Middleware test suite")
}

View File

@@ -4,12 +4,10 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"strconv" "strconv"
"strings" "strings"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/services"
@@ -17,6 +15,8 @@ import (
"github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@@ -44,22 +44,21 @@ const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG" const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
// TODO: Refactor to not return error if unchanged // TODO: Refactor to not return error if unchanged
func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) { func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && model != "" { if ok && model != "" {
return return
} }
model = c.Param("model") model = ctx.Params("model")
if model == "" { if (model == "") && ctx.Query("model") != "" {
model = c.QueryParam("model") model = ctx.Query("model")
} }
if model == "" { if model == "" {
// Set model from bearer token, if available // Set model from bearer token, if available
auth := c.Request().Header.Get("Authorization") bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
bearer := strings.TrimPrefix(auth, "Bearer ") if bearer != "" {
if bearer != "" && bearer != auth {
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE) exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
if err == nil && exists { if err == nil && exists {
model = bearer model = bearer
@@ -67,72 +66,71 @@ func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) {
} }
} }
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model) ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
} }
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc { func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(ctx *fiber.Ctx) error {
return func(c echo.Context) error { re.setModelNameFromRequest(ctx)
re.setModelNameFromRequest(c) localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || localModelName == "" { if !ok || localModelName == "" {
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName) ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default") log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
} }
return next(c) return ctx.Next()
}
} }
} }
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc { func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) fiber.Handler {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(ctx *fiber.Ctx) error {
return func(c echo.Context) error { re.setModelNameFromRequest(ctx)
re.setModelNameFromRequest(c) localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if localModelName != "" { // Don't overwrite existing values if localModelName != "" { // Don't overwrite existing values
return next(c) return ctx.Next()
} }
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED) modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
if err != nil { if err != nil {
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()") log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
return next(c) return ctx.Next()
} }
if len(modelNames) == 0 { if len(modelNames) == 0 {
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed") log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
// This is non-fatal - making it so was breaking the case of direct installation of raw models // This is non-fatal - making it so was breaking the case of direct installation of raw models
// return errors.New("this endpoint requires at least one model to be installed") // return errors.New("this endpoint requires at least one model to be installed")
return next(c) return ctx.Next()
} }
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0]) ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model") log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
return next(c) return ctx.Next()
}
} }
} }
// TODO: If context and cancel above belong on all methods, move that part of above into here! // TODO: If context and cancel above belong on all methods, move that part of above into here!
// Otherwise, it's in its own method below for now // Otherwise, it's in its own method below for now
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc { func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler {
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(ctx *fiber.Ctx) error {
return func(c echo.Context) error {
input := initializer() input := initializer()
if input == nil { if input == nil {
return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body") return fmt.Errorf("unable to initialize body")
} }
if err := c.Bind(input); err != nil { if err := ctx.BodyParser(input); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err)) return fmt.Errorf("failed parsing request body: %w", err)
} }
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
if input.ModelName(nil) == "" { if input.ModelName(nil) == "" {
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && localModelName != "" { if ok && localModelName != "" {
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain") log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
input.ModelName(&localModelName) input.ModelName(&localModelName)
} }
} else {
// Update context locals with the model name from the request body
// This ensures downstream middleware (like metrics) can access it
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, input.ModelName(nil))
} }
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig) cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
@@ -145,47 +143,29 @@ func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIR
cfg.Model = input.ModelName(nil) cfg.Model = input.ModelName(nil)
} }
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return next(c) return ctx.Next()
}
} }
} }
func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error { func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" { if !ok || input.Model == "" {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil { if !ok || cfg == nil {
return echo.ErrBadRequest return fiber.ErrBadRequest
} }
// Extract or generate the correlation ID // Extract or generate the correlation ID
correlationID := c.Request().Header.Get("X-Correlation-ID") correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
if correlationID == "" { ctx.Set("X-Correlation-ID", correlationID)
correlationID = uuid.New().String()
}
c.Response().Header().Set("X-Correlation-ID", correlationID)
// Use the request context directly - Echo properly supports context cancellation!
// No need for workarounds like handleConnectionCancellation
reqCtx := c.Request().Context()
c1, cancel := context.WithCancel(re.applicationConfig.Context) c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Cancel when request context is cancelled (client disconnects)
go func() {
select {
case <-reqCtx.Done():
cancel()
case <-c1.Done():
// Already cancelled
}
}()
// Add the correlation ID to the new context // Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
@@ -202,10 +182,10 @@ func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error {
cfg.Model = input.Model cfg.Model = input.Model
} }
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return nil return ctx.Next()
} }
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {

View File

@@ -3,55 +3,34 @@ package middleware
import ( import (
"strings" "strings"
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
) )
// StripPathPrefix returns middleware that strips a path prefix from the request path. // StripPathPrefix returns a middleware that strips a path prefix from the request path.
// The path prefix is obtained from the X-Forwarded-Prefix HTTP request header. // The path prefix is obtained from the X-Forwarded-Prefix HTTP request header.
// This must be registered as Pre middleware (using e.Pre()) to modify the path before routing. func StripPathPrefix() fiber.Handler {
func StripPathPrefix() echo.MiddlewareFunc { return func(c *fiber.Ctx) error {
return func(next echo.HandlerFunc) echo.HandlerFunc { for _, prefix := range c.GetReqHeaders()["X-Forwarded-Prefix"] {
return func(c echo.Context) error {
prefixes := c.Request().Header.Values("X-Forwarded-Prefix")
originalPath := c.Request().URL.Path
for _, prefix := range prefixes {
if prefix != "" { if prefix != "" {
normalizedPrefix := prefix path := c.Path()
if !strings.HasSuffix(prefix, "/") { pos := len(prefix)
normalizedPrefix = prefix + "/"
}
if strings.HasPrefix(originalPath, normalizedPrefix) { if prefix[pos-1] == '/' {
// Update the request path by stripping the normalized prefix pos--
newPath := originalPath[len(normalizedPrefix):]
if newPath == "" {
newPath = "/"
}
// Ensure path starts with / for proper routing
if !strings.HasPrefix(newPath, "/") {
newPath = "/" + newPath
}
// Update the URL path - Echo's router uses URL.Path for routing
c.Request().URL.Path = newPath
c.Request().URL.RawPath = ""
// Update RequestURI to match the new path (needed for proper routing)
if c.Request().URL.RawQuery != "" {
c.Request().RequestURI = newPath + "?" + c.Request().URL.RawQuery
} else { } else {
c.Request().RequestURI = newPath prefix += "/"
} }
// Store original path for BaseURL utility
c.Set("_original_path", originalPath) if strings.HasPrefix(path, prefix) {
c.Path(path[pos:])
break break
} else if originalPath == prefix || originalPath == prefix+"/" { } else if prefix[:pos] == path {
// Redirect to prefix with trailing slash (use 302 to match test expectations) c.Redirect(prefix)
return c.Redirect(302, normalizedPrefix) return nil
} }
} }
} }
return next(c) return c.Next()
}
} }
} }

View File

@@ -2,133 +2,120 @@ package middleware
import ( import (
"net/http/httptest" "net/http/httptest"
"testing"
"github.com/labstack/echo/v4" "github.com/gofiber/fiber/v2"
. "github.com/onsi/ginkgo/v2" "github.com/stretchr/testify/require"
. "github.com/onsi/gomega"
) )
var _ = Describe("StripPathPrefix", func() { func TestStripPathPrefix(t *testing.T) {
var app *echo.Echo
var actualPath string var actualPath string
var appInitialized bool
BeforeEach(func() { app := fiber.New()
app.Use(StripPathPrefix())
app.Get("/hello/world", func(c *fiber.Ctx) error {
actualPath = c.Path()
return nil
})
app.Get("/", func(c *fiber.Ctx) error {
actualPath = c.Path()
return nil
})
for _, tc := range []struct {
name string
path string
prefixHeader []string
expectStatus int
expectPath string
}{
{
name: "without prefix and header",
path: "/hello/world",
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "without prefix and headers on root path",
path: "/",
expectStatus: 200,
expectPath: "/",
},
{
name: "without prefix but header",
path: "/hello/world",
prefixHeader: []string{"/otherprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix but non-matching header",
path: "/prefix/hello/world",
prefixHeader: []string{"/otherprefix/"},
expectStatus: 404,
},
{
name: "with prefix and matching header",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and 1st header matching",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix/", "/otherprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and 2nd header matching",
path: "/myprefix/hello/world",
prefixHeader: []string{"/otherprefix/", "/myprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and header not ending with slash",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and non-matching header not ending with slash",
path: "/myprefix-suffix/hello/world",
prefixHeader: []string{"/myprefix"},
expectStatus: 404,
},
{
name: "redirect when prefix does not end with a slash",
path: "/myprefix",
prefixHeader: []string{"/myprefix"},
expectStatus: 302,
expectPath: "/myprefix/",
},
} {
t.Run(tc.name, func(t *testing.T) {
actualPath = "" actualPath = ""
if !appInitialized { req := httptest.NewRequest("GET", tc.path, nil)
app = echo.New() if tc.prefixHeader != nil {
app.Pre(StripPathPrefix()) req.Header["X-Forwarded-Prefix"] = tc.prefixHeader
}
app.GET("/hello/world", func(c echo.Context) error { resp, err := app.Test(req, -1)
actualPath = c.Request().URL.Path
return nil
})
app.GET("/", func(c echo.Context) error { require.NoError(t, err)
actualPath = c.Request().URL.Path require.Equal(t, tc.expectStatus, resp.StatusCode, "response status code")
return nil
}) if tc.expectStatus == 200 {
appInitialized = true require.Equal(t, tc.expectPath, actualPath, "rewritten path")
} else if tc.expectStatus == 302 {
require.Equal(t, tc.expectPath, resp.Header.Get("Location"), "redirect location")
} }
}) })
}
Context("without prefix", func() { }
It("should not modify path when no header is present", func() {
req := httptest.NewRequest("GET", "/hello/world", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should not modify root path when no header is present", func() {
req := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/"), "rewritten path")
})
It("should not modify path when header does not match", func() {
req := httptest.NewRequest("GET", "/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
})
Context("with prefix", func() {
It("should return 404 when prefix does not match header", func() {
req := httptest.NewRequest("GET", "/prefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(404), "response status code")
})
It("should strip matching prefix from path", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should strip prefix when it matches the first header value", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/", "/otherprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should strip prefix when it matches the second header value", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/", "/myprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should strip prefix when header does not end with slash", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should return 404 when prefix does not match header without trailing slash", func() {
req := httptest.NewRequest("GET", "/myprefix-suffix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(404), "response status code")
})
It("should redirect when prefix does not end with a slash", func() {
req := httptest.NewRequest("GET", "/myprefix", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(302), "response status code")
Expect(rec.Header().Get("Location")).To(Equal("/myprefix/"), "redirect location")
})
})
})

Some files were not shown because too many files have changed in this diff Show More