Compare commits

..

1 Commits

Author SHA1 Message Date
Ettore Di Giacinto
15da375c4d wip
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
2025-10-28 15:56:58 +01:00
204 changed files with 5585 additions and 16596 deletions

View File

@@ -1,8 +0,0 @@
# .air.toml
[build]
cmd = "make build"
bin = "./local-ai"
args_bin = [ "--debug" ]
include_ext = ["go", "html", "yaml", "toml", "json", "txt", "md"]
exclude_dir = ["pkg/grpc/proto"]
delay = 1000

View File

@@ -7,8 +7,8 @@ import (
"slices"
"strings"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
cogito "github.com/mudler/cogito"
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
"github.com/mudler/cogito"
"github.com/mudler/cogito/structures"
"github.com/sashabaranov/go-openai/jsonschema"

39
.github/gallery-agent/go.mod vendored Normal file
View File

@@ -0,0 +1,39 @@
module github.com/go-skynet/LocalAI/.github/gallery-agent
go 1.24.1
require (
github.com/mudler/cogito v0.3.0
github.com/onsi/ginkgo/v2 v2.25.3
github.com/onsi/gomega v1.38.2
github.com/sashabaranov/go-openai v1.41.2
github.com/tmc/langchaingo v0.1.13
gopkg.in/yaml.v3 v3.0.1
)
require (
dario.cat/mergo v1.0.1 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.4.0 // indirect
github.com/Masterminds/sprig/v3 v3.3.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/jsonschema-go v0.3.0 // indirect
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/huandu/xstrings v1.5.0 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect
github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/modelcontextprotocol/go-sdk v1.0.0 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
github.com/spf13/cast v1.7.0 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.uber.org/automaxprocs v1.6.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.41.0 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.28.0 // indirect
golang.org/x/tools v0.36.0 // indirect
)

168
.github/gallery-agent/go.sum vendored Normal file
View File

@@ -0,0 +1,168 @@
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU=
github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0=
github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs=
github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw=
github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA=
github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v28.2.2+incompatible h1:CjwRSksz8Yo4+RmQ339Dp/D2tGO5JxwYeqtMOEe0LDw=
github.com/docker/docker v28.2.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw=
github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.3.0 h1:6AH2TxVNtk3IlvkkhjrtbUc4S8AvO0Xii0DxIygDg+Q=
github.com/google/jsonschema-go v0.3.0/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8=
github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI=
github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ81pIr0yLvtUWk2if982qA3F3QD6H4=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s=
github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ=
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
github.com/moby/go-archive v0.1.0/go.mod h1:G9B+YoujNohJmrIYFBpSd54GTUB4lt9S+xVQvsJyFuo=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU=
github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko=
github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs=
github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/modelcontextprotocol/go-sdk v1.0.0 h1:Z4MSjLi38bTgLrd/LjSmofqRqyBiVKRyQSJgw8q8V74=
github.com/modelcontextprotocol/go-sdk v1.0.0/go.mod h1:nYtYQroQ2KQiM0/SbyEPUWQ6xs4B95gJjEalc9AQyOs=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/mudler/cogito v0.3.0 h1:NbVAO3bLkK5oGSY0xq87jlz8C9OIsLW55s+8Hfzeu9s=
github.com/mudler/cogito v0.3.0/go.mod h1:abMwl+CUjCp87IufA2quZdZt0bbLaHHN79o17HbUKxU=
github.com/onsi/ginkgo/v2 v2.25.3 h1:Ty8+Yi/ayDAGtk4XxmmfUy4GabvM+MegeB4cDLRi6nw=
github.com/onsi/ginkgo/v2 v2.25.3/go.mod h1:43uiyQC4Ed2tkOzLsEYm7hnrb7UJTWHYNsuy3bG/snE=
github.com/onsi/gomega v1.38.2 h1:eZCjf2xjZAqe+LeWvKb5weQ+NcPwX84kqJ0cZNxok2A=
github.com/onsi/gomega v1.38.2/go.mod h1:W2MJcYxRGV63b418Ai34Ud0hEdTVXq9NW9+Sx6uXf3k=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw=
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/sashabaranov/go-openai v1.41.2 h1:vfPRBZNMpnqu8ELsclWcAvF19lDNgh1t6TVfFFOPiSM=
github.com/sashabaranov/go-openai v1.41.2/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/shirou/gopsutil/v4 v4.25.5 h1:rtd9piuSMGeU8g1RMXjZs9y9luK5BwtnG7dZaQUJAsc=
github.com/shirou/gopsutil/v4 v4.25.5/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w=
github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/testcontainers/testcontainers-go v0.38.0 h1:d7uEapLcv2P8AvH8ahLqDMMxda2W9gQN1nRbHS28HBw=
github.com/testcontainers/testcontainers-go v0.38.0/go.mod h1:C52c9MoHpWO+C4aqmgSU+hxlR5jlEayWtgYrb8Pzz1w=
github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU=
github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI=
github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk=
github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY=
github.com/tmc/langchaingo v0.1.13 h1:rcpMWBIi2y3B90XxfE4Ao8dhCQPVDMaNPnN5cGB1CaA=
github.com/tmc/langchaingo v0.1.13/go.mod h1:vpQ5NOIhpzxDfTZK9B6tf2GM/MoaHewPWM5KXXGh7hg=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 h1:Xs2Ncz0gNihqu9iosIZ5SkBbWo5T8JhhLJFMQL1qmLI=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0/go.mod h1:vy+2G/6NvVMpwGX/NyLqcC41fxepnuKHk16E6IZUcJc=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs=
go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg=
golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -51,7 +51,6 @@ type ModelFile struct {
Size int64
SHA256 string
IsReadme bool
URL string
}
// ModelDetails represents detailed information about a model
@@ -216,7 +215,6 @@ func (c *Client) GetModelDetails(repoID string) (*ModelDetails, error) {
}
// Process each file
baseURL := strings.TrimSuffix(c.baseURL, "/api/models")
for _, file := range files {
fileName := filepath.Base(file.Path)
isReadme := strings.Contains(strings.ToLower(fileName), "readme")
@@ -229,16 +227,11 @@ func (c *Client) GetModelDetails(repoID string) (*ModelDetails, error) {
sha256 = file.Oid
}
// Construct the full URL for the file
// Use /resolve/main/ for downloading files (handles LFS properly)
fileURL := fmt.Sprintf("%s/%s/resolve/main/%s", baseURL, repoID, file.Path)
modelFile := ModelFile{
Path: file.Path,
Size: file.Size,
SHA256: sha256,
IsReadme: isReadme,
URL: fileURL,
}
details.Files = append(details.Files, modelFile)

View File

@@ -1,7 +1,6 @@
package hfapi_test
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
@@ -9,7 +8,7 @@ import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
)
var _ = Describe("HuggingFace API Client", func() {
@@ -271,15 +270,6 @@ var _ = Describe("HuggingFace API Client", func() {
})
})
Context("when getting file SHA on remote model", func() {
It("should get file SHA successfully", func() {
sha, err := client.GetFileSHA(
"mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF", "localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf")
Expect(err).ToNot(HaveOccurred())
Expect(sha).To(Equal("4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4"))
})
})
Context("when listing files", func() {
BeforeEach(func() {
mockFilesResponse := `[
@@ -339,25 +329,23 @@ var _ = Describe("HuggingFace API Client", func() {
Context("when getting file SHA", func() {
BeforeEach(func() {
mockFilesResponse := `[
{
"type": "file",
"path": "model-Q4_K_M.gguf",
mockFileInfoResponse := `{
"path": "model-Q4_K_M.gguf",
"size": 1000000,
"oid": "abc123",
"lfs": {
"oid": "sha256:def456",
"size": 1000000,
"oid": "abc123",
"lfs": {
"oid": "def456789",
"size": 1000000,
"pointerSize": 135
}
"pointer": "version https://git-lfs.github.com/spec/v1",
"sha256": "def456789"
}
]`
}`
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/tree/main") {
if strings.Contains(r.URL.Path, "/paths-info") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(mockFilesResponse))
w.Write([]byte(mockFileInfoResponse))
} else {
w.WriteHeader(http.StatusNotFound)
}
@@ -375,29 +363,18 @@ var _ = Describe("HuggingFace API Client", func() {
It("should handle missing SHA gracefully", func() {
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/tree/main") {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`[
{
"type": "file",
"path": "file.txt",
"size": 100,
"oid": "file123"
}
]`))
} else {
w.WriteHeader(http.StatusNotFound)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"path": "file.txt", "size": 100}`))
}))
client.SetBaseURL(server.URL)
sha, err := client.GetFileSHA("test/model", "file.txt")
Expect(err).ToNot(HaveOccurred())
// When there's no LFS, it should return the OID
Expect(sha).To(Equal("file123"))
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no SHA256 found"))
Expect(sha).To(Equal(""))
})
})
@@ -462,13 +439,6 @@ var _ = Describe("HuggingFace API Client", func() {
Expect(details.ReadmeFile).ToNot(BeNil())
Expect(details.ReadmeFile.Path).To(Equal("README.md"))
Expect(details.ReadmeFile.IsReadme).To(BeTrue())
// Verify URLs are set for all files
baseURL := strings.TrimSuffix(server.URL, "/api/models")
for _, file := range details.Files {
expectedURL := fmt.Sprintf("%s/test/model/resolve/main/%s", baseURL, file.Path)
Expect(file.URL).To(Equal(expectedURL))
}
})
})

View File

@@ -11,5 +11,3 @@ func TestHfapi(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "HuggingFace API Suite")
}

View File

@@ -9,7 +9,7 @@ import (
"strings"
"time"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
)
// ProcessedModelFile represents a processed model file with additional metadata

View File

@@ -3,9 +3,9 @@ package main
import (
"fmt"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
openai "github.com/sashabaranov/go-openai"
jsonschema "github.com/sashabaranov/go-openai/jsonschema"
"github.com/go-skynet/LocalAI/.github/gallery-agent/hfapi"
"github.com/sashabaranov/go-openai"
"github.com/tmc/langchaingo/jsonschema"
)
// Get repository README from HF
@@ -13,7 +13,7 @@ type HFReadmeTool struct {
client *hfapi.Client
}
func (s *HFReadmeTool) Execute(args map[string]any) (string, error) {
func (s *HFReadmeTool) Run(args map[string]any) (string, error) {
q, ok := args["repository"].(string)
if !ok {
return "", fmt.Errorf("no query")

View File

@@ -1,10 +1,10 @@
name: Bump Backend dependencies
name: Bump dependencies
on:
schedule:
- cron: 0 20 * * *
workflow_dispatch:
jobs:
bump-backends:
bump:
strategy:
fail-fast: false
matrix:

View File

@@ -1,10 +1,10 @@
name: Bump Documentation
name: Bump dependencies
on:
schedule:
- cron: 0 20 * * *
workflow_dispatch:
jobs:
bump-docs:
bump:
strategy:
fail-fast: false
matrix:

View File

@@ -33,7 +33,7 @@ jobs:
run: |
CGO_ENABLED=0 make build
- name: rm
uses: appleboy/ssh-action@v1.2.3
uses: appleboy/ssh-action@v1.2.2
with:
host: ${{ secrets.EXPLORER_SSH_HOST }}
username: ${{ secrets.EXPLORER_SSH_USERNAME }}
@@ -53,7 +53,7 @@ jobs:
rm: true
target: ./local-ai
- name: restarting
uses: appleboy/ssh-action@v1.2.3
uses: appleboy/ssh-action@v1.2.2
with:
host: ${{ secrets.EXPLORER_SSH_HOST }}
username: ${{ secrets.EXPLORER_SSH_USERNAME }}

View File

@@ -2,7 +2,7 @@ name: Gallery Agent
on:
schedule:
- cron: '0 */3 * * *' # Run every 4 hours
- cron: '0 */1 * * *' # Run every 4 hours
workflow_dispatch:
inputs:
search_term:
@@ -39,6 +39,11 @@ jobs:
with:
go-version: '1.21'
- name: Build gallery agent
run: |
cd .github/gallery-agent
go mod download
go build -o gallery-agent .
- name: Run gallery agent
env:
@@ -51,7 +56,9 @@ jobs:
MAX_MODELS: ${{ github.event.inputs.max_models || '1' }}
run: |
export GALLERY_INDEX_PATH=$PWD/gallery/index.yaml
go run .github/gallery-agent
cd .github/gallery-agent
./gallery-agent
rm -rf gallery-agent
- name: Check for changes
id: check_changes

View File

@@ -30,7 +30,6 @@ Thank you for your interest in contributing to LocalAI! We appreciate your time
3. Install the required dependencies ( see https://localai.io/basics/build/#build-localai-locally )
4. Build LocalAI: `make build`
5. Run LocalAI: `./local-ai`
6. To Build and live reload: `make build-dev`
## Contributing
@@ -77,7 +76,7 @@ LOCALAI_IMAGE_TAG=test LOCALAI_IMAGE=local-ai-aio make run-e2e-aio
## Documentation
We are welcome the contribution of the documents, please open new PR or create a new issue. The documentation is available under `docs/` https://github.com/mudler/LocalAI/tree/master/docs
## Community and Communication
- You can reach out via the Github issue tracker.

View File

@@ -103,10 +103,6 @@ build-launcher: ## Build the launcher application
build-all: build build-launcher ## Build both server and launcher
build-dev: ## Run LocalAI in dev mode with live reload
@command -v air >/dev/null 2>&1 || go install github.com/air-verse/air@latest
air -c .air.toml
dev-dist:
$(GORELEASER) build --snapshot --clean

View File

@@ -43,7 +43,7 @@
> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/)
>
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
> [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🌍 Explorer](https://explorer.localai.io) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on
[![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white)](https://t.me/localaiofficial_bot)
[![tests](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[![Build and Release](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[![build container images](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[![Bump dependencies](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/localai)](https://artifacthub.io/packages/search?repo=localai)
@@ -116,8 +116,6 @@ For more installation options, see [Installer Options](https://localai.io/docs/a
<img src="https://img.shields.io/badge/Download-macOS-blue?style=for-the-badge&logo=apple&logoColor=white" alt="Download LocalAI for macOS"/>
</a>
> Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244
Or run with docker:
> **💡 Docker Run vs Docker Start**
@@ -202,7 +200,7 @@ local-ai run oci://localai/phi-2:latest
> ⚡ **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection).
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html), if you are interested in our roadmap items and future enhancements, you can see the [Issues labeled as Roadmap here](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap)
For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html)
## 📰 Latest project news

View File

@@ -154,10 +154,6 @@ message PredictOptions {
repeated string Videos = 45;
repeated string Audios = 46;
string CorrelationId = 47;
string Tools = 48; // JSON array of available tools/functions for tool calling
string ToolChoice = 49; // JSON string or object specifying tool choice behavior
int32 Logprobs = 50; // Number of top logprobs to return (maps to OpenAI logprobs parameter)
int32 TopLogprobs = 51; // Number of top logprobs to return per token (maps to OpenAI top_logprobs parameter)
}
// The response message containing the result
@@ -168,7 +164,6 @@ message Reply {
double timing_prompt_processing = 4;
double timing_token_generation = 5;
bytes audio = 6;
bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format
}
message GrammarTrigger {
@@ -387,11 +382,6 @@ message StatusResponse {
message Message {
string role = 1;
string content = 2;
// Optional fields for OpenAI-compatible message format
string name = 3; // Tool name (for tool messages)
string tool_call_id = 4; // Tool call ID (for tool messages)
string reasoning_content = 5; // Reasoning content (for thinking models)
string tool_calls = 6; // Tool calls as JSON string (for assistant messages with tool calls)
}
message DetectOptions {

View File

@@ -1,5 +1,5 @@
LLAMA_VERSION?=80deff3648b93727422461c41c7279ef1dac7452
LLAMA_VERSION?=5a4ff43e7dd049e35942bc3d12361dab2f155544
LLAMA_REPO?=https://github.com/ggerganov/llama.cpp
CMAKE_ARGS?=

View File

File diff suppressed because it is too large Load Diff

View File

@@ -8,7 +8,7 @@ JOBS?=$(shell nproc --ignore=1)
# whisper.cpp version
WHISPER_REPO?=https://github.com/ggml-org/whisper.cpp
WHISPER_CPP_VERSION?=d9b7613b34a343848af572cc14467fc5e82fc788
WHISPER_CPP_VERSION?=f16c12f3f55f5bd3d6ac8cf2f31ab90a42c884d5
SO_TARGET?=libgowhisper.so
CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF

View File

@@ -2,7 +2,6 @@
accelerate
torch
torchaudio
numpy>=1.24.0,<1.26.0
transformers
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster

View File

@@ -2,7 +2,6 @@
torch==2.6.0+cu118
torchaudio==2.6.0+cu118
transformers==4.46.3
numpy>=1.24.0,<1.26.0
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
accelerate

View File

@@ -1,7 +1,6 @@
torch
torchaudio
transformers
numpy>=1.24.0,<1.26.0
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
accelerate

View File

@@ -2,7 +2,6 @@
torch==2.6.0+rocm6.1
torchaudio==2.6.0+rocm6.1
transformers
numpy>=1.24.0,<1.26.0
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
accelerate

View File

@@ -3,7 +3,6 @@ intel-extension-for-pytorch==2.3.110+xpu
torch==2.3.1+cxx11.abi
torchaudio==2.3.1+cxx11.abi
transformers
numpy>=1.24.0,<1.26.0
# https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
accelerate

View File

@@ -2,6 +2,5 @@
torch
torchaudio
transformers
numpy>=1.24.0,<1.26.0
chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster
accelerate

View File

@@ -61,7 +61,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
if request.PipelineType != "": # Reuse the PipelineType field for language
kwargs['lang'] = request.PipelineType
self.model_name = model_name
self.model = Reranker(model_name, **kwargs)
self.model = Reranker(model_name, **kwargs)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
@@ -75,13 +75,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
documents.append(doc)
ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents))))
# Prepare results to return
cropped_results = ranked_results.top_k(request.top_n) if request.top_n > 0 else ranked_results
results = [
backend_pb2.DocumentResult(
index=res.doc_id,
text=res.text,
relevance_score=res.score
) for res in (cropped_results)
) for res in ranked_results.results
]
# Calculate the usage and total tokens

View File

@@ -76,7 +76,7 @@ class TestBackendServicer(unittest.TestCase):
)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
self.assertTrue(response.success)
rerank_response = stub.Rerank(request)
print(rerank_response.results[0])
self.assertIsNotNone(rerank_response.results)
@@ -87,60 +87,4 @@ class TestBackendServicer(unittest.TestCase):
print(err)
self.fail("Reranker service failed")
finally:
self.tearDown()
def test_rerank_omit_top_n(self):
"""
This method tests if the embeddings are generated successfully even top_n is omitted
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
request = backend_pb2.RerankRequest(
query="I love you",
documents=["I hate you", "I really like you"],
top_n=0 #
)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
self.assertTrue(response.success)
rerank_response = stub.Rerank(request)
print(rerank_response.results[0])
self.assertIsNotNone(rerank_response.results)
self.assertEqual(len(rerank_response.results), 2)
self.assertEqual(rerank_response.results[0].text, "I really like you")
self.assertEqual(rerank_response.results[1].text, "I hate you")
except Exception as err:
print(err)
self.fail("Reranker service failed")
finally:
self.tearDown()
def test_rerank_crop(self):
"""
This method tests top_n cropping
"""
try:
self.setUp()
with grpc.insecure_channel("localhost:50051") as channel:
stub = backend_pb2_grpc.BackendStub(channel)
request = backend_pb2.RerankRequest(
query="I love you",
documents=["I hate you", "I really like you", "I hate ignoring top_n"],
top_n=2
)
response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder"))
self.assertTrue(response.success)
rerank_response = stub.Rerank(request)
print(rerank_response.results[0])
self.assertIsNotNone(rerank_response.results)
self.assertEqual(len(rerank_response.results), 2)
self.assertEqual(rerank_response.results[0].text, "I really like you")
self.assertEqual(rerank_response.results[1].text, "I hate you")
except Exception as err:
print(err)
self.fail("Reranker service failed")
finally:
self.tearDown()
self.tearDown()

View File

@@ -22,15 +22,9 @@ func New(opts ...config.AppOption) (*Application, error) {
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.SystemState.Model.ModelsPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())
if err := application.start(); err != nil {
return nil, err
}
caps, err := xsysinfo.CPUCapabilities()
if err == nil {
log.Debug().Msgf("CPU capabilities: %v", caps)
}
gpus, err := xsysinfo.GPUs()
if err == nil {
@@ -62,12 +56,12 @@ func New(opts ...config.AppOption) (*Application, error) {
}
}
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
if err := coreStartup.InstallModels(options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
log.Error().Err(err).Msg("error installing models")
}
for _, backend := range options.ExternalBackends {
if err := coreStartup.InstallExternalBackends(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
log.Error().Err(err).Msg("error installing external backend")
}
}
@@ -158,6 +152,10 @@ func New(opts ...config.AppOption) (*Application, error) {
// Watch the configuration directory
startWatcher(options)
if err := application.start(); err != nil {
return nil, err
}
log.Info().Msg("core/startup process completed!")
return application, nil
}

View File

@@ -3,6 +3,7 @@ package backend
import (
"context"
"encoding/json"
"fmt"
"regexp"
"slices"
"strings"
@@ -25,7 +26,6 @@ type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
AudioOutput string
Logprobs *schema.Logprobs // Logprobs from the backend response
}
type TokenUsage struct {
@@ -35,7 +35,7 @@ type TokenUsage struct {
TimingTokenGeneration float64
}
func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (LLMResponse, error), error) {
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model
// Check if the modelFile exists, if it doesn't try to load it from the gallery
@@ -47,7 +47,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
if !slices.Contains(modelNames, c.Name) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
if err != nil {
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
//return nil, err
@@ -65,8 +65,29 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
var protoMessages []*proto.Message
// if we are using the tokenizer template, we need to convert the messages to proto messages
// unless the prompt has already been tokenized (non-chat endpoints + functions)
if c.TemplateConfig.UseTokenizerTemplate && len(messages) > 0 {
protoMessages = messages.ToProto()
if c.TemplateConfig.UseTokenizerTemplate && s == "" {
protoMessages = make([]*proto.Message, len(messages), len(messages))
for i, message := range messages {
protoMessages[i] = &proto.Message{
Role: message.Role,
}
switch ct := message.Content.(type) {
case string:
protoMessages[i].Content = ct
case []interface{}:
// If using the tokenizer template, in case of multimodal we want to keep the multimodal content as and return only strings here
data, _ := json.Marshal(ct)
resultData := []struct {
Text string `json:"text"`
}{}
json.Unmarshal(data, &resultData)
for _, r := range resultData {
protoMessages[i].Content += r.Text
}
default:
return nil, fmt.Errorf("unsupported type for schema.Message.Content for inference: %T", ct)
}
}
}
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
@@ -78,21 +99,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
opts.Images = images
opts.Videos = videos
opts.Audios = audios
opts.Tools = tools
opts.ToolChoice = toolChoice
if logprobs != nil {
opts.Logprobs = int32(*logprobs)
}
if topLogprobs != nil {
opts.TopLogprobs = int32(*topLogprobs)
}
if len(logitBias) > 0 {
// Serialize logit_bias map to JSON string for proto
logitBiasJSON, err := json.Marshal(logitBias)
if err == nil {
opts.LogitBias = string(logitBiasJSON)
}
}
tokenUsage := TokenUsage{}
@@ -124,7 +130,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
}
ss := ""
var logprobs *schema.Logprobs
var partialRune []byte
err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) {
@@ -136,14 +141,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration
tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing
// Parse logprobs from reply if present (collect from last chunk that has them)
if len(reply.Logprobs) > 0 {
var parsedLogprobs schema.Logprobs
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
logprobs = &parsedLogprobs
}
}
// Process complete runes and accumulate them
var completeRunes []byte
for len(partialRune) > 0 {
@@ -169,7 +166,6 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
return LLMResponse{
Response: ss,
Usage: tokenUsage,
Logprobs: logprobs,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
@@ -192,19 +188,9 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
response = c.TemplateConfig.ReplyPrefix + response
}
// Parse logprobs from reply if present
var logprobs *schema.Logprobs
if len(reply.Logprobs) > 0 {
var parsedLogprobs schema.Logprobs
if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil {
logprobs = &parsedLogprobs
}
}
return LLMResponse{
Response: response,
Usage: tokenUsage,
Logprobs: logprobs,
}, err
}
}

View File

@@ -212,7 +212,7 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
}
}
pbOpts := &pb.PredictOptions{
return &pb.PredictOptions{
Temperature: float32(*c.Temperature),
TopP: float32(*c.TopP),
NDraft: c.NDraft,
@@ -249,6 +249,4 @@ func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions
TailFreeSamplingZ: float32(*c.TFZ),
TypicalP: float32(*c.TypicalP),
}
// Logprobs and TopLogprobs are set by the caller if provided
return pbOpts
}

View File

@@ -1,7 +1,6 @@
package cli
import (
"context"
"encoding/json"
"fmt"
@@ -103,7 +102,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
}
modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallExternalBackends(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
if err != nil {
return err
}

View File

@@ -48,12 +48,10 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error {
appHTTP := http.Explorer(db)
signals.RegisterGracefulTerminationHandler(func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := appHTTP.Shutdown(ctx); err != nil {
if err := appHTTP.Shutdown(); err != nil {
log.Error().Err(err).Msg("error during shutdown")
}
})
return appHTTP.Start(e.Address)
return appHTTP.Listen(e.Address)
}

View File

@@ -1,14 +1,12 @@
package cli
import (
"context"
"encoding/json"
"errors"
"fmt"
cliContext "github.com/mudler/LocalAI/core/cli/context"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/startup"
@@ -80,12 +78,6 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
return err
}
galleryService := services.NewGalleryService(&config.ApplicationConfig{}, model.NewModelLoader(systemState, true))
err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState)
if err != nil {
return err
}
var galleries []config.Gallery
if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil {
log.Error().Err(err).Msg("unable to load galleries")
@@ -135,7 +127,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
}
modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
err = startup.InstallModels(galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
if err != nil {
return err
}

View File

@@ -232,5 +232,5 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
}
})
return appHTTP.Start(r.Address)
return appHTTP.Listen(r.Address)
}

View File

@@ -1,7 +1,6 @@
package worker
import (
"context"
"encoding/json"
"errors"
"fmt"
@@ -43,7 +42,7 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
log.Error().Err(err).Msg("failed loading galleries")
return "", err
}
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true)
err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
if err != nil {
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
return "", err

View File

@@ -1,17 +1,151 @@
package config
import (
"strings"
"github.com/mudler/LocalAI/pkg/xsysinfo"
"github.com/rs/zerolog/log"
gguf "github.com/gpustack/gguf-parser-go"
)
type familyType uint8
const (
Unknown familyType = iota
LLaMa3
CommandR
Phi3
ChatML
Mistral03
Gemma
DeepSeek2
)
const (
defaultContextSize = 1024
defaultNGPULayers = 99999999
)
type settingsConfig struct {
StopWords []string
TemplateConfig TemplateConfig
RepeatPenalty float64
}
// default settings to adopt with a given model family
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
Gemma: {
RepeatPenalty: 1.0,
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
TemplateConfig: TemplateConfig{
Chat: "{{.Input }}\n<start_of_turn>model\n",
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
Completion: "{{.Input}}",
},
},
DeepSeek2: {
StopWords: []string{"<end▁of▁sentence>"},
TemplateConfig: TemplateConfig{
ChatMessage: `{{if eq .RoleName "user" -}}User: {{.Content }}
{{ end -}}
{{if eq .RoleName "assistant" -}}Assistant: {{.Content}}<end▁of▁sentence>{{end}}
{{if eq .RoleName "system" -}}{{.Content}}
{{end -}}`,
Chat: "{{.Input -}}\nAssistant: ",
},
},
LLaMa3: {
StopWords: []string{"<|eot_id|>"},
TemplateConfig: TemplateConfig{
Chat: "<|begin_of_text|>{{.Input }}\n<|start_header_id|>assistant<|end_header_id|>",
ChatMessage: "<|start_header_id|>{{ .RoleName }}<|end_header_id|>\n\n{{.Content }}<|eot_id|>",
},
},
CommandR: {
TemplateConfig: TemplateConfig{
Chat: "{{.Input -}}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>",
Functions: `<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
You are a function calling AI model, you can call the following functions:
## Available Tools
{{range .Functions}}
- {"type": "function", "function": {"name": "{{.Name}}", "description": "{{.Description}}", "parameters": {{toJson .Parameters}} }}
{{end}}
When using a tool, reply with JSON, for instance {"name": "tool_name", "arguments": {"param1": "value1", "param2": "value2"}}
<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{.Input -}}`,
ChatMessage: `{{if eq .RoleName "user" -}}
<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
{{- else if eq .RoleName "system" -}}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
{{- else if eq .RoleName "assistant" -}}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
{{- else if eq .RoleName "tool" -}}
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{.Content}}<|END_OF_TURN_TOKEN|>
{{- else if .FunctionCall -}}
<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{{toJson .FunctionCall}}}<|END_OF_TURN_TOKEN|>
{{- end -}}`,
},
StopWords: []string{"<|END_OF_TURN_TOKEN|>"},
},
Phi3: {
TemplateConfig: TemplateConfig{
Chat: "{{.Input}}\n<|assistant|>",
ChatMessage: "<|{{ .RoleName }}|>\n{{.Content}}<|end|>",
Completion: "{{.Input}}",
},
StopWords: []string{"<|end|>", "<|endoftext|>"},
},
ChatML: {
TemplateConfig: TemplateConfig{
Chat: "{{.Input -}}\n<|im_start|>assistant",
Functions: `<|im_start|>system
You are a function calling AI model. You are provided with functions to execute. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools:
{{range .Functions}}
{'type': 'function', 'function': {'name': '{{.Name}}', 'description': '{{.Description}}', 'parameters': {{toJson .Parameters}} }}
{{end}}
For each function call return a json object with function name and arguments
<|im_end|>
{{.Input -}}
<|im_start|>assistant`,
ChatMessage: `<|im_start|>{{ .RoleName }}
{{ if .FunctionCall -}}
Function call:
{{ else if eq .RoleName "tool" -}}
Function response:
{{ end -}}
{{ if .Content -}}
{{.Content }}
{{ end -}}
{{ if .FunctionCall -}}
{{toJson .FunctionCall}}
{{ end -}}<|im_end|>`,
},
StopWords: []string{"<|im_end|>", "<dummy32000>", "</s>"},
},
Mistral03: {
TemplateConfig: TemplateConfig{
Chat: "{{.Input -}}",
Functions: `[AVAILABLE_TOOLS] [{{range .Functions}}{"type": "function", "function": {"name": "{{.Name}}", "description": "{{.Description}}", "parameters": {{toJson .Parameters}} }}{{end}} ] [/AVAILABLE_TOOLS]{{.Input }}`,
ChatMessage: `{{if eq .RoleName "user" -}}
[INST] {{.Content }} [/INST]
{{- else if .FunctionCall -}}
[TOOL_CALLS] {{toJson .FunctionCall}} [/TOOL_CALLS]
{{- else if eq .RoleName "tool" -}}
[TOOL_RESULTS] {{.Content}} [/TOOL_RESULTS]
{{- else -}}
{{ .Content -}}
{{ end -}}`,
},
StopWords: []string{"<|im_end|>", "<dummy32000>", "</tool_call>", "<|eot_id|>", "<|end_of_text|>", "</s>", "[/TOOL_CALLS]", "[/ACTIONS]"},
},
}
// this maps well known template used in HF to model families defined above
var knownTemplates = map[string]familyType{
`{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\n' + content + '<|im_end|>\n<|im_start|>assistant\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\n' }}{% endif %}{% endfor %}`: ChatML,
`{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}`: Mistral03,
}
func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
if defaultCtx == 0 && cfg.ContextSize == nil {
@@ -82,9 +216,81 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
cfg.Name = f.Metadata().Name
}
// Instruct to use template from llama.cpp
cfg.TemplateConfig.UseTokenizerTemplate = true
cfg.FunctionsConfig.GrammarConfig.NoGrammar = true
cfg.Options = append(cfg.Options, "use_jinja:true")
cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT")
family := identifyFamily(f)
if family == Unknown {
log.Debug().Msgf("guessDefaultsFromFile: %s", "family not identified")
return
}
// identify template
settings, ok := defaultsSettings[family]
if ok {
cfg.TemplateConfig = settings.TemplateConfig
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: guessed template %+v", cfg.TemplateConfig)
if len(cfg.StopWords) == 0 {
cfg.StopWords = settings.StopWords
}
if cfg.RepeatPenalty == 0.0 {
cfg.RepeatPenalty = settings.RepeatPenalty
}
} else {
log.Debug().Any("family", family).Msgf("guessDefaultsFromFile: no template found for family")
}
if cfg.HasTemplate() {
return
}
// identify from well known templates first, otherwise use the raw jinja template
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
if found {
// try to use the jinja template
cfg.TemplateConfig.JinjaTemplate = true
cfg.TemplateConfig.ChatMessage = chatTemplate.ValueString()
}
}
func identifyFamily(f *gguf.GGUFFile) familyType {
// identify from well known templates first
chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template")
if found && chatTemplate.ValueString() != "" {
if family, ok := knownTemplates[chatTemplate.ValueString()]; ok {
return family
}
}
// otherwise try to identify from the model properties
arch := f.Architecture().Architecture
eosTokenID := f.Tokenizer().EOSTokenID
bosTokenID := f.Tokenizer().BOSTokenID
isYI := arch == "llama" && bosTokenID == 1 && eosTokenID == 2
// WTF! Mistral0.3 and isYi have same bosTokenID and eosTokenID
llama3 := arch == "llama" && eosTokenID == 128009
commandR := arch == "command-r" && eosTokenID == 255001
qwen2 := arch == "qwen2"
phi3 := arch == "phi-3"
gemma := strings.HasPrefix(arch, "gemma") || strings.Contains(strings.ToLower(f.Metadata().Name), "gemma")
deepseek2 := arch == "deepseek2"
switch {
case deepseek2:
return DeepSeek2
case gemma:
return Gemma
case llama3:
return LLaMa3
case commandR:
return CommandR
case phi3:
return Phi3
case qwen2, isYI:
return ChatML
default:
return Unknown
}
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/cogito"
"gopkg.in/yaml.v3"
)
@@ -17,31 +16,30 @@ const (
RAND_SEED = -1
)
// @Description TTS configuration
type TTSConfig struct {
// Voice wav path or id
Voice string `yaml:"voice,omitempty" json:"voice,omitempty"`
Voice string `yaml:"voice" json:"voice"`
AudioPath string `yaml:"audio_path,omitempty" json:"audio_path,omitempty"`
AudioPath string `yaml:"audio_path" json:"audio_path"`
}
// @Description ModelConfig represents a model configuration
// ModelConfig represents a model configuration
type ModelConfig struct {
modelConfigFile string `yaml:"-" json:"-"`
schema.PredictionOptions `yaml:"parameters,omitempty" json:"parameters,omitempty"`
Name string `yaml:"name,omitempty" json:"name,omitempty"`
schema.PredictionOptions `yaml:"parameters" json:"parameters"`
Name string `yaml:"name" json:"name"`
F16 *bool `yaml:"f16,omitempty" json:"f16,omitempty"`
Threads *int `yaml:"threads,omitempty" json:"threads,omitempty"`
Debug *bool `yaml:"debug,omitempty" json:"debug,omitempty"`
Roles map[string]string `yaml:"roles,omitempty" json:"roles,omitempty"`
Embeddings *bool `yaml:"embeddings,omitempty" json:"embeddings,omitempty"`
Backend string `yaml:"backend,omitempty" json:"backend,omitempty"`
TemplateConfig TemplateConfig `yaml:"template,omitempty" json:"template,omitempty"`
KnownUsecaseStrings []string `yaml:"known_usecases,omitempty" json:"known_usecases,omitempty"`
F16 *bool `yaml:"f16" json:"f16"`
Threads *int `yaml:"threads" json:"threads"`
Debug *bool `yaml:"debug" json:"debug"`
Roles map[string]string `yaml:"roles" json:"roles"`
Embeddings *bool `yaml:"embeddings" json:"embeddings"`
Backend string `yaml:"backend" json:"backend"`
TemplateConfig TemplateConfig `yaml:"template" json:"template"`
KnownUsecaseStrings []string `yaml:"known_usecases" json:"known_usecases"`
KnownUsecases *ModelConfigUsecases `yaml:"-" json:"-"`
Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"`
Pipeline Pipeline `yaml:"pipeline" json:"pipeline"`
PromptStrings, InputStrings []string `yaml:"-" json:"-"`
InputToken [][]int `yaml:"-" json:"-"`
@@ -49,101 +47,96 @@ type ModelConfig struct {
ResponseFormat string `yaml:"-" json:"-"`
ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"`
FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"`
FunctionsConfig functions.FunctionsConfig `yaml:"function" json:"function"`
FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
FeatureFlag FeatureFlag `yaml:"feature_flags" json:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early.
// LLM configs (GPT4ALL, Llama.cpp, ...)
LLMConfig `yaml:",inline" json:",inline"`
// Diffusers
Diffusers Diffusers `yaml:"diffusers,omitempty" json:"diffusers,omitempty"`
Step int `yaml:"step,omitempty" json:"step,omitempty"`
Diffusers Diffusers `yaml:"diffusers" json:"diffusers"`
Step int `yaml:"step" json:"step"`
// GRPC Options
GRPC GRPC `yaml:"grpc,omitempty" json:"grpc,omitempty"`
GRPC GRPC `yaml:"grpc" json:"grpc"`
// TTS specifics
TTSConfig `yaml:"tts,omitempty" json:"tts,omitempty"`
TTSConfig `yaml:"tts" json:"tts"`
// CUDA
// Explicitly enable CUDA or not (some backends might need it)
CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"`
CUDA bool `yaml:"cuda" json:"cuda"`
DownloadFiles []File `yaml:"download_files,omitempty" json:"download_files,omitempty"`
DownloadFiles []File `yaml:"download_files" json:"download_files"`
Description string `yaml:"description,omitempty" json:"description,omitempty"`
Usage string `yaml:"usage,omitempty" json:"usage,omitempty"`
Description string `yaml:"description" json:"description"`
Usage string `yaml:"usage" json:"usage"`
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
Options []string `yaml:"options" json:"options"`
Overrides []string `yaml:"overrides" json:"overrides"`
MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"`
Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"`
MCP MCPConfig `yaml:"mcp" json:"mcp"`
Agent AgentConfig `yaml:"agent" json:"agent"`
}
// @Description MCP configuration
type MCPConfig struct {
Servers string `yaml:"remote,omitempty" json:"remote,omitempty"`
Stdio string `yaml:"stdio,omitempty" json:"stdio,omitempty"`
Servers string `yaml:"remote" json:"remote"`
Stdio string `yaml:"stdio" json:"stdio"`
}
// @Description Agent configuration
type AgentConfig struct {
MaxAttempts int `yaml:"max_attempts,omitempty" json:"max_attempts,omitempty"`
MaxIterations int `yaml:"max_iterations,omitempty" json:"max_iterations,omitempty"`
EnableReasoning bool `yaml:"enable_reasoning,omitempty" json:"enable_reasoning,omitempty"`
EnablePlanning bool `yaml:"enable_planning,omitempty" json:"enable_planning,omitempty"`
EnableMCPPrompts bool `yaml:"enable_mcp_prompts,omitempty" json:"enable_mcp_prompts,omitempty"`
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator,omitempty" json:"enable_plan_re_evaluator,omitempty"`
MaxAttempts int `yaml:"max_attempts" json:"max_attempts"`
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
EnableReasoning bool `yaml:"enable_reasoning" json:"enable_reasoning"`
EnablePlanning bool `yaml:"enable_planning" json:"enable_planning"`
EnableMCPPrompts bool `yaml:"enable_mcp_prompts" json:"enable_mcp_prompts"`
EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator" json:"enable_plan_re_evaluator"`
}
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) {
func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers]) {
var remote MCPGenericConfig[MCPRemoteServers]
var stdio MCPGenericConfig[MCPSTDIOServers]
if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil {
return remote, stdio, err
return remote, stdio
}
if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil {
return remote, stdio, err
return remote, stdio
}
return remote, stdio, nil
return remote, stdio
}
// @Description MCP generic configuration
type MCPGenericConfig[T any] struct {
Servers T `yaml:"mcpServers,omitempty" json:"mcpServers,omitempty"`
Servers T `yaml:"mcpServers" json:"mcpServers"`
}
type MCPRemoteServers map[string]MCPRemoteServer
type MCPSTDIOServers map[string]MCPSTDIOServer
// @Description MCP remote server configuration
type MCPRemoteServer struct {
URL string `json:"url,omitempty"`
Token string `json:"token,omitempty"`
URL string `json:"url"`
Token string `json:"token"`
}
// @Description MCP STDIO server configuration
type MCPSTDIOServer struct {
Args []string `json:"args,omitempty"`
Env map[string]string `json:"env,omitempty"`
Command string `json:"command,omitempty"`
Args []string `json:"args"`
Env map[string]string `json:"env"`
Command string `json:"command"`
}
// @Description Pipeline defines other models to use for audio-to-audio
// Pipeline defines other models to use for audio-to-audio
type Pipeline struct {
TTS string `yaml:"tts,omitempty" json:"tts,omitempty"`
LLM string `yaml:"llm,omitempty" json:"llm,omitempty"`
Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"`
VAD string `yaml:"vad,omitempty" json:"vad,omitempty"`
TTS string `yaml:"tts" json:"tts"`
LLM string `yaml:"llm" json:"llm"`
Transcription string `yaml:"transcription" json:"transcription"`
VAD string `yaml:"vad" json:"vad"`
}
// @Description File configuration for model downloads
type File struct {
Filename string `yaml:"filename,omitempty" json:"filename,omitempty"`
SHA256 string `yaml:"sha256,omitempty" json:"sha256,omitempty"`
URI downloader.URI `yaml:"uri,omitempty" json:"uri,omitempty"`
Filename string `yaml:"filename" json:"filename"`
SHA256 string `yaml:"sha256" json:"sha256"`
URI downloader.URI `yaml:"uri" json:"uri"`
}
type FeatureFlag map[string]*bool
@@ -155,136 +148,126 @@ func (ff FeatureFlag) Enabled(s string) bool {
return false
}
// @Description GRPC configuration
type GRPC struct {
Attempts int `yaml:"attempts,omitempty" json:"attempts,omitempty"`
AttemptsSleepTime int `yaml:"attempts_sleep_time,omitempty" json:"attempts_sleep_time,omitempty"`
Attempts int `yaml:"attempts" json:"attempts"`
AttemptsSleepTime int `yaml:"attempts_sleep_time" json:"attempts_sleep_time"`
}
// @Description Diffusers configuration
type Diffusers struct {
CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"`
PipelineType string `yaml:"pipeline_type,omitempty" json:"pipeline_type,omitempty"`
SchedulerType string `yaml:"scheduler_type,omitempty" json:"scheduler_type,omitempty"`
EnableParameters string `yaml:"enable_parameters,omitempty" json:"enable_parameters,omitempty"` // A list of comma separated parameters to specify
IMG2IMG bool `yaml:"img2img,omitempty" json:"img2img,omitempty"` // Image to Image Diffuser
ClipSkip int `yaml:"clip_skip,omitempty" json:"clip_skip,omitempty"` // Skip every N frames
ClipModel string `yaml:"clip_model,omitempty" json:"clip_model,omitempty"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder,omitempty" json:"clip_subfolder,omitempty"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net,omitempty" json:"control_net,omitempty"`
CUDA bool `yaml:"cuda" json:"cuda"`
PipelineType string `yaml:"pipeline_type" json:"pipeline_type"`
SchedulerType string `yaml:"scheduler_type" json:"scheduler_type"`
EnableParameters string `yaml:"enable_parameters" json:"enable_parameters"` // A list of comma separated parameters to specify
IMG2IMG bool `yaml:"img2img" json:"img2img"` // Image to Image Diffuser
ClipSkip int `yaml:"clip_skip" json:"clip_skip"` // Skip every N frames
ClipModel string `yaml:"clip_model" json:"clip_model"` // Clip model to use
ClipSubFolder string `yaml:"clip_subfolder" json:"clip_subfolder"` // Subfolder to use for clip model
ControlNet string `yaml:"control_net" json:"control_net"`
}
// @Description LLMConfig is a struct that holds the configuration that are generic for most of the LLM backends.
// LLMConfig is a struct that holds the configuration that are
// generic for most of the LLM backends.
type LLMConfig struct {
SystemPrompt string `yaml:"system_prompt,omitempty" json:"system_prompt,omitempty"`
TensorSplit string `yaml:"tensor_split,omitempty" json:"tensor_split,omitempty"`
MainGPU string `yaml:"main_gpu,omitempty" json:"main_gpu,omitempty"`
RMSNormEps float32 `yaml:"rms_norm_eps,omitempty" json:"rms_norm_eps,omitempty"`
NGQA int32 `yaml:"ngqa,omitempty" json:"ngqa,omitempty"`
PromptCachePath string `yaml:"prompt_cache_path,omitempty" json:"prompt_cache_path,omitempty"`
PromptCacheAll bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"`
PromptCacheRO bool `yaml:"prompt_cache_ro,omitempty" json:"prompt_cache_ro,omitempty"`
MirostatETA *float64 `yaml:"mirostat_eta,omitempty" json:"mirostat_eta,omitempty"`
MirostatTAU *float64 `yaml:"mirostat_tau,omitempty" json:"mirostat_tau,omitempty"`
Mirostat *int `yaml:"mirostat,omitempty" json:"mirostat,omitempty"`
NGPULayers *int `yaml:"gpu_layers,omitempty" json:"gpu_layers,omitempty"`
MMap *bool `yaml:"mmap,omitempty" json:"mmap,omitempty"`
MMlock *bool `yaml:"mmlock,omitempty" json:"mmlock,omitempty"`
LowVRAM *bool `yaml:"low_vram,omitempty" json:"low_vram,omitempty"`
Reranking *bool `yaml:"reranking,omitempty" json:"reranking,omitempty"`
Grammar string `yaml:"grammar,omitempty" json:"grammar,omitempty"`
StopWords []string `yaml:"stopwords,omitempty" json:"stopwords,omitempty"`
Cutstrings []string `yaml:"cutstrings,omitempty" json:"cutstrings,omitempty"`
ExtractRegex []string `yaml:"extract_regex,omitempty" json:"extract_regex,omitempty"`
TrimSpace []string `yaml:"trimspace,omitempty" json:"trimspace,omitempty"`
TrimSuffix []string `yaml:"trimsuffix,omitempty" json:"trimsuffix,omitempty"`
SystemPrompt string `yaml:"system_prompt" json:"system_prompt"`
TensorSplit string `yaml:"tensor_split" json:"tensor_split"`
MainGPU string `yaml:"main_gpu" json:"main_gpu"`
RMSNormEps float32 `yaml:"rms_norm_eps" json:"rms_norm_eps"`
NGQA int32 `yaml:"ngqa" json:"ngqa"`
PromptCachePath string `yaml:"prompt_cache_path" json:"prompt_cache_path"`
PromptCacheAll bool `yaml:"prompt_cache_all" json:"prompt_cache_all"`
PromptCacheRO bool `yaml:"prompt_cache_ro" json:"prompt_cache_ro"`
MirostatETA *float64 `yaml:"mirostat_eta" json:"mirostat_eta"`
MirostatTAU *float64 `yaml:"mirostat_tau" json:"mirostat_tau"`
Mirostat *int `yaml:"mirostat" json:"mirostat"`
NGPULayers *int `yaml:"gpu_layers" json:"gpu_layers"`
MMap *bool `yaml:"mmap" json:"mmap"`
MMlock *bool `yaml:"mmlock" json:"mmlock"`
LowVRAM *bool `yaml:"low_vram" json:"low_vram"`
Reranking *bool `yaml:"reranking" json:"reranking"`
Grammar string `yaml:"grammar" json:"grammar"`
StopWords []string `yaml:"stopwords" json:"stopwords"`
Cutstrings []string `yaml:"cutstrings" json:"cutstrings"`
ExtractRegex []string `yaml:"extract_regex" json:"extract_regex"`
TrimSpace []string `yaml:"trimspace" json:"trimspace"`
TrimSuffix []string `yaml:"trimsuffix" json:"trimsuffix"`
ContextSize *int `yaml:"context_size,omitempty" json:"context_size,omitempty"`
NUMA bool `yaml:"numa,omitempty" json:"numa,omitempty"`
LoraAdapter string `yaml:"lora_adapter,omitempty" json:"lora_adapter,omitempty"`
LoraBase string `yaml:"lora_base,omitempty" json:"lora_base,omitempty"`
LoraAdapters []string `yaml:"lora_adapters,omitempty" json:"lora_adapters,omitempty"`
LoraScales []float32 `yaml:"lora_scales,omitempty" json:"lora_scales,omitempty"`
LoraScale float32 `yaml:"lora_scale,omitempty" json:"lora_scale,omitempty"`
NoMulMatQ bool `yaml:"no_mulmatq,omitempty" json:"no_mulmatq,omitempty"`
DraftModel string `yaml:"draft_model,omitempty" json:"draft_model,omitempty"`
NDraft int32 `yaml:"n_draft,omitempty" json:"n_draft,omitempty"`
Quantization string `yaml:"quantization,omitempty" json:"quantization,omitempty"`
LoadFormat string `yaml:"load_format,omitempty" json:"load_format,omitempty"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization,omitempty" json:"gpu_memory_utilization,omitempty"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code,omitempty" json:"trust_remote_code,omitempty"` // vLLM
EnforceEager bool `yaml:"enforce_eager,omitempty" json:"enforce_eager,omitempty"` // vLLM
SwapSpace int `yaml:"swap_space,omitempty" json:"swap_space,omitempty"` // vLLM
MaxModelLen int `yaml:"max_model_len,omitempty" json:"max_model_len,omitempty"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size,omitempty" json:"tensor_parallel_size,omitempty"` // vLLM
DisableLogStatus bool `yaml:"disable_log_stats,omitempty" json:"disable_log_stats,omitempty"` // vLLM
DType string `yaml:"dtype,omitempty" json:"dtype,omitempty"` // vLLM
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt,omitempty" json:"limit_mm_per_prompt,omitempty"` // vLLM
MMProj string `yaml:"mmproj,omitempty" json:"mmproj,omitempty"`
ContextSize *int `yaml:"context_size" json:"context_size"`
NUMA bool `yaml:"numa" json:"numa"`
LoraAdapter string `yaml:"lora_adapter" json:"lora_adapter"`
LoraBase string `yaml:"lora_base" json:"lora_base"`
LoraAdapters []string `yaml:"lora_adapters" json:"lora_adapters"`
LoraScales []float32 `yaml:"lora_scales" json:"lora_scales"`
LoraScale float32 `yaml:"lora_scale" json:"lora_scale"`
NoMulMatQ bool `yaml:"no_mulmatq" json:"no_mulmatq"`
DraftModel string `yaml:"draft_model" json:"draft_model"`
NDraft int32 `yaml:"n_draft" json:"n_draft"`
Quantization string `yaml:"quantization" json:"quantization"`
LoadFormat string `yaml:"load_format" json:"load_format"`
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization" json:"gpu_memory_utilization"` // vLLM
TrustRemoteCode bool `yaml:"trust_remote_code" json:"trust_remote_code"` // vLLM
EnforceEager bool `yaml:"enforce_eager" json:"enforce_eager"` // vLLM
SwapSpace int `yaml:"swap_space" json:"swap_space"` // vLLM
MaxModelLen int `yaml:"max_model_len" json:"max_model_len"` // vLLM
TensorParallelSize int `yaml:"tensor_parallel_size" json:"tensor_parallel_size"` // vLLM
DisableLogStatus bool `yaml:"disable_log_stats" json:"disable_log_stats"` // vLLM
DType string `yaml:"dtype" json:"dtype"` // vLLM
LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt" json:"limit_mm_per_prompt"` // vLLM
MMProj string `yaml:"mmproj" json:"mmproj"`
FlashAttention *string `yaml:"flash_attention,omitempty" json:"flash_attention,omitempty"`
NoKVOffloading bool `yaml:"no_kv_offloading,omitempty" json:"no_kv_offloading,omitempty"`
CacheTypeK string `yaml:"cache_type_k,omitempty" json:"cache_type_k,omitempty"`
CacheTypeV string `yaml:"cache_type_v,omitempty" json:"cache_type_v,omitempty"`
FlashAttention *string `yaml:"flash_attention" json:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading" json:"no_kv_offloading"`
CacheTypeK string `yaml:"cache_type_k" json:"cache_type_k"`
CacheTypeV string `yaml:"cache_type_v" json:"cache_type_v"`
RopeScaling string `yaml:"rope_scaling,omitempty" json:"rope_scaling,omitempty"`
ModelType string `yaml:"type,omitempty" json:"type,omitempty"`
RopeScaling string `yaml:"rope_scaling" json:"rope_scaling"`
ModelType string `yaml:"type" json:"type"`
YarnExtFactor float32 `yaml:"yarn_ext_factor,omitempty" json:"yarn_ext_factor,omitempty"`
YarnAttnFactor float32 `yaml:"yarn_attn_factor,omitempty" json:"yarn_attn_factor,omitempty"`
YarnBetaFast float32 `yaml:"yarn_beta_fast,omitempty" json:"yarn_beta_fast,omitempty"`
YarnBetaSlow float32 `yaml:"yarn_beta_slow,omitempty" json:"yarn_beta_slow,omitempty"`
YarnExtFactor float32 `yaml:"yarn_ext_factor" json:"yarn_ext_factor"`
YarnAttnFactor float32 `yaml:"yarn_attn_factor" json:"yarn_attn_factor"`
YarnBetaFast float32 `yaml:"yarn_beta_fast" json:"yarn_beta_fast"`
YarnBetaSlow float32 `yaml:"yarn_beta_slow" json:"yarn_beta_slow"`
CFGScale float32 `yaml:"cfg_scale,omitempty" json:"cfg_scale,omitempty"` // Classifier-Free Guidance Scale
CFGScale float32 `yaml:"cfg_scale" json:"cfg_scale"` // Classifier-Free Guidance Scale
}
// @Description LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
// LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM
type LimitMMPerPrompt struct {
LimitImagePerPrompt int `yaml:"image,omitempty" json:"image,omitempty"`
LimitVideoPerPrompt int `yaml:"video,omitempty" json:"video,omitempty"`
LimitAudioPerPrompt int `yaml:"audio,omitempty" json:"audio,omitempty"`
LimitImagePerPrompt int `yaml:"image" json:"image"`
LimitVideoPerPrompt int `yaml:"video" json:"video"`
LimitAudioPerPrompt int `yaml:"audio" json:"audio"`
}
// @Description TemplateConfig is a struct that holds the configuration of the templating system
// TemplateConfig is a struct that holds the configuration of the templating system
type TemplateConfig struct {
// Chat is the template used in the chat completion endpoint
Chat string `yaml:"chat,omitempty" json:"chat,omitempty"`
Chat string `yaml:"chat" json:"chat"`
// ChatMessage is the template used for chat messages
ChatMessage string `yaml:"chat_message,omitempty" json:"chat_message,omitempty"`
ChatMessage string `yaml:"chat_message" json:"chat_message"`
// Completion is the template used for completion requests
Completion string `yaml:"completion,omitempty" json:"completion,omitempty"`
Completion string `yaml:"completion" json:"completion"`
// Edit is the template used for edit completion requests
Edit string `yaml:"edit,omitempty" json:"edit,omitempty"`
Edit string `yaml:"edit" json:"edit"`
// Functions is the template used when tools are present in the client requests
Functions string `yaml:"function,omitempty" json:"function,omitempty"`
Functions string `yaml:"function" json:"function"`
// UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used.
// Note: this is mostly consumed for backends such as vllm and transformers
// that can use the tokenizers specified in the JSON config files of the models
UseTokenizerTemplate bool `yaml:"use_tokenizer_template,omitempty" json:"use_tokenizer_template,omitempty"`
UseTokenizerTemplate bool `yaml:"use_tokenizer_template" json:"use_tokenizer_template"`
// JoinChatMessagesByCharacter is a string that will be used to join chat messages together.
// It defaults to \n
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character,omitempty" json:"join_chat_messages_by_character,omitempty"`
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character" json:"join_chat_messages_by_character"`
Multimodal string `yaml:"multimodal,omitempty" json:"multimodal,omitempty"`
Multimodal string `yaml:"multimodal" json:"multimodal"`
ReplyPrefix string `yaml:"reply_prefix,omitempty" json:"reply_prefix,omitempty"`
}
JinjaTemplate bool `yaml:"jinja_template" json:"jinja_template"`
func (c *ModelConfig) syncKnownUsecasesFromString() {
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
// Make sure the usecases are valid, we rewrite with what we identified
c.KnownUsecaseStrings = []string{}
for k, usecase := range GetAllModelConfigUsecases() {
if c.HasUsecases(usecase) {
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
}
}
ReplyPrefix string `yaml:"reply_prefix" json:"reply_prefix"`
}
func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
@@ -295,7 +278,14 @@ func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error {
}
*c = ModelConfig(aux)
c.syncKnownUsecasesFromString()
c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings)
// Make sure the usecases are valid, we rewrite with what we identified
c.KnownUsecaseStrings = []string{}
for k, usecase := range GetAllModelConfigUsecases() {
if c.HasUsecases(usecase) {
c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k)
}
}
return nil
}
@@ -472,7 +462,6 @@ func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) {
}
guessDefaultsFromFile(cfg, lo.modelPath, ctx)
cfg.syncKnownUsecasesFromString()
}
func (c *ModelConfig) Validate() bool {
@@ -503,7 +492,7 @@ func (c *ModelConfig) Validate() bool {
}
func (c *ModelConfig) HasTemplate() bool {
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" || c.TemplateConfig.UseTokenizerTemplate
return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != ""
}
func (c *ModelConfig) GetModelConfigFile() string {
@@ -584,7 +573,7 @@ func (c *ModelConfig) HasUsecases(u ModelConfigUsecases) bool {
// This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently.
func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
if (u & FLAG_CHAT) == FLAG_CHAT {
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate {
if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" {
return false
}
}
@@ -669,40 +658,3 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecases) bool {
return true
}
// BuildCogitoOptions generates cogito options from the model configuration
// It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results
func (c *ModelConfig) BuildCogitoOptions() []cogito.Option {
cogitoOpts := []cogito.Option{
cogito.WithIterations(3), // default to 3 iterations
cogito.WithMaxAttempts(3), // default to 3 attempts
cogito.WithForceReasoning(),
}
// Apply agent configuration options
if c.Agent.EnableReasoning {
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
}
if c.Agent.EnablePlanning {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
}
if c.Agent.EnableMCPPrompts {
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
}
if c.Agent.EnablePlanReEvaluator {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
}
if c.Agent.MaxIterations != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations))
}
if c.Agent.MaxAttempts != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts))
}
return cogitoOpts
}

View File

@@ -3,9 +3,7 @@
package gallery
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
@@ -70,7 +68,7 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
}
// InstallBackendFromGallery installs a backend from the gallery.
func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
if !force {
// check if we already have the backend installed
backends, err := ListSystemBackends(systemState)
@@ -110,7 +108,7 @@ func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery,
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
// Then, let's install the best backend
if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil {
if err := InstallBackend(systemState, modelLoader, bestBackend, downloadStatus); err != nil {
return err
}
@@ -135,10 +133,10 @@ func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery,
return nil
}
return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus)
return InstallBackend(systemState, modelLoader, backend, downloadStatus)
}
func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
if err != nil {
@@ -165,17 +163,11 @@ func InstallBackend(ctx context.Context, systemState *system.SystemState, modelL
}
} else {
uri := downloader.URI(config.URI)
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
success := false
// Try to download from mirrors
for _, mirror := range config.Mirrors {
// Check for cancellation before trying next mirror
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
if err := downloader.URI(mirror).DownloadFile(backendPath, "", 1, 1, downloadStatus); err == nil {
success = true
break
}
@@ -318,10 +310,8 @@ func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error)
}
}
}
} else if !errors.Is(err, os.ErrNotExist) {
} else {
log.Warn().Err(err).Msg("Failed to read system backends, proceeding with user-managed backends")
} else if errors.Is(err, os.ErrNotExist) {
log.Debug().Msg("No system backends found")
}
// User-managed backends and alias collection

View File

@@ -1,7 +1,6 @@
package gallery
import (
"context"
"encoding/json"
"os"
"path/filepath"
@@ -56,7 +55,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
)
must(err)
sysDefault.GPUVendor = "" // force default selection
backs, err := ListSystemBackends(sysDefault)
backs, err := ListSystemBackends(sysDefault)
must(err)
aliasBack, ok := backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
@@ -78,7 +77,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
must(err)
sysNvidia.GPUVendor = "nvidia"
sysNvidia.VRAM = 8 * 1024 * 1024 * 1024
backs, err = ListSystemBackends(sysNvidia)
backs, err = ListSystemBackends(sysNvidia)
must(err)
aliasBack, ok = backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
@@ -117,13 +116,13 @@ var _ = Describe("Gallery Backends", func() {
Describe("InstallBackendFromGallery", func() {
It("should return error when backend is not found", func() {
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true)
err := InstallBackendFromGallery(galleries, systemState, ml, "non-existent", nil, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
})
It("should install backend from gallery", func() {
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true)
err := InstallBackendFromGallery(galleries, systemState, ml, "test-backend", nil, true)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
})
@@ -299,7 +298,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -379,7 +378,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -463,7 +462,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -562,7 +561,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(newPath),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
err = InstallBackend(systemState, ml, &backend, nil)
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
Expect(newPath).To(BeADirectory())
})
@@ -594,7 +593,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
err = InstallBackend(systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
@@ -627,7 +626,7 @@ var _ = Describe("Gallery Backends", func() {
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
err = InstallBackend(systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
})
@@ -648,7 +647,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
err = InstallBackend(systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())

91
core/gallery/btree.go Normal file
View File

@@ -0,0 +1,91 @@
package gallery
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/google/btree"
"gopkg.in/yaml.v3"
)
type Gallery[T GalleryElement] struct {
*btree.BTreeG[T]
}
func less[T GalleryElement](a, b T) bool {
return a.GetName() < b.GetName()
}
func (g *Gallery[T]) Search(term string) GalleryElements[T] {
var filteredModels GalleryElements[T]
g.Ascend(func(item T) bool {
if strings.Contains(item.GetName(), term) ||
strings.Contains(item.GetDescription(), term) ||
strings.Contains(item.GetGallery().Name, term) ||
strings.Contains(strings.Join(item.GetTags(), ","), term) {
filteredModels = append(filteredModels, item)
}
return true
})
return filteredModels
}
// processYAMLFile takes a single file path and adds its models to the existing tree.
func processYAMLFile[T GalleryElement](filePath string, tree *Gallery[T]) error {
file, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("could not open %s: %w", filePath, err)
}
defer file.Close()
decoder := yaml.NewDecoder(file)
// Stream documents from the file
for {
var model T
err := decoder.Decode(&model)
if err == io.EOF {
break // End of file
}
if err != nil {
return fmt.Errorf("error decoding %s: %w", filePath, err)
}
tree.ReplaceOrInsert(model)
}
return nil
}
// loadModelsFromDirectory scans a directory and processes all .yaml/.yml files.
func LoadGalleryFromDirectory[T GalleryElement](dirPath string) (*Gallery[T], error) {
tree := &Gallery[T]{btree.NewG[T](2, less[T])}
entries, err := os.ReadDir(dirPath)
if err != nil {
return nil, fmt.Errorf("could not read directory: %w", err)
}
for _, entry := range entries {
// Skip subdirectories and non-YAML files.
if entry.IsDir() {
continue
}
fileName := entry.Name()
if !strings.HasSuffix(fileName, ".yaml") && !strings.HasSuffix(fileName, ".yml") {
continue
}
fullPath := filepath.Join(dirPath, fileName)
err := processYAMLFile(fullPath, tree)
if err != nil {
return nil, err
}
}
return tree, nil
}

View File

@@ -1,7 +1,6 @@
package gallery
import (
"context"
"fmt"
"os"
"path/filepath"
@@ -29,19 +28,6 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
return config, nil
}
func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
var config T
uri := downloader.URI(url)
err := uri.DownloadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url")
return config, err
}
return config, nil
}
func ReadConfigFile[T any](filePath string) (*T, error) {
// Read the YAML file
yamlFile, err := os.ReadFile(filePath)
@@ -75,15 +61,12 @@ func (gm GalleryElements[T]) Search(term string) GalleryElements[T] {
term = strings.ToLower(term)
for _, m := range gm {
if fuzzy.Match(term, strings.ToLower(m.GetName())) ||
fuzzy.Match(term, strings.ToLower(m.GetDescription())) ||
fuzzy.Match(term, strings.ToLower(m.GetGallery().Name)) ||
strings.Contains(strings.ToLower(m.GetName()), term) ||
strings.Contains(strings.ToLower(m.GetDescription()), term) ||
strings.Contains(strings.ToLower(m.GetGallery().Name), term) ||
strings.Contains(strings.ToLower(strings.Join(m.GetTags(), ",")), term) {
filteredModels = append(filteredModels, m)
}
}
return filteredModels
}

View File

@@ -1,67 +0,0 @@
package importers
import (
"encoding/json"
"strings"
"github.com/rs/zerolog/log"
"github.com/mudler/LocalAI/core/gallery"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
)
var defaultImporters = []Importer{
&LlamaCPPImporter{},
&MLXImporter{},
&VLLMImporter{},
&TransformersImporter{},
}
type Details struct {
HuggingFace *hfapi.ModelDetails
URI string
Preferences json.RawMessage
}
type Importer interface {
Match(details Details) bool
Import(details Details) (gallery.ModelConfig, error)
}
func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.ModelConfig, error) {
var err error
var modelConfig gallery.ModelConfig
hf := hfapi.NewClient()
hfrepoID := strings.ReplaceAll(uri, "huggingface://", "")
hfrepoID = strings.ReplaceAll(hfrepoID, "hf://", "")
hfrepoID = strings.ReplaceAll(hfrepoID, "https://huggingface.co/", "")
hfDetails, err := hf.GetModelDetails(hfrepoID)
if err != nil {
// maybe not a HF repository
// TODO: maybe we can check if the URI is a valid HF repository
log.Debug().Str("uri", uri).Msg("Failed to get model details, maybe not a HF repository")
} else {
log.Debug().Str("uri", uri).Msg("Got model details")
log.Debug().Any("details", hfDetails).Msg("Model details")
}
details := Details{
HuggingFace: hfDetails,
URI: uri,
Preferences: preferences,
}
for _, importer := range defaultImporters {
if importer.Match(details) {
modelConfig, err = importer.Import(details)
if err != nil {
continue
}
break
}
}
return modelConfig, err
}

View File

@@ -1,13 +0,0 @@
package importers_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestImporters(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Importers test suite")
}

View File

@@ -1,215 +0,0 @@
package importers_test
import (
"encoding/json"
"fmt"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("DiscoverModelConfig", func() {
Context("With only a repository URI", func() {
It("should discover and import using LlamaCPPImporter", func() {
uri := "https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
Expect(modelConfig.Name).To(Equal("LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].Filename).To(Equal("localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF/resolve/main/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].SHA256).To(Equal("4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4"), fmt.Sprintf("Model config: %+v", modelConfig))
})
It("should discover and import using LlamaCPPImporter", func() {
uri := "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: mmproj/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].Filename).To(Equal("Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[1].Filename).To(Equal("mmproj/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
})
It("should discover and import using LlamaCPPImporter", func() {
uri := "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"
preferences := json.RawMessage(`{ "quantizations": "Q8_0", "mmproj_quantizations": "f16" }`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err))
Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: mmproj/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].Filename).To(Equal("Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[1].Filename).To(Equal("mmproj/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig))
})
})
Context("with .gguf URI", func() {
It("should discover and import using LlamaCPPImporter", func() {
uri := "https://example.com/my-model.gguf"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("my-model.gguf"))
Expect(modelConfig.Description).To(Equal("Imported from https://example.com/my-model.gguf"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
})
It("should use custom preferences when provided", func() {
uri := "https://example.com/my-model.gguf"
preferences := json.RawMessage(`{"name": "custom-name", "description": "Custom description"}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-name"))
Expect(modelConfig.Description).To(Equal("Custom description"))
})
})
Context("with mlx-community URI", func() {
It("should discover and import using MLXImporter", func() {
uri := "https://huggingface.co/mlx-community/test-model"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("test-model"))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mlx-community/test-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
})
It("should use custom preferences when provided", func() {
uri := "https://huggingface.co/mlx-community/test-model"
preferences := json.RawMessage(`{"name": "custom-mlx", "description": "Custom MLX description"}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-mlx"))
Expect(modelConfig.Description).To(Equal("Custom MLX description"))
})
})
Context("with backend preference", func() {
It("should use llama-cpp backend when specified", func() {
uri := "https://example.com/model"
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
})
It("should use mlx backend when specified", func() {
uri := "https://example.com/model"
preferences := json.RawMessage(`{"backend": "mlx"}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
})
It("should use mlx-vlm backend when specified", func() {
uri := "https://example.com/model"
preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx-vlm"))
})
})
Context("with HuggingFace URI formats", func() {
It("should handle huggingface:// prefix", func() {
uri := "huggingface://mlx-community/test-model"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("test-model"))
})
It("should handle hf:// prefix", func() {
uri := "hf://mlx-community/test-model"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("test-model"))
})
It("should handle https://huggingface.co/ prefix", func() {
uri := "https://huggingface.co/mlx-community/test-model"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("test-model"))
})
})
Context("with invalid or non-matching URI", func() {
It("should return error when no importer matches", func() {
uri := "https://example.com/unknown-model.bin"
preferences := json.RawMessage(`{}`)
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
// When no importer matches, the function returns empty config and error
// The exact behavior depends on implementation, but typically an error is returned
Expect(modelConfig.Name).To(BeEmpty())
Expect(err).To(HaveOccurred())
})
})
Context("with invalid JSON preferences", func() {
It("should return error when JSON is invalid even if URI matches", func() {
uri := "https://example.com/model.gguf"
preferences := json.RawMessage(`invalid json`)
// Even though Match() returns true for .gguf extension,
// Import() will fail when trying to unmarshal invalid JSON preferences
modelConfig, err := importers.DiscoverModelConfig(uri, preferences)
Expect(err).To(HaveOccurred())
Expect(modelConfig.Name).To(BeEmpty())
})
})
})

View File

@@ -1,209 +0,0 @@
package importers
import (
"encoding/json"
"path/filepath"
"slices"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/functions"
"go.yaml.in/yaml/v2"
)
var _ Importer = &LlamaCPPImporter{}
type LlamaCPPImporter struct{}
func (i *LlamaCPPImporter) Match(details Details) bool {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return false
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return false
}
if preferencesMap["backend"] == "llama-cpp" {
return true
}
if strings.HasSuffix(details.URI, ".gguf") {
return true
}
if details.HuggingFace != nil {
for _, file := range details.HuggingFace.Files {
if strings.HasSuffix(file.Path, ".gguf") {
return true
}
}
}
return false
}
func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return gallery.ModelConfig{}, err
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return gallery.ModelConfig{}, err
}
name, ok := preferencesMap["name"].(string)
if !ok {
name = filepath.Base(details.URI)
}
description, ok := preferencesMap["description"].(string)
if !ok {
description = "Imported from " + details.URI
}
preferedQuantizations, _ := preferencesMap["quantizations"].(string)
quants := []string{"q4_k_m"}
if preferedQuantizations != "" {
quants = strings.Split(preferedQuantizations, ",")
}
mmprojQuants, _ := preferencesMap["mmproj_quantizations"].(string)
mmprojQuantsList := []string{"fp16"}
if mmprojQuants != "" {
mmprojQuantsList = strings.Split(mmprojQuants, ",")
}
embeddings, _ := preferencesMap["embeddings"].(string)
modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Options: []string{"use_jinja:true"},
Backend: "llama-cpp",
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
},
FunctionsConfig: functions.FunctionsConfig{
GrammarConfig: functions.GrammarConfig{
NoGrammar: true,
},
},
}
if embeddings != "" && strings.ToLower(embeddings) == "true" || strings.ToLower(embeddings) == "yes" {
trueV := true
modelConfig.Embeddings = &trueV
}
cfg := gallery.ModelConfig{
Name: name,
Description: description,
}
if strings.HasSuffix(details.URI, ".gguf") {
cfg.Files = append(cfg.Files, gallery.File{
URI: details.URI,
Filename: filepath.Base(details.URI),
})
modelConfig.PredictionOptions = schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: filepath.Base(details.URI),
},
}
} else if details.HuggingFace != nil {
// We want to:
// Get first the chosen quants that match filenames
// OR the first mmproj/gguf file found
var lastMMProjFile *gallery.File
var lastGGUFFile *gallery.File
foundPreferedQuant := false
foundPreferedMMprojQuant := false
for _, file := range details.HuggingFace.Files {
// Get the mmproj prefered quants
if strings.Contains(strings.ToLower(file.Path), "mmproj") {
lastMMProjFile = &gallery.File{
URI: file.URL,
Filename: filepath.Join("mmproj", filepath.Base(file.Path)),
SHA256: file.SHA256,
}
if slices.ContainsFunc(mmprojQuantsList, func(quant string) bool {
return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
}) {
cfg.Files = append(cfg.Files, *lastMMProjFile)
foundPreferedMMprojQuant = true
}
} else if strings.HasSuffix(strings.ToLower(file.Path), "gguf") {
lastGGUFFile = &gallery.File{
URI: file.URL,
Filename: filepath.Base(file.Path),
SHA256: file.SHA256,
}
// get the files of the prefered quants
if slices.ContainsFunc(quants, func(quant string) bool {
return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant))
}) {
foundPreferedQuant = true
cfg.Files = append(cfg.Files, *lastGGUFFile)
}
}
}
// Make sure to add at least one file if not already present (which is the latest one)
if lastMMProjFile != nil && !foundPreferedMMprojQuant {
if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
return f.Filename == lastMMProjFile.Filename
}) {
cfg.Files = append(cfg.Files, *lastMMProjFile)
}
}
if lastGGUFFile != nil && !foundPreferedQuant {
if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool {
return f.Filename == lastGGUFFile.Filename
}) {
cfg.Files = append(cfg.Files, *lastGGUFFile)
}
}
// Find first mmproj file and configure it in the config file
for _, file := range cfg.Files {
if !strings.Contains(strings.ToLower(file.Filename), "mmproj") {
continue
}
modelConfig.MMProj = file.Filename
break
}
// Find first non-mmproj file and configure it in the config file
for _, file := range cfg.Files {
if strings.Contains(strings.ToLower(file.Filename), "mmproj") {
continue
}
modelConfig.PredictionOptions = schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: file.Filename,
},
}
break
}
}
data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err
}
cfg.ConfigFile = string(data)
return cfg, nil
}

View File

@@ -1,132 +0,0 @@
package importers_test
import (
"encoding/json"
"fmt"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("LlamaCPPImporter", func() {
var importer *LlamaCPPImporter
BeforeEach(func() {
importer = &LlamaCPPImporter{}
})
Context("Match", func() {
It("should match when URI ends with .gguf", func() {
details := Details{
URI: "https://example.com/model.gguf",
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when backend preference is llama-cpp", func() {
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should not match when URI does not end with .gguf and no backend preference", func() {
details := Details{
URI: "https://example.com/model.bin",
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should not match when backend preference is different", func() {
preferences := json.RawMessage(`{"backend": "mlx"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should return false when JSON preferences are invalid", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://example.com/model.gguf",
Preferences: preferences,
}
// Invalid JSON causes Match to return false early
result := importer.Match(details)
Expect(result).To(BeFalse())
})
})
Context("Import", func() {
It("should import model config with default name and description", func() {
details := Details{
URI: "https://example.com/my-model.gguf",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("my-model.gguf"))
Expect(modelConfig.Description).To(Equal("Imported from https://example.com/my-model.gguf"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"))
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
})
It("should import model config with custom name and description from preferences", func() {
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
details := Details{
URI: "https://example.com/my-model.gguf",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-model"))
Expect(modelConfig.Description).To(Equal("Custom description"))
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
})
It("should handle invalid JSON preferences", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://example.com/my-model.gguf",
Preferences: preferences,
}
_, err := importer.Import(details)
Expect(err).To(HaveOccurred())
})
It("should extract filename correctly from URI with path", func() {
details := importers.Details{
URI: "https://example.com/path/to/model.gguf",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/path/to/model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
Expect(modelConfig.Files[0].Filename).To(Equal("model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig))
})
})
})

View File

@@ -1,94 +0,0 @@
package importers
import (
"encoding/json"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"go.yaml.in/yaml/v2"
)
var _ Importer = &MLXImporter{}
type MLXImporter struct{}
func (i *MLXImporter) Match(details Details) bool {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return false
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return false
}
b, ok := preferencesMap["backend"].(string)
if ok && b == "mlx" || b == "mlx-vlm" {
return true
}
// All https://huggingface.co/mlx-community/*
if strings.Contains(details.URI, "mlx-community/") {
return true
}
return false
}
func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return gallery.ModelConfig{}, err
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return gallery.ModelConfig{}, err
}
name, ok := preferencesMap["name"].(string)
if !ok {
name = filepath.Base(details.URI)
}
description, ok := preferencesMap["description"].(string)
if !ok {
description = "Imported from " + details.URI
}
backend := "mlx"
b, ok := preferencesMap["backend"].(string)
if ok {
backend = b
}
modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Backend: backend,
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: details.URI,
},
},
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
},
}
data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err
}
return gallery.ModelConfig{
Name: name,
Description: description,
ConfigFile: string(data),
}, nil
}

View File

@@ -1,147 +0,0 @@
package importers_test
import (
"encoding/json"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("MLXImporter", func() {
var importer *importers.MLXImporter
BeforeEach(func() {
importer = &importers.MLXImporter{}
})
Context("Match", func() {
It("should match when URI contains mlx-community/", func() {
details := importers.Details{
URI: "https://huggingface.co/mlx-community/test-model",
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when backend preference is mlx", func() {
preferences := json.RawMessage(`{"backend": "mlx"}`)
details := importers.Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when backend preference is mlx-vlm", func() {
preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
details := importers.Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should not match when URI does not contain mlx-community/ and no backend preference", func() {
details := importers.Details{
URI: "https://huggingface.co/other-org/test-model",
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should not match when backend preference is different", func() {
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
details := importers.Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should return false when JSON preferences are invalid", func() {
preferences := json.RawMessage(`invalid json`)
details := importers.Details{
URI: "https://huggingface.co/mlx-community/test-model",
Preferences: preferences,
}
// Invalid JSON causes Match to return false early
result := importer.Match(details)
Expect(result).To(BeFalse())
})
})
Context("Import", func() {
It("should import model config with default name and description", func() {
details := importers.Details{
URI: "https://huggingface.co/mlx-community/test-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("test-model"))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mlx-community/test-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/mlx-community/test-model"))
})
It("should import model config with custom name and description from preferences", func() {
preferences := json.RawMessage(`{"name": "custom-mlx-model", "description": "Custom MLX description"}`)
details := importers.Details{
URI: "https://huggingface.co/mlx-community/test-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-mlx-model"))
Expect(modelConfig.Description).To(Equal("Custom MLX description"))
})
It("should use custom backend from preferences", func() {
preferences := json.RawMessage(`{"backend": "mlx-vlm"}`)
details := importers.Details{
URI: "https://huggingface.co/mlx-community/test-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx-vlm"))
})
It("should handle invalid JSON preferences", func() {
preferences := json.RawMessage(`invalid json`)
details := importers.Details{
URI: "https://huggingface.co/mlx-community/test-model",
Preferences: preferences,
}
_, err := importer.Import(details)
Expect(err).To(HaveOccurred())
})
It("should extract filename correctly from URI with path", func() {
details := importers.Details{
URI: "https://huggingface.co/mlx-community/path/to/model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("model"))
})
})
})

View File

@@ -1,110 +0,0 @@
package importers
import (
"encoding/json"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"go.yaml.in/yaml/v2"
)
var _ Importer = &TransformersImporter{}
type TransformersImporter struct{}
func (i *TransformersImporter) Match(details Details) bool {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return false
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return false
}
b, ok := preferencesMap["backend"].(string)
if ok && b == "transformers" {
return true
}
if details.HuggingFace != nil {
for _, file := range details.HuggingFace.Files {
if strings.Contains(file.Path, "tokenizer.json") ||
strings.Contains(file.Path, "tokenizer_config.json") {
return true
}
}
}
return false
}
func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, error) {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return gallery.ModelConfig{}, err
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return gallery.ModelConfig{}, err
}
name, ok := preferencesMap["name"].(string)
if !ok {
name = filepath.Base(details.URI)
}
description, ok := preferencesMap["description"].(string)
if !ok {
description = "Imported from " + details.URI
}
backend := "transformers"
b, ok := preferencesMap["backend"].(string)
if ok {
backend = b
}
modelType, ok := preferencesMap["type"].(string)
if !ok {
modelType = "AutoModelForCausalLM"
}
quantization, ok := preferencesMap["quantization"].(string)
if !ok {
quantization = ""
}
modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Backend: backend,
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: details.URI,
},
},
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
},
}
modelConfig.ModelType = modelType
modelConfig.Quantization = quantization
data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err
}
return gallery.ModelConfig{
Name: name,
Description: description,
ConfigFile: string(data),
}, nil
}

View File

@@ -1,219 +0,0 @@
package importers_test
import (
"encoding/json"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/mudler/LocalAI/core/gallery/importers"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("TransformersImporter", func() {
var importer *TransformersImporter
BeforeEach(func() {
importer = &TransformersImporter{}
})
Context("Match", func() {
It("should match when backend preference is transformers", func() {
preferences := json.RawMessage(`{"backend": "transformers"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when HuggingFace details contain tokenizer.json", func() {
hfDetails := &hfapi.ModelDetails{
Files: []hfapi.ModelFile{
{Path: "tokenizer.json"},
},
}
details := Details{
URI: "https://huggingface.co/test/model",
HuggingFace: hfDetails,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when HuggingFace details contain tokenizer_config.json", func() {
hfDetails := &hfapi.ModelDetails{
Files: []hfapi.ModelFile{
{Path: "tokenizer_config.json"},
},
}
details := Details{
URI: "https://huggingface.co/test/model",
HuggingFace: hfDetails,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should not match when URI has no tokenizer files and no backend preference", func() {
details := Details{
URI: "https://example.com/model.bin",
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should not match when backend preference is different", func() {
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should return false when JSON preferences are invalid", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
})
Context("Import", func() {
It("should import model config with default name and description", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("my-model"))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
})
It("should import model config with custom name and description from preferences", func() {
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-model"))
Expect(modelConfig.Description).To(Equal("Custom description"))
})
It("should use custom model type from preferences", func() {
preferences := json.RawMessage(`{"type": "SentenceTransformer"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: SentenceTransformer"))
})
It("should use default model type when not specified", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM"))
})
It("should use custom backend from preferences", func() {
preferences := json.RawMessage(`{"backend": "transformers"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers"))
})
It("should use quantization from preferences", func() {
preferences := json.RawMessage(`{"quantization": "int8"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("quantization: int8"))
})
It("should handle invalid JSON preferences", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
_, err := importer.Import(details)
Expect(err).To(HaveOccurred())
})
It("should extract filename correctly from URI with path", func() {
details := importers.Details{
URI: "https://huggingface.co/test/path/to/model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("model"))
})
It("should include use_tokenizer_template in config", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
})
It("should include known_usecases in config", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
})
})
})

View File

@@ -1,98 +0,0 @@
package importers
import (
"encoding/json"
"path/filepath"
"strings"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"go.yaml.in/yaml/v2"
)
var _ Importer = &VLLMImporter{}
type VLLMImporter struct{}
func (i *VLLMImporter) Match(details Details) bool {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return false
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return false
}
b, ok := preferencesMap["backend"].(string)
if ok && b == "vllm" {
return true
}
if details.HuggingFace != nil {
for _, file := range details.HuggingFace.Files {
if strings.Contains(file.Path, "tokenizer.json") ||
strings.Contains(file.Path, "tokenizer_config.json") {
return true
}
}
}
return false
}
func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) {
preferences, err := details.Preferences.MarshalJSON()
if err != nil {
return gallery.ModelConfig{}, err
}
preferencesMap := make(map[string]any)
err = json.Unmarshal(preferences, &preferencesMap)
if err != nil {
return gallery.ModelConfig{}, err
}
name, ok := preferencesMap["name"].(string)
if !ok {
name = filepath.Base(details.URI)
}
description, ok := preferencesMap["description"].(string)
if !ok {
description = "Imported from " + details.URI
}
backend := "vllm"
b, ok := preferencesMap["backend"].(string)
if ok {
backend = b
}
modelConfig := config.ModelConfig{
Name: name,
Description: description,
KnownUsecaseStrings: []string{"chat"},
Backend: backend,
PredictionOptions: schema.PredictionOptions{
BasicModelRequest: schema.BasicModelRequest{
Model: details.URI,
},
},
TemplateConfig: config.TemplateConfig{
UseTokenizerTemplate: true,
},
}
data, err := yaml.Marshal(modelConfig)
if err != nil {
return gallery.ModelConfig{}, err
}
return gallery.ModelConfig{
Name: name,
Description: description,
ConfigFile: string(data),
}, nil
}

View File

@@ -1,181 +0,0 @@
package importers_test
import (
"encoding/json"
"github.com/mudler/LocalAI/core/gallery/importers"
. "github.com/mudler/LocalAI/core/gallery/importers"
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("VLLMImporter", func() {
var importer *VLLMImporter
BeforeEach(func() {
importer = &VLLMImporter{}
})
Context("Match", func() {
It("should match when backend preference is vllm", func() {
preferences := json.RawMessage(`{"backend": "vllm"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when HuggingFace details contain tokenizer.json", func() {
hfDetails := &hfapi.ModelDetails{
Files: []hfapi.ModelFile{
{Path: "tokenizer.json"},
},
}
details := Details{
URI: "https://huggingface.co/test/model",
HuggingFace: hfDetails,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should match when HuggingFace details contain tokenizer_config.json", func() {
hfDetails := &hfapi.ModelDetails{
Files: []hfapi.ModelFile{
{Path: "tokenizer_config.json"},
},
}
details := Details{
URI: "https://huggingface.co/test/model",
HuggingFace: hfDetails,
}
result := importer.Match(details)
Expect(result).To(BeTrue())
})
It("should not match when URI has no tokenizer files and no backend preference", func() {
details := Details{
URI: "https://example.com/model.bin",
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should not match when backend preference is different", func() {
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
It("should return false when JSON preferences are invalid", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://example.com/model",
Preferences: preferences,
}
result := importer.Match(details)
Expect(result).To(BeFalse())
})
})
Context("Import", func() {
It("should import model config with default name and description", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("my-model"))
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model"))
})
It("should import model config with custom name and description from preferences", func() {
preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("custom-model"))
Expect(modelConfig.Description).To(Equal("Custom description"))
})
It("should use custom backend from preferences", func() {
preferences := json.RawMessage(`{"backend": "vllm"}`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm"))
})
It("should handle invalid JSON preferences", func() {
preferences := json.RawMessage(`invalid json`)
details := Details{
URI: "https://huggingface.co/test/my-model",
Preferences: preferences,
}
_, err := importer.Import(details)
Expect(err).To(HaveOccurred())
})
It("should extract filename correctly from URI with path", func() {
details := importers.Details{
URI: "https://huggingface.co/test/path/to/model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.Name).To(Equal("model"))
})
It("should include use_tokenizer_template in config", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true"))
})
It("should include known_usecases in config", func() {
details := Details{
URI: "https://huggingface.co/test/my-model",
}
modelConfig, err := importer.Import(details)
Expect(err).ToNot(HaveOccurred())
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat"))
})
})
})

View File

@@ -1,7 +1,6 @@
package gallery
import (
"context"
"errors"
"fmt"
"os"
@@ -73,7 +72,6 @@ type PromptTemplate struct {
// Installs a model from the gallery
func InstallModelFromGallery(
ctx context.Context,
modelGalleries, backendGalleries []config.Gallery,
systemState *system.SystemState,
modelLoader *model.ModelLoader,
@@ -86,7 +84,7 @@ func InstallModelFromGallery(
if len(model.URL) > 0 {
var err error
config, err = GetGalleryConfigFromURLWithContext[ModelConfig](ctx, model.URL, systemState.Model.ModelsPath)
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath)
if err != nil {
return err
}
@@ -127,7 +125,7 @@ func InstallModelFromGallery(
return err
}
installedModel, err := InstallModel(ctx, systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
if err != nil {
return err
}
@@ -135,7 +133,7 @@ func InstallModelFromGallery(
if automaticallyInstallBackend && installedModel.Backend != "" {
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
if err := InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
return err
}
}
@@ -156,7 +154,7 @@ func InstallModelFromGallery(
return applyModel(model)
}
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
basePath := systemState.Model.ModelsPath
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0750)
@@ -170,13 +168,6 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
// Download files and verify their SHA
for i, file := range config.Files {
// Check for cancellation before each file
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, basePath); err != nil {
@@ -194,7 +185,7 @@ func InstallModel(ctx context.Context, systemState *system.SystemState, nameOver
}
}
uri := downloader.URI(file.URI)
if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
return nil, err
}
}

View File

@@ -1,7 +1,6 @@
package gallery_test
import (
"context"
"errors"
"os"
"path/filepath"
@@ -35,7 +34,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
@@ -89,7 +88,7 @@ var _ = Describe("Model test", func() {
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
Expect(models[0].Installed).To(BeFalse())
err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
Expect(err).ToNot(HaveOccurred())
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
@@ -130,7 +129,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -150,7 +149,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -180,7 +179,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).To(HaveOccurred())
})
})

View File

@@ -4,23 +4,30 @@ import (
"embed"
"errors"
"fmt"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/dave-gray101/v2keyauth"
"github.com/gofiber/websocket/v2"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
httpMiddleware "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/gofiber/contrib/fiberzerolog"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/csrf"
"github.com/gofiber/fiber/v2/middleware/favicon"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/gofiber/fiber/v2/middleware/recover"
// swagger handler
"github.com/rs/zerolog/log"
)
@@ -42,85 +49,85 @@ var embedDirStatic embed.FS
// @in header
// @name Authorization
func API(application *application.Application) (*echo.Echo, error) {
e := echo.New()
func API(application *application.Application) (*fiber.App, error) {
// Set body limit
if application.ApplicationConfig().UploadLimitMB > 0 {
e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", application.ApplicationConfig().UploadLimitMB)))
fiberCfg := fiber.Config{
Views: renderEngine(),
BodyLimit: application.ApplicationConfig().UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
// We disable the Fiber startup message as it does not conform to structured logging.
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
DisableStartupMessage: true,
// Override default error handler
}
// Set error handler
if !application.ApplicationConfig().OpaqueErrors {
e.HTTPErrorHandler = func(err error, c echo.Context) {
code := http.StatusInternalServerError
var he *echo.HTTPError
if errors.As(err, &he) {
code = he.Code
}
// Normally, return errors as JSON responses
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500
code := fiber.StatusInternalServerError
// Handle 404 errors with HTML rendering when appropriate
if code == http.StatusNotFound {
notFoundHandler(c)
return
// Retrieve the custom status code if it's a *fiber.Error
var e *fiber.Error
if errors.As(err, &e) {
code = e.Code
}
// Send custom error page
c.JSON(code, schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
})
return ctx.Status(code).JSON(
schema.ErrorResponse{
Error: &schema.APIError{Message: err.Error(), Code: code},
},
)
}
} else {
e.HTTPErrorHandler = func(err error, c echo.Context) {
code := http.StatusInternalServerError
var he *echo.HTTPError
if errors.As(err, &he) {
code = he.Code
}
c.NoContent(code)
// If OpaqueErrors are required, replace everything with a blank 500.
fiberCfg.ErrorHandler = func(ctx *fiber.Ctx, _ error) error {
return ctx.Status(500).SendString("")
}
}
// Set renderer
e.Renderer = renderEngine()
router := fiber.New(fiberCfg)
// Hide banner
e.HideBanner = true
// Middleware - StripPathPrefix must be registered early as it uses Rewrite which runs before routing
e.Pre(httpMiddleware.StripPathPrefix())
router.Use(middleware.StripPathPrefix())
if application.ApplicationConfig().MachineTag != "" {
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Response().Header().Set("Machine-Tag", application.ApplicationConfig().MachineTag)
return next(c)
}
router.Use(func(c *fiber.Ctx) error {
c.Response().Header.Set("Machine-Tag", application.ApplicationConfig().MachineTag)
return c.Next()
})
}
// Custom logger middleware using zerolog
e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()
res := c.Response()
start := log.Logger.Info()
err := next(c)
start.
Str("method", req.Method).
Str("path", req.URL.Path).
Int("status", res.Status).
Msg("HTTP request")
return err
router.Use("/v1/realtime", func(c *fiber.Ctx) error {
if websocket.IsWebSocketUpgrade(c) {
// Returns true if the client requested upgrade to the WebSocket protocol
return c.Next()
}
return nil
})
// Recover middleware
router.Hooks().OnListen(func(listenData fiber.ListenData) error {
scheme := "http"
if listenData.TLS {
scheme = "https"
}
log.Info().Str("endpoint", scheme+"://"+listenData.Host+":"+listenData.Port).Msg("LocalAI API is listening! Please connect to the endpoint for API documentation.")
return nil
})
// Have Fiber use zerolog like the rest of the application rather than it's built-in logger
logger := log.Logger
router.Use(fiberzerolog.New(fiberzerolog.Config{
Logger: &logger,
}))
// Default middleware config
if !application.ApplicationConfig().Debug {
e.Use(middleware.Recover())
router.Use(recover.New())
}
// Metrics middleware
if !application.ApplicationConfig().DisableMetrics {
metricsService, err := services.NewLocalAIMetricsService()
if err != nil {
@@ -128,40 +135,34 @@ func API(application *application.Application) (*echo.Echo, error) {
}
if metricsService != nil {
e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
e.Server.RegisterOnShutdown(func() {
metricsService.Shutdown()
router.Use(localai.LocalAIMetricsAPIMiddleware(metricsService))
router.Hooks().OnShutdown(func() error {
return metricsService.Shutdown()
})
}
}
// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(e)
routes.HealthRoutes(router)
// Get key auth middleware
keyAuthMiddleware, err := httpMiddleware.GetKeyAuthConfig(application.ApplicationConfig())
if err != nil {
kaConfig, err := middleware.GetKeyAuthConfig(application.ApplicationConfig())
if err != nil || kaConfig == nil {
return nil, fmt.Errorf("failed to create key auth config: %w", err)
}
// Favicon handler
e.GET("/favicon.svg", func(c echo.Context) error {
data, err := embedDirStatic.ReadFile("static/favicon.svg")
if err != nil {
return c.NoContent(http.StatusNotFound)
}
c.Response().Header().Set("Content-Type", "image/svg+xml")
return c.Blob(http.StatusOK, "image/svg+xml", data)
})
httpFS := http.FS(embedDirStatic)
// Static files - use fs.Sub to create a filesystem rooted at "static"
staticFS, err := fs.Sub(embedDirStatic, "static")
if err != nil {
return nil, fmt.Errorf("failed to create static filesystem: %w", err)
}
e.StaticFS("/static", staticFS)
router.Use(favicon.New(favicon.Config{
URL: "/favicon.svg",
FileSystem: httpFS,
File: "static/favicon.svg",
}))
router.Use("/static", filesystem.New(filesystem.Config{
Root: httpFS,
PathPrefix: "static",
Browse: true,
}))
// Generated content directories
if application.ApplicationConfig().GeneratedContentDir != "" {
os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750)
audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio")
@@ -172,53 +173,46 @@ func API(application *application.Application) (*echo.Echo, error) {
os.MkdirAll(imagePath, 0750)
os.MkdirAll(videoPath, 0750)
e.Static("/generated-audio", audioPath)
e.Static("/generated-images", imagePath)
e.Static("/generated-videos", videoPath)
router.Static("/generated-audio", audioPath)
router.Static("/generated-images", imagePath)
router.Static("/generated-videos", videoPath)
}
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Skipper property of the KeyAuth Configuration
e.Use(keyAuthMiddleware)
// Auth is applied to _all_ endpoints. No exceptions. Filtering out endpoints to bypass is the role of the Filter property of the KeyAuth Configuration
router.Use(v2keyauth.New(*kaConfig))
// CORS middleware
if application.ApplicationConfig().CORS {
corsConfig := middleware.CORSConfig{}
if application.ApplicationConfig().CORSAllowOrigins != "" {
corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",")
var c func(ctx *fiber.Ctx) error
if application.ApplicationConfig().CORSAllowOrigins == "" {
c = cors.New()
} else {
c = cors.New(cors.Config{AllowOrigins: application.ApplicationConfig().CORSAllowOrigins})
}
e.Use(middleware.CORSWithConfig(corsConfig))
router.Use(c)
}
// CSRF middleware
if application.ApplicationConfig().CSRF {
log.Debug().Msg("Enabling CSRF middleware. Tokens are now required for state-modifying requests")
e.Use(middleware.CSRF())
router.Use(csrf.New())
}
requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
requestExtractor := middleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
// Create opcache for tracking UI operations (used by both UI and LocalAI routes)
var opcache *services.OpCache
routes.RegisterElevenLabsRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterLocalAIRoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
routes.RegisterOpenAIRoutes(router, requestExtractor, application)
if !application.ApplicationConfig().DisableWebUI {
opcache = services.NewOpCache(application.GalleryService())
// Create opcache for tracking UI operations
opcache := services.NewOpCache(application.GalleryService())
routes.RegisterUIAPIRoutes(router, application.ModelConfigLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterUIRoutes(router, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
}
routes.RegisterJINARoutes(router, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator())
routes.RegisterOpenAIRoutes(e, requestExtractor, application)
if !application.ApplicationConfig().DisableWebUI {
routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache)
routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService())
}
routes.RegisterJINARoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
// Define a custom 404 handler
// Note: keep this at the bottom!
router.Use(notFoundHandler)
// Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route
// Log startup message
e.Server.RegisterOnShutdown(func() {
log.Info().Msg("LocalAI API server shutting down")
})
return e, nil
return router, nil
}

View File

@@ -10,14 +10,13 @@ import (
"os"
"path/filepath"
"runtime"
"time"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http"
"github.com/mudler/LocalAI/core/schema"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/downloader"
"github.com/mudler/LocalAI/pkg/system"
@@ -26,7 +25,6 @@ import (
"gopkg.in/yaml.v3"
openaigo "github.com/otiai10/openaigo"
"github.com/rs/zerolog/log"
"github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
)
@@ -87,7 +85,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
response := []gallery.GalleryModel{}
uri := downloader.URI(url)
// TODO: No tests currently seem to exercise file:// urls. Fix?
err := uri.DownloadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error {
err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
@@ -268,7 +266,7 @@ const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b8
var _ = Describe("API test", func() {
var app *echo.Echo
var app *fiber.App
var client *openai.Client
var client2 *openaigo.Client
var c context.Context
@@ -341,11 +339,7 @@ var _ = Describe("API test", func() {
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go func() {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -364,9 +358,7 @@ var _ = Describe("API test", func() {
AfterEach(func(sc SpecContext) {
cancel()
if app != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := app.Shutdown(ctx)
err := app.Shutdown()
Expect(err).ToNot(HaveOccurred())
}
err := os.RemoveAll(tmpdir)
@@ -555,11 +547,7 @@ var _ = Describe("API test", func() {
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go func() {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -578,9 +566,7 @@ var _ = Describe("API test", func() {
AfterEach(func() {
cancel()
if app != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := app.Shutdown(ctx)
err := app.Shutdown()
Expect(err).ToNot(HaveOccurred())
}
err := os.RemoveAll(tmpdir)
@@ -769,11 +755,7 @@ var _ = Describe("API test", func() {
Expect(err).ToNot(HaveOccurred())
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go func() {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -791,9 +773,7 @@ var _ = Describe("API test", func() {
AfterEach(func() {
cancel()
if app != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := app.Shutdown(ctx)
err := app.Shutdown()
Expect(err).ToNot(HaveOccurred())
}
})
@@ -816,83 +796,6 @@ var _ = Describe("API test", func() {
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
})
It("returns logprobs in chat completions when requested", func() {
topLogprobsVal := 3
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
Model: "testmodel.ggml",
LogProbs: true,
TopLogProbs: topLogprobsVal,
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(response.Choices)).To(Equal(1))
Expect(response.Choices[0].Message).ToNot(BeNil())
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
// Verify logprobs are present and have correct structure
Expect(response.Choices[0].LogProbs).ToNot(BeNil())
Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty())
Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1))
foundatLeastToken := ""
foundAtLeastBytes := []byte{}
foundAtLeastTopLogprobBytes := []byte{}
foundatLeastTopLogprob := ""
// Verify logprobs content structure matches OpenAI format
for _, logprobContent := range response.Choices[0].LogProbs.Content {
// Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it
if len(logprobContent.Bytes) > 0 {
foundAtLeastBytes = logprobContent.Bytes
}
if len(logprobContent.Token) > 0 {
foundatLeastToken = logprobContent.Token
}
Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1))
// If top_logprobs is requested, verify top_logprobs array respects the limit
if len(logprobContent.TopLogProbs) > 0 {
// Should respect top_logprobs limit (3 in this test)
Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal))
for _, topLogprob := range logprobContent.TopLogProbs {
if len(topLogprob.Bytes) > 0 {
foundAtLeastTopLogprobBytes = topLogprob.Bytes
}
if len(topLogprob.Token) > 0 {
foundatLeastTopLogprob = topLogprob.Token
}
Expect(topLogprob.LogProb).To(BeNumerically("<=", 0))
}
}
}
Expect(foundAtLeastBytes).ToNot(BeEmpty())
Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty())
Expect(foundatLeastToken).ToNot(BeEmpty())
Expect(foundatLeastTopLogprob).ToNot(BeEmpty())
})
It("applies logit_bias to chat completions when requested", func() {
// logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
// According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion
logitBias := map[string]int{
"15043": 1, // Bias token ID 15043 (example token ID) with bias value 1
}
response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{
Model: "testmodel.ggml",
Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}},
LogitBias: logitBias,
})
Expect(err).ToNot(HaveOccurred())
Expect(len(response.Choices)).To(Equal(1))
Expect(response.Choices[0].Message).ToNot(BeNil())
Expect(response.Choices[0].Message.Content).ToNot(BeEmpty())
// If logit_bias is applied, the response should be generated successfully
// We can't easily verify the bias effect without knowing the actual token IDs for the model,
// but the fact that the request succeeds confirms the API accepts and processes logit_bias
})
It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt})
Expect(err).To(HaveOccurred())
@@ -1103,11 +1006,7 @@ var _ = Describe("API test", func() {
app, err = API(application)
Expect(err).ToNot(HaveOccurred())
go func() {
if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed {
log.Error().Err(err).Msg("server error")
}
}()
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
@@ -1123,9 +1022,7 @@ var _ = Describe("API test", func() {
AfterEach(func() {
cancel()
if app != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := app.Shutdown(ctx)
err := app.Shutdown()
Expect(err).ToNot(HaveOccurred())
}
})

View File

@@ -1,9 +1,7 @@
package elevenlabs
import (
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -17,17 +15,17 @@ import (
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/sound-generation [post]
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
if !ok || input.ModelID == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend")
@@ -37,7 +35,7 @@ func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader
if err != nil {
return err
}
return c.Attachment(filePath, filepath.Base(filePath))
return c.Download(filePath)
}
}

View File

@@ -1,14 +1,13 @@
package elevenlabs
import (
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/rs/zerolog/log"
)
@@ -18,19 +17,19 @@ import (
// @Param request body schema.TTSRequest true "query params"
// @Success 200 {string} binary "Response"
// @Router /v1/text-to-speech/{voice-id} [post]
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
voiceID := c.Param("voice-id")
voiceID := c.Params("voice-id")
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest)
if !ok || input.ModelID == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request received")
@@ -39,6 +38,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
if err != nil {
return err
}
return c.Attachment(filePath, filepath.Base(filePath))
return c.Download(filePath)
}
}

View File

@@ -2,32 +2,28 @@ package explorer
import (
"encoding/base64"
"net/http"
"sort"
"strings"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/internal"
)
func Dashboard() echo.HandlerFunc {
return func(c echo.Context) error {
summary := map[string]interface{}{
func Dashboard() func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"BaseURL": middleware.BaseURL(c),
"BaseURL": utils.BaseURL(c),
}
contentType := c.Request().Header.Get("Content-Type")
accept := c.Request().Header.Get("Accept")
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "html")) {
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
// The client expects a JSON response
return c.JSON(http.StatusOK, summary)
return c.Status(fiber.StatusOK).JSON(summary)
} else {
// Render index
return c.Render(http.StatusOK, "views/explorer", summary)
return c.Render("views/explorer", summary)
}
}
}
@@ -43,8 +39,8 @@ type Network struct {
Token string `json:"token"`
}
func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
return func(c echo.Context) error {
func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
results := []Network{}
for _, token := range db.TokenList() {
networkData, exists := db.Get(token) // get the token data
@@ -65,44 +61,44 @@ func ShowNetworks(db *explorer.Database) echo.HandlerFunc {
return len(results[i].Clusters) > len(results[j].Clusters)
})
return c.JSON(http.StatusOK, results)
return c.JSON(results)
}
}
func AddNetwork(db *explorer.Database) echo.HandlerFunc {
return func(c echo.Context) error {
func AddNetwork(db *explorer.Database) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
request := new(AddNetworkRequest)
if err := c.Bind(request); err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"})
if err := c.BodyParser(request); err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Cannot parse JSON"})
}
if request.Token == "" {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"})
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token is required"})
}
if request.Name == "" {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"})
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Name is required"})
}
if request.Description == "" {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"})
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Description is required"})
}
// TODO: check if token is valid, otherwise reject
// try to decode the token from base64
_, err := base64.StdEncoding.DecodeString(request.Token)
if err != nil {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"})
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid token"})
}
if _, exists := db.Get(request.Token); exists {
return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"})
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Token already exists"})
}
err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description})
if err != nil {
return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"})
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Cannot add token"})
}
return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"})
return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Token added"})
}
}

View File

@@ -1,12 +1,11 @@
package jina
import (
"net/http"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/model"
@@ -18,36 +17,24 @@ import (
// @Param request body schema.JINARerankRequest true "query params"
// @Success 200 {object} schema.JINARerankResponse "Response"
// @Router /v1/rerank [post]
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
log.Debug().Str("model", input.Model).Msg("JINA Rerank Request received")
var requestTopN int32
docs := int32(len(input.Documents))
if input.TopN == nil { // omit top_n to get all
requestTopN = docs
} else {
requestTopN = int32(*input.TopN)
if requestTopN < 1 {
return c.JSON(http.StatusUnprocessableEntity, "top_n - should be greater than or equal to 1")
}
if requestTopN > docs { // make it more obvious for backends
requestTopN = docs
}
}
request := &proto.RerankRequest{
Query: input.Query,
TopN: requestTopN,
TopN: int32(input.TopN),
Documents: input.Documents,
}
@@ -71,6 +58,6 @@ func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
response.Usage.TotalTokens = int(results.Usage.TotalTokens)
response.Usage.PromptTokens = int(results.Usage.PromptTokens)
return c.JSON(http.StatusOK, response)
return c.Status(fiber.StatusOK).JSON(response)
}
}

View File

@@ -4,11 +4,11 @@ import (
"encoding/json"
"fmt"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system"
@@ -39,13 +39,13 @@ func CreateBackendEndpointService(galleries []config.Gallery, systemState *syste
// @Summary Returns the job status
// @Success 200 {object} services.GalleryOpStatus "Response"
// @Router /backends/jobs/{uuid} [get]
func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
status := mgs.backendApplier.GetStatus(c.Param("uuid"))
func (mgs *BackendEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.backendApplier.GetStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(200, status)
return c.JSON(status)
}
}
@@ -53,9 +53,9 @@ func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
// @Summary Returns all the jobs status progress
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
// @Router /backends/jobs [get]
func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
return c.JSON(200, mgs.backendApplier.GetAllStatus())
func (mgs *BackendEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.backendApplier.GetAllStatus())
}
}
@@ -64,11 +64,11 @@ func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
// @Param request body GalleryBackend true "query params"
// @Success 200 {object} schema.BackendResponse "Response"
// @Router /backends/apply [post]
func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
func (mgs *BackendEndpointService) ApplyBackendEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryBackend)
// Get input data from the request body
if err := c.Bind(input); err != nil {
if err := c.BodyParser(input); err != nil {
return err
}
@@ -76,13 +76,13 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
if err != nil {
return err
}
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend]{
ID: uuid.String(),
GalleryElementName: input.ID,
Galleries: mgs.galleries,
}
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
}
}
@@ -91,11 +91,11 @@ func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc {
// @Param name path string true "Backend name"
// @Success 200 {object} schema.BackendResponse "Response"
// @Router /backends/delete/{name} [post]
func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
backendName := c.Param("name")
func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backendName := c.Params("name")
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{
mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend]{
Delete: true,
GalleryElementName: backendName,
Galleries: mgs.galleries,
@@ -106,7 +106,7 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
return err
}
return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())})
return c.JSON(schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", utils.BaseURL(c), uuid.String())})
}
}
@@ -114,13 +114,13 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc {
// @Summary List all Backends
// @Success 200 {object} []gallery.GalleryBackend "Response"
// @Router /backends [get]
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
return func(c echo.Context) error {
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backends, err := gallery.ListSystemBackends(systemState)
if err != nil {
return err
}
return c.JSON(200, backends.GetAll())
return c.JSON(backends.GetAll())
}
}
@@ -129,14 +129,14 @@ func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.Syst
// @Success 200 {object} []config.Gallery "Response"
// @Router /backends/galleries [get]
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing backend galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Blob(200, "application/json", dat)
return c.Send(dat)
}
}
@@ -144,12 +144,12 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFu
// @Summary List all available Backends
// @Success 200 {object} []gallery.GalleryBackend "Response"
// @Router /backends/available [get]
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc {
return func(c echo.Context) error {
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
if err != nil {
return err
}
return c.JSON(200, backends)
return c.JSON(backends)
}
}

View File

@@ -1,45 +1,45 @@
package localai
import (
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
)
// BackendMonitorEndpoint returns the status of the specified backend
// @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Success 200 {object} proto.StatusResponse "Response"
// @Router /backend/monitor [get]
func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.Bind(input); err != nil {
return err
}
resp, err := bm.CheckAndSample(input.Model)
if err != nil {
return err
}
return c.JSON(200, resp)
}
}
// BackendShutdownEndpoint shuts down the specified backend
// @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Router /backend/shutdown [post]
func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.Bind(input); err != nil {
return err
}
return bm.ShutdownModel(input.Model)
}
}
package localai
import (
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
)
// BackendMonitorEndpoint returns the status of the specified backend
// @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Success 200 {object} proto.StatusResponse "Response"
// @Router /backend/monitor [get]
func BackendMonitorEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
resp, err := bm.CheckAndSample(input.Model)
if err != nil {
return err
}
return c.JSON(resp)
}
}
// BackendShutdownEndpoint shuts down the specified backend
// @Summary Backend monitor endpoint
// @Param request body schema.BackendMonitorRequest true "Backend statistics request"
// @Router /backend/shutdown [post]
func BackendShutdownEndpoint(bm *services.BackendMonitorService) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.BackendMonitorRequest)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
return bm.ShutdownModel(input.Model)
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -16,17 +16,17 @@ import (
// @Param request body schema.DetectionRequest true "query params"
// @Success 200 {object} schema.DetectionResponse "Response"
// @Router /v1/detection [post]
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
log.Debug().Str("image", input.Image).Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Detection")
@@ -54,6 +54,6 @@ func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appC
}
}
return c.JSON(200, response)
return c.JSON(response)
}
}

View File

@@ -2,13 +2,11 @@ package localai
import (
"fmt"
"io"
"net/http"
"os"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
httpUtils "github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/utils"
@@ -16,15 +14,15 @@ import (
)
// GetEditModelPage renders the edit model page with current configuration
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
modelName := c.Param("name")
func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
if modelName == "" {
response := ModelResponse{
Success: false,
Error: "Model name is required",
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
modelConfig, exists := cl.GetModelConfig(modelName)
@@ -33,7 +31,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Success: false,
Error: "Model configuration not found",
}
return c.JSON(http.StatusNotFound, response)
return c.Status(404).JSON(response)
}
modelConfigFile := modelConfig.GetModelConfigFile()
@@ -42,7 +40,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Success: false,
Error: "Model configuration file not found",
}
return c.JSON(http.StatusNotFound, response)
return c.Status(404).JSON(response)
}
configData, err := os.ReadFile(modelConfigFile)
if err != nil {
@@ -50,7 +48,7 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Success: false,
Error: "Failed to read configuration file: " + err.Error(),
}
return c.JSON(http.StatusInternalServerError, response)
return c.Status(500).JSON(response)
}
// Render the edit page with the current configuration
@@ -71,20 +69,20 @@ func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.Applicatio
Version: internal.PrintableVersion(),
}
return c.Render(http.StatusOK, "views/model-editor", templateData)
return c.Render("views/model-editor", templateData)
}
}
// EditModelEndpoint handles updating existing model configurations
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
modelName := c.Param("name")
func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
if modelName == "" {
response := ModelResponse{
Success: false,
Error: "Model name is required",
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
modelConfig, exists := cl.GetModelConfig(modelName)
@@ -93,24 +91,17 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Existing model configuration not found",
}
return c.JSON(http.StatusNotFound, response)
return c.Status(404).JSON(response)
}
// Get the raw body
body, err := io.ReadAll(c.Request().Body)
if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to read request body: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
}
body := c.Body()
if len(body) == 0 {
response := ModelResponse{
Success: false,
Error: "Request body is empty",
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
// Check content to see if it's a valid model config
@@ -122,7 +113,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
// Validate required fields
@@ -131,7 +122,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Name is required",
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
// Validate the configuration
@@ -141,7 +132,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Error: "Validation failed",
Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."},
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
// Load the existing configuration
@@ -151,7 +142,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Model configuration not trusted: " + err.Error(),
}
return c.JSON(http.StatusNotFound, response)
return c.Status(404).JSON(response)
}
// Write new content to file
@@ -160,16 +151,16 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Failed to write configuration file: " + err.Error(),
}
return c.JSON(http.StatusInternalServerError, response)
return c.Status(500).JSON(response)
}
// Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil {
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.JSON(http.StatusInternalServerError, response)
return c.Status(500).JSON(response)
}
// Preload the model
@@ -178,7 +169,7 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Success: false,
Error: "Failed to preload model: " + err.Error(),
}
return c.JSON(http.StatusInternalServerError, response)
return c.Status(500).JSON(response)
}
// Return success response
@@ -188,20 +179,20 @@ func EditModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applicati
Filename: configPath,
Config: req,
}
return c.JSON(200, response)
return c.JSON(response)
}
}
// ReloadModelsEndpoint handles reloading model configurations from disk
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.JSON(http.StatusInternalServerError, response)
return c.Status(500).JSON(response)
}
// Preload the models
@@ -210,7 +201,7 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
Success: false,
Error: "Failed to preload models: " + err.Error(),
}
return c.JSON(http.StatusInternalServerError, response)
return c.Status(500).JSON(response)
}
// Return success response
@@ -218,6 +209,6 @@ func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applic
Success: true,
Message: "Model configurations reloaded successfully",
}
return c.JSON(http.StatusOK, response)
return c.Status(fiber.StatusOK).JSON(response)
}
}

View File

@@ -2,14 +2,12 @@ package localai_test
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/pkg/system"
@@ -17,14 +15,6 @@ import (
. "github.com/onsi/gomega"
)
// testRenderer is a simple renderer for tests that returns JSON
type testRenderer struct{}
func (t *testRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error {
// For tests, just return the data as JSON
return json.NewEncoder(w).Encode(data)
}
var _ = Describe("Edit Model test", func() {
var tempDir string
@@ -50,35 +40,33 @@ var _ = Describe("Edit Model test", func() {
//modelLoader := model.NewModelLoader(systemState, true)
modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath)
// Define Echo app and register all routes upfront
app := echo.New()
// Set up a simple renderer for the test
app.Renderer = &testRenderer{}
app.POST("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
app.GET("/edit-model/:name", GetEditModelPage(modelConfigLoader, applicationConfig))
// Define Fiber app.
app := fiber.New()
app.Put("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig))
requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`)
req := httptest.NewRequest("POST", "/import-model", requestBody)
req.Header.Set("Content-Type", "application/json")
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
req := httptest.NewRequest("PUT", "/import-model", requestBody)
resp, err := app.Test(req, 5000)
Expect(err).ToNot(HaveOccurred())
body, err := io.ReadAll(rec.Body)
body, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
Expect(err).ToNot(HaveOccurred())
Expect(string(body)).To(ContainSubstring("Model configuration created successfully"))
Expect(rec.Code).To(Equal(http.StatusOK))
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
req = httptest.NewRequest("GET", "/edit-model/foo", nil)
rec = httptest.NewRecorder()
app.ServeHTTP(rec, req)
app.Get("/edit-model/:name", EditModelEndpoint(modelConfigLoader, applicationConfig))
requestBody = bytes.NewBufferString(`{"name": "foo", "parameters": { "model": "foo"}}`)
body, err = io.ReadAll(rec.Body)
req = httptest.NewRequest("GET", "/edit-model/foo", requestBody)
resp, _ = app.Test(req, 1)
body, err = io.ReadAll(resp.Body)
defer resp.Body.Close()
Expect(err).ToNot(HaveOccurred())
// The response contains the model configuration with backend field
Expect(string(body)).To(ContainSubstring(`"backend":"foo"`))
Expect(string(body)).To(ContainSubstring(`"name":"foo"`))
Expect(rec.Code).To(Equal(http.StatusOK))
Expect(string(body)).To(ContainSubstring(`"model":"foo"`))
Expect(resp.StatusCode).To(Equal(fiber.StatusOK))
})
})
})

View File

@@ -1,160 +1,160 @@
package localai
import (
"encoding/json"
"fmt"
"github.com/labstack/echo/v4"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
type ModelGalleryEndpointService struct {
galleries []config.Gallery
backendGalleries []config.Gallery
modelPath string
galleryApplier *services.GalleryService
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
return ModelGalleryEndpointService{
galleries: galleries,
backendGalleries: backendGalleries,
modelPath: systemState.Model.ModelsPath,
galleryApplier: galleryApplier,
}
}
// GetOpStatusEndpoint returns the job status
// @Summary Returns the job status
// @Success 200 {object} services.GalleryOpStatus "Response"
// @Router /models/jobs/{uuid} [get]
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
status := mgs.galleryApplier.GetStatus(c.Param("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(200, status)
}
}
// GetAllStatusEndpoint returns all the jobs status progress
// @Summary Returns all the jobs status progress
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
// @Router /models/jobs [get]
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
return c.JSON(200, mgs.galleryApplier.GetAllStatus())
}
}
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
// @Summary Install models to LocalAI.
// @Param request body GalleryModel true "query params"
// @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/apply [post]
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.Bind(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
Req: input.GalleryModel,
ID: uuid.String(),
GalleryElementName: input.ID,
Galleries: mgs.galleries,
BackendGalleries: mgs.backendGalleries,
}
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
}
}
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
// @Summary delete models to LocalAI.
// @Param name path string true "Model name"
// @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/delete/{name} [post]
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
modelName := c.Param("name")
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
Delete: true,
GalleryElementName: modelName,
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())})
}
}
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
// @Summary List installable models.
// @Success 200 {object} []gallery.GalleryModel "Response"
// @Router /models/available [get]
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc {
return func(c echo.Context) error {
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
if err != nil {
log.Error().Err(err).Msg("could not list models from galleries")
return err
}
log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries))
m := []gallery.Metadata{}
for _, mm := range models {
m = append(m, mm.Metadata)
}
log.Debug().Msgf("Models %#v", m)
dat, err := json.Marshal(m)
if err != nil {
return fmt.Errorf("could not marshal models: %w", err)
}
return c.Blob(200, "application/json", dat)
}
}
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
// @Summary List all Galleries
// @Success 200 {object} []config.Gallery "Response"
// @Router /models/galleries [get]
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() echo.HandlerFunc {
return func(c echo.Context) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Blob(200, "application/json", dat)
}
}
package localai
import (
"encoding/json"
"fmt"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/system"
"github.com/rs/zerolog/log"
)
type ModelGalleryEndpointService struct {
galleries []config.Gallery
backendGalleries []config.Gallery
modelPath string
galleryApplier *services.GalleryService
}
type GalleryModel struct {
ID string `json:"id"`
gallery.GalleryModel
}
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
return ModelGalleryEndpointService{
galleries: galleries,
backendGalleries: backendGalleries,
modelPath: systemState.Model.ModelsPath,
galleryApplier: galleryApplier,
}
}
// GetOpStatusEndpoint returns the job status
// @Summary Returns the job status
// @Success 200 {object} services.GalleryOpStatus "Response"
// @Router /models/jobs/{uuid} [get]
func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
status := mgs.galleryApplier.GetStatus(c.Params("uuid"))
if status == nil {
return fmt.Errorf("could not find any status for ID")
}
return c.JSON(status)
}
}
// GetAllStatusEndpoint returns all the jobs status progress
// @Summary Returns all the jobs status progress
// @Success 200 {object} map[string]services.GalleryOpStatus "Response"
// @Router /models/jobs [get]
func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
return c.JSON(mgs.galleryApplier.GetAllStatus())
}
}
// ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery
// @Summary Install models to LocalAI.
// @Param request body GalleryModel true "query params"
// @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/apply [post]
func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(GalleryModel)
// Get input data from the request body
if err := c.BodyParser(input); err != nil {
return err
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel]{
Req: input.GalleryModel,
ID: uuid.String(),
GalleryElementName: input.ID,
Galleries: mgs.galleries,
BackendGalleries: mgs.backendGalleries,
}
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
}
}
// DeleteModelGalleryEndpoint lets delete models from a LocalAI instance
// @Summary delete models to LocalAI.
// @Param name path string true "Model name"
// @Success 200 {object} schema.GalleryResponse "Response"
// @Router /models/delete/{name} [post]
func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelName := c.Params("name")
mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel]{
Delete: true,
GalleryElementName: modelName,
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
return c.JSON(schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", utils.BaseURL(c), uuid.String())})
}
}
// ListModelFromGalleryEndpoint list the available models for installation from the active galleries
// @Summary List installable models.
// @Success 200 {object} []gallery.GalleryModel "Response"
// @Router /models/available [get]
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
if err != nil {
log.Error().Err(err).Msg("could not list models from galleries")
return err
}
log.Debug().Msgf("Available %d models from %d galleries\n", len(models), len(mgs.galleries))
m := []gallery.Metadata{}
for _, mm := range models {
m = append(m, mm.Metadata)
}
log.Debug().Msgf("Models %#v", m)
dat, err := json.Marshal(m)
if err != nil {
return fmt.Errorf("could not marshal models: %w", err)
}
return c.Send(dat)
}
}
// ListModelGalleriesEndpoint list the available galleries configured in LocalAI
// @Summary List all Galleries
// @Success 200 {object} []config.Gallery "Response"
// @Router /models/galleries [get]
// NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents!
func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
log.Debug().Msgf("Listing model galleries %+v", mgs.galleries)
dat, err := json.Marshal(mgs.galleries)
if err != nil {
return err
}
return c.Send(dat)
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -21,17 +21,17 @@ import (
// @Success 200 {string} binary "generated audio/wav file"
// @Router /v1/tokenMetrics [get]
// @Router /tokenMetrics [get]
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.TokenMetricsRequest)
// Get input data from the request body
if err := c.Bind(input); err != nil {
if err := c.BodyParser(input); err != nil {
return err
}
modelFile, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || modelFile != "" {
modelFile = input.Model
log.Warn().Msgf("Model not found in context: %s", input.Model)
@@ -52,6 +52,6 @@ func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
if err != nil {
return err
}
return c.JSON(200, response)
return c.JSON(response)
}
}

View File

@@ -2,97 +2,33 @@ package localai
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/gallery/importers"
httpUtils "github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/pkg/utils"
"gopkg.in/yaml.v3"
)
// ImportModelURIEndpoint handles creating new model configurations from a URI
func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) echo.HandlerFunc {
return func(c echo.Context) error {
input := new(schema.ImportModelRequest)
if err := c.Bind(input); err != nil {
return err
}
modelConfig, err := importers.DiscoverModelConfig(input.URI, input.Preferences)
if err != nil {
return fmt.Errorf("failed to discover model config: %w", err)
}
uuid, err := uuid.NewUUID()
if err != nil {
return err
}
// Determine gallery ID for tracking - use model name if available, otherwise use URI
galleryID := input.URI
if modelConfig.Name != "" {
galleryID = modelConfig.Name
}
// Register operation in opcache if available (for UI progress tracking)
if opcache != nil {
opcache.Set(galleryID, uuid.String())
}
galleryService.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
Req: gallery.GalleryModel{
Overrides: map[string]interface{}{},
},
ID: uuid.String(),
GalleryElementName: galleryID,
GalleryElement: &modelConfig,
BackendGalleries: appConfig.BackendGalleries,
}
return c.JSON(200, schema.GalleryResponse{
ID: uuid.String(),
StatusURL: fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), uuid.String()),
})
}
}
// ImportModelEndpoint handles creating new model configurations
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) fiber.Handler {
return func(c *fiber.Ctx) error {
// Get the raw body
body, err := io.ReadAll(c.Request().Body)
if err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to read request body: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
}
body := c.Body()
if len(body) == 0 {
response := ModelResponse{
Success: false,
Error: "Request body is empty",
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
// Check content type to determine how to parse
contentType := c.Request().Header.Get("Content-Type")
contentType := string(c.Context().Request.Header.ContentType())
var modelConfig config.ModelConfig
var err error
if strings.Contains(contentType, "application/json") {
// Parse JSON
@@ -101,7 +37,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to parse JSON: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
} else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") {
// Parse YAML
@@ -110,18 +46,18 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
} else {
// Try to auto-detect format
if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' {
if strings.TrimSpace(string(body))[0] == '{' {
// Looks like JSON
if err := json.Unmarshal(body, &modelConfig); err != nil {
response := ModelResponse{
Success: false,
Error: "Failed to parse JSON: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
} else {
// Assume YAML
@@ -130,7 +66,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to parse YAML: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
}
}
@@ -141,7 +77,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Name is required",
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
// Set defaults
@@ -153,7 +89,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Invalid configuration",
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
// Create the configuration file
@@ -163,7 +99,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Model path not trusted: " + err.Error(),
}
return c.JSON(http.StatusBadRequest, response)
return c.Status(400).JSON(response)
}
// Marshal to YAML for storage
@@ -173,7 +109,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to marshal configuration: " + err.Error(),
}
return c.JSON(http.StatusInternalServerError, response)
return c.Status(500).JSON(response)
}
// Write the file
@@ -182,7 +118,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to write configuration file: " + err.Error(),
}
return c.JSON(http.StatusInternalServerError, response)
return c.Status(500).JSON(response)
}
// Reload configurations
if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath); err != nil {
@@ -190,7 +126,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to reload configurations: " + err.Error(),
}
return c.JSON(http.StatusInternalServerError, response)
return c.Status(500).JSON(response)
}
// Preload the model
@@ -199,7 +135,7 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Success: false,
Error: "Failed to preload model: " + err.Error(),
}
return c.JSON(http.StatusInternalServerError, response)
return c.Status(500).JSON(response)
}
// Return success response
response := ModelResponse{
@@ -207,6 +143,6 @@ func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.Applica
Message: "Model configuration created successfully",
Filename: filepath.Base(configPath),
}
return c.JSON(200, response)
return c.JSON(response)
}
}

View File

@@ -1,323 +0,0 @@
package localai
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/cogito"
"github.com/rs/zerolog/log"
)
// MCP SSE Event Types
type MCPReasoningEvent struct {
Type string `json:"type"`
Content string `json:"content"`
}
type MCPToolCallEvent struct {
Type string `json:"type"`
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
Reasoning string `json:"reasoning"`
}
type MCPToolResultEvent struct {
Type string `json:"type"`
Name string `json:"name"`
Result string `json:"result"`
}
type MCPStatusEvent struct {
Type string `json:"type"`
Message string `json:"message"`
}
type MCPAssistantEvent struct {
Type string `json:"type"`
Content string `json:"content"`
}
type MCPErrorEvent struct {
Type string `json:"type"`
Message string `json:"message"`
}
// MCPStreamEndpoint is the SSE streaming endpoint for MCP chat completions
// @Summary Stream MCP chat completions with reasoning, tool calls, and results
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/mcp/chat/completions [post]
func MCPStreamEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
created := int(time.Now().Unix())
// Handle Correlation
id := c.Request().Header.Get("X-Correlation-ID")
if id == "" {
id = fmt.Sprintf("mcp-%d", time.Now().UnixNano())
}
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
}
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return echo.ErrBadRequest
}
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
return fmt.Errorf("no MCP servers configured")
}
// Get MCP config from model config
remote, stdio, err := config.MCP.MCPConfigFromYAML()
if err != nil {
return fmt.Errorf("failed to get MCP config: %w", err)
}
// Check if we have tools in cache, or we have to have an initial connection
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
if err != nil {
return fmt.Errorf("failed to get MCP sessions: %w", err)
}
if len(sessions) == 0 {
return fmt.Errorf("no working MCP servers found")
}
// Build fragment from messages
fragment := cogito.NewEmptyFragment()
for _, message := range input.Messages {
fragment = fragment.AddMessage(message.Role, message.StringContent)
}
port := appConfig.APIAddress[strings.LastIndex(appConfig.APIAddress, ":")+1:]
apiKey := ""
if len(appConfig.ApiKeys) > 0 {
apiKey = appConfig.ApiKeys[0]
}
ctxWithCancellation, cancel := context.WithCancel(ctx)
defer cancel()
// TODO: instead of connecting to the API, we should just wire this internally
// and act like completion.go.
// We can do this as cogito expects an interface and we can create one that
// we satisfy to just call internally ComputeChoices
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
// Build cogito options using the consolidated method
cogitoOpts := config.BuildCogitoOptions()
cogitoOpts = append(
cogitoOpts,
cogito.WithContext(ctxWithCancellation),
cogito.WithMCPs(sessions...),
)
// Check if streaming is requested
toStream := input.Stream
if !toStream {
// Non-streaming mode: execute synchronously and return JSON response
cogitoOpts = append(
cogitoOpts,
cogito.WithStatusCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
}),
cogito.WithReasoningCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
}),
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("reasoning", t.Reasoning).Interface("arguments", t.Arguments).Msg("[model agent] Tool call")
return true
}),
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
log.Debug().Str("model", config.Name).Str("tool", t.Name).Str("result", t.Result).Interface("tool_arguments", t.ToolArguments).Msg("[model agent] Tool call result")
}),
)
f, err := cogito.ExecuteTools(
defaultLLM, fragment,
cogitoOpts...,
)
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
return err
}
f, err = defaultLLM.Ask(ctxWithCancellation, f)
if err != nil {
return err
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Message: &schema.Message{Role: "assistant", Content: &f.LastMessage().Content}}},
Object: "chat.completion",
}
jsonResult, _ := json.Marshal(resp)
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(200, resp)
}
// Streaming mode: use SSE
// Set up SSE headers
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().Header().Set("X-Correlation-ID", id)
// Create channel for streaming events
events := make(chan interface{})
ended := make(chan error, 1)
// Set up callbacks for streaming
statusCallback := func(s string) {
events <- MCPStatusEvent{
Type: "status",
Message: s,
}
}
reasoningCallback := func(s string) {
events <- MCPReasoningEvent{
Type: "reasoning",
Content: s,
}
}
toolCallCallback := func(t *cogito.ToolChoice) bool {
events <- MCPToolCallEvent{
Type: "tool_call",
Name: t.Name,
Arguments: t.Arguments,
Reasoning: t.Reasoning,
}
return true
}
toolCallResultCallback := func(t cogito.ToolStatus) {
events <- MCPToolResultEvent{
Type: "tool_result",
Name: t.Name,
Result: t.Result,
}
}
cogitoOpts = append(cogitoOpts,
cogito.WithStatusCallback(statusCallback),
cogito.WithReasoningCallback(reasoningCallback),
cogito.WithToolCallBack(toolCallCallback),
cogito.WithToolCallResultCallback(toolCallResultCallback),
)
// Execute tools in a goroutine
go func() {
defer close(events)
f, err := cogito.ExecuteTools(
defaultLLM, fragment,
cogitoOpts...,
)
if err != nil && !errors.Is(err, cogito.ErrNoToolSelected) {
events <- MCPErrorEvent{
Type: "error",
Message: fmt.Sprintf("Failed to execute tools: %v", err),
}
ended <- err
return
}
// Get final response
f, err = defaultLLM.Ask(ctxWithCancellation, f)
if err != nil {
events <- MCPErrorEvent{
Type: "error",
Message: fmt.Sprintf("Failed to get response: %v", err),
}
ended <- err
return
}
// Stream final assistant response
content := f.LastMessage().Content
events <- MCPAssistantEvent{
Type: "assistant",
Content: content,
}
ended <- nil
}()
// Stream events to client
LOOP:
for {
select {
case <-ctx.Done():
// Context was cancelled (client disconnected or request cancelled)
log.Debug().Msgf("Request context cancelled, stopping stream")
cancel()
break LOOP
case event := <-events:
if event == nil {
// Channel closed
break LOOP
}
eventData, err := json.Marshal(event)
if err != nil {
log.Debug().Msgf("Failed to marshal event: %v", err)
continue
}
log.Debug().Msgf("Sending event: %s", string(eventData))
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(eventData))
if err != nil {
log.Debug().Msgf("Sending event failed: %v", err)
cancel()
return err
}
c.Response().Flush()
case err := <-ended:
if err == nil {
// Send done signal
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
errorEvent := MCPErrorEvent{
Type: "error",
Message: err.Error(),
}
errorData, marshalErr := json.Marshal(errorEvent)
if marshalErr != nil {
fmt.Fprintf(c.Response().Writer, "data: {\"type\":\"error\",\"message\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
}
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
}
}
log.Debug().Msgf("Stream ended")
return nil
}
}

View File

@@ -1,47 +1,46 @@
package localai
import (
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/services"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
// @Summary Prometheus metrics endpoint
// @Param request body config.Gallery true "Gallery details"
// @Router /metrics [get]
func LocalAIMetricsEndpoint() echo.HandlerFunc {
return echo.WrapHandler(promhttp.Handler())
}
type apiMiddlewareConfig struct {
Filter func(c echo.Context) bool
metricsService *services.LocalAIMetricsService
}
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) echo.MiddlewareFunc {
cfg := apiMiddlewareConfig{
metricsService: metrics,
Filter: func(c echo.Context) bool {
return c.Path() == "/metrics"
},
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if cfg.Filter != nil && cfg.Filter(c) {
return next(c)
}
path := c.Path()
method := c.Request().Method
start := time.Now()
err := next(c)
elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metricsService.ObserveAPICall(method, path, elapsed)
return err
}
}
}
package localai
import (
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/adaptor"
"github.com/mudler/LocalAI/core/services"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI
// @Summary Prometheus metrics endpoint
// @Param request body config.Gallery true "Gallery details"
// @Router /metrics [get]
func LocalAIMetricsEndpoint() fiber.Handler {
return adaptor.HTTPHandler(promhttp.Handler())
}
type apiMiddlewareConfig struct {
Filter func(c *fiber.Ctx) bool
metricsService *services.LocalAIMetricsService
}
func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) fiber.Handler {
cfg := apiMiddlewareConfig{
metricsService: metrics,
Filter: func(c *fiber.Ctx) bool {
return c.Path() == "/metrics"
},
}
return func(c *fiber.Ctx) error {
if cfg.Filter != nil && cfg.Filter(c) {
return c.Next()
}
path := c.Path()
method := c.Method()
start := time.Now()
err := c.Next()
elapsed := float64(time.Since(start)) / float64(time.Second)
cfg.metricsService.ObserveAPICall(method, path, elapsed)
return err
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/p2p"
"github.com/mudler/LocalAI/core/schema"
@@ -11,10 +11,10 @@ import (
// @Summary Returns available P2P nodes
// @Success 200 {object} []schema.P2PNodesResponse "Response"
// @Router /api/p2p [get]
func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
func ShowP2PNodes(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
// Render index
return func(c echo.Context) error {
return c.JSON(200, schema.P2PNodesResponse{
return func(c *fiber.Ctx) error {
return c.JSON(schema.P2PNodesResponse{
Nodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.WorkerID)),
FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)),
})
@@ -25,6 +25,6 @@ func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc {
// @Summary Show the P2P token
// @Success 200 {string} string "Response"
// @Router /api/p2p/token [get]
func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error { return c.String(200, appConfig.P2PToken) }
func ShowP2PToken(appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error { return c.Send([]byte(appConfig.P2PToken)) }
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
@@ -9,11 +9,11 @@ import (
"github.com/mudler/LocalAI/pkg/store"
)
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresSet)
if err := c.Bind(input); err != nil {
if err := c.BodyParser(input); err != nil {
return err
}
@@ -28,20 +28,20 @@ func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
vals[i] = []byte(v)
}
err = store.SetCols(c.Request().Context(), sb, input.Keys, vals)
err = store.SetCols(c.Context(), sb, input.Keys, vals)
if err != nil {
return err
}
return c.NoContent(200)
return c.Send(nil)
}
}
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresDelete)
if err := c.Bind(input); err != nil {
if err := c.BodyParser(input); err != nil {
return err
}
@@ -51,19 +51,19 @@ func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationCo
}
defer sl.Close()
if err := store.DeleteCols(c.Request().Context(), sb, input.Keys); err != nil {
if err := store.DeleteCols(c.Context(), sb, input.Keys); err != nil {
return err
}
return c.NoContent(200)
return c.Send(nil)
}
}
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresGet)
if err := c.Bind(input); err != nil {
if err := c.BodyParser(input); err != nil {
return err
}
@@ -73,7 +73,7 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
}
defer sl.Close()
keys, vals, err := store.GetCols(c.Request().Context(), sb, input.Keys)
keys, vals, err := store.GetCols(c.Context(), sb, input.Keys)
if err != nil {
return err
}
@@ -87,15 +87,15 @@ func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfi
res.Values[i] = string(v)
}
return c.JSON(200, res)
return c.JSON(res)
}
}
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input := new(schema.StoresFind)
if err := c.Bind(input); err != nil {
if err := c.BodyParser(input); err != nil {
return err
}
@@ -105,7 +105,7 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
}
defer sl.Close()
keys, vals, similarities, err := store.Find(c.Request().Context(), sb, input.Key, input.Topk)
keys, vals, similarities, err := store.Find(c.Context(), sb, input.Key, input.Topk)
if err != nil {
return err
}
@@ -120,6 +120,6 @@ func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConf
res.Values[i] = string(v)
}
return c.JSON(200, res)
return c.JSON(res)
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/pkg/model"
@@ -11,8 +11,8 @@ import (
// @Summary Show the LocalAI instance information
// @Success 200 {object} schema.SystemInformationResponse "Response"
// @Router /system [get]
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
availableBackends := []string{}
loadedModels := ml.ListLoadedModels()
for b := range appConfig.ExternalGRPCBackends {
@@ -26,7 +26,7 @@ func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConf
for _, m := range loadedModels {
sysmodels = append(sysmodels, schema.SysInfoModel{ID: m.ID})
}
return c.JSON(200,
return c.JSON(
schema.SystemInformationResponse{
Backends: availableBackends,
Models: sysmodels,

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -14,22 +14,22 @@ import (
// @Param request body schema.TokenizeRequest true "Request"
// @Success 200 {object} schema.TokenizeResponse "Response"
// @Router /v1/tokenize [post]
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig)
if err != nil {
return err
}
return c.JSON(200, tokenResponse)
return ctx.JSON(tokenResponse)
}
}

View File

@@ -1,14 +1,12 @@
package localai
import (
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/schema"
"github.com/rs/zerolog/log"
@@ -24,16 +22,16 @@ import (
// @Success 200 {string} binary "generated audio/wav file"
// @Router /v1/audio/speech [post]
// @Router /tts [post]
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request received")
@@ -61,6 +59,6 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
return err
}
return c.Attachment(filePath, filepath.Base(filePath))
return c.Download(filePath)
}
}

View File

@@ -1,7 +1,7 @@
package localai
import (
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -16,26 +16,26 @@ import (
// @Param request body schema.VADRequest true "query params"
// @Success 200 {object} proto.VADResponse "Response"
// @Router /vad [post]
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
log.Debug().Str("model", input.Model).Msg("LocalAI VAD Request received")
resp, err := backend.VAD(input, c.Request().Context(), ml, appConfig, *cfg)
resp, err := backend.VAD(input, c.Context(), ml, appConfig, *cfg)
if err != nil {
return err
}
return c.JSON(200, resp)
return c.JSON(resp)
}
}

View File

@@ -7,20 +7,19 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend"
"github.com/gofiber/fiber/v2"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@@ -65,18 +64,18 @@ func downloadFile(url string) (string, error) {
// @Param request body schema.VideoRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /video [post]
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
if !ok || input.Model == "" {
log.Error().Msg("Video Endpoint - Invalid Input")
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
log.Error().Msg("Video Endpoint - Invalid Config")
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
src := ""
@@ -165,7 +164,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
return err
}
baseURL := middleware.BaseURL(c)
baseURL := c.BaseURL()
fn, err := backend.VideoGeneration(
height,
@@ -202,10 +201,7 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL, err = url.JoinPath(baseURL, "generated-videos", base)
if err != nil {
return err
}
item.URL = baseURL + "/generated-videos/" + base
}
id := uuid.New().String()
@@ -220,6 +216,6 @@ func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(200, resp)
return c.JSON(resp)
}
}

View File

@@ -1,20 +1,18 @@
package localai
import (
"strings"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/internal"
"github.com/mudler/LocalAI/pkg/model"
)
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) echo.HandlerFunc {
return func(c echo.Context) error {
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
return func(c *fiber.Ctx) error {
modelConfigs := cl.GetAllModelsConfigs()
galleryConfigs := map[string]*gallery.ModelConfig{}
@@ -42,10 +40,10 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
// Get model statuses to display in the UI the operation in progress
processingModels, taskTypes := opcache.GetStatus()
summary := map[string]interface{}{
summary := fiber.Map{
"Title": "LocalAI API - " + internal.PrintableVersion(),
"Version": internal.PrintableVersion(),
"BaseURL": middleware.BaseURL(c),
"BaseURL": utils.BaseURL(c),
"Models": modelsWithoutConfig,
"ModelsConfig": modelConfigs,
"GalleryConfig": galleryConfigs,
@@ -56,21 +54,12 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
"InstalledBackends": installedBackends,
}
contentType := c.Request().Header.Get("Content-Type")
accept := c.Request().Header.Get("Accept")
// Default to HTML if Accept header is empty (browser behavior)
// Only return JSON if explicitly requested or Content-Type is application/json
if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "text/html")) {
if string(c.Context().Request.Header.ContentType()) == "application/json" || len(c.Accepts("html")) == 0 {
// The client expects a JSON response
return c.JSON(200, summary)
return c.Status(fiber.StatusOK).JSON(summary)
} else {
// Check if this is the manage route
templateName := "views/index"
if strings.HasSuffix(c.Request().URL.Path, "/manage") || c.Request().URL.Path == "/manage" {
templateName = "views/manage"
}
// Render appropriate template
return c.Render(200, templateName, summary)
// Render index
return c.Render("views/index", summary)
}
}
}

View File

@@ -1,12 +1,14 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -17,6 +19,7 @@ import (
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create
@@ -24,7 +27,7 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/chat/completions [post]
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) echo.HandlerFunc {
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
var id, textContentToReturn string
var created int
@@ -33,7 +36,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &textContentToReturn}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
@@ -53,7 +56,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0, FinishReason: nil}},
Choices: []schema.Choice{{Delta: &schema.Message{Content: &s}, Index: 0}},
Object: "chat.completion.chunk",
Usage: usage,
}
@@ -87,7 +90,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}},
Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant", Content: &textContentToReturn}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
@@ -111,7 +114,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
ID: id,
Created: created,
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0, FinishReason: nil}},
Choices: []schema.Choice{{Delta: &schema.Message{Content: &result}, Index: 0}},
Object: "chat.completion.chunk",
Usage: usage,
}
@@ -139,10 +142,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
},
},
},
},
Index: 0,
FinishReason: nil,
}},
}}},
Object: "chat.completion.chunk",
}
responses <- initialMessage
@@ -165,10 +165,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
},
},
},
},
Index: 0,
FinishReason: nil,
}},
}}},
Object: "chat.completion.chunk",
}
}
@@ -178,21 +175,21 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
return err
}
return func(c echo.Context) error {
return func(c *fiber.Ctx) error {
textContentToReturn = ""
id = uuid.New().String()
created = int(time.Now().Unix())
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
extraUsage := c.Get("Extra-Usage", "") != ""
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
log.Debug().Msgf("Chat endpoint configuration read: %+v", config)
@@ -220,7 +217,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
}
// If we are using a response format, we need to generate a grammar for it
if config.ResponseFormatMap != nil {
d := schema.ChatCompletionResponseFormat{}
dat, err := json.Marshal(config.ResponseFormatMap)
@@ -264,7 +260,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}
switch {
// Generates grammar with internal's LocalAI engine
case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn:
noActionGrammar := functions.Function{
Name: noActionName,
@@ -288,7 +283,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
funcs = funcs.Select(config.FunctionToCall())
}
// Update input grammar or json_schema based on use_llama_grammar option
// Update input grammar
jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey)
g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...)
if err == nil {
@@ -303,7 +298,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
} else {
log.Error().Err(err).Msg("Failed generating grammar")
}
default:
// Force picking one of the functions by the request
if config.FunctionToCall() != "" {
@@ -322,7 +316,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
// If we are using the tokenizer template, we don't need to process the messages
// unless we are processing functions
if !config.TemplateConfig.UseTokenizerTemplate {
if !config.TemplateConfig.UseTokenizerTemplate || shouldUseFn {
predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn)
log.Debug().Msgf("Prompt (after templating): %s", predInput)
@@ -335,10 +329,13 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
case toStream:
log.Debug().Msgf("Stream request received")
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().Header().Set("X-Correlation-ID", id)
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
// c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.Set("X-Correlation-ID", id)
responses := make(chan schema.OpenAIResponse)
ended := make(chan error, 1)
@@ -351,101 +348,89 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}
}()
usage := &schema.OpenAIUsage{}
toolsCalled := false
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
usage := &schema.OpenAIUsage{}
toolsCalled := false
LOOP:
for {
select {
case <-input.Context.Done():
// Context was cancelled (client disconnected or request cancelled)
log.Debug().Msgf("Request context cancelled, stopping stream")
input.Cancel()
break LOOP
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
respData, err := json.Marshal(ev)
if err != nil {
log.Debug().Msgf("Failed to marshal response: %v", err)
input.Cancel()
continue
}
log.Debug().Msgf("Sending chunk: %s", string(respData))
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
return err
}
c.Response().Flush()
case err := <-ended:
if err == nil {
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
LOOP:
for {
select {
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it
if len(ev.Choices[0].Delta.ToolCalls) > 0 {
toolsCalled = true
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
_, err := fmt.Fprintf(w, "data: %v\n", buf.String())
if err != nil {
log.Debug().Msgf("Sending chunk failed: %v", err)
input.Cancel()
}
w.Flush()
case err := <-ended:
if err == nil {
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
stopReason := FinishReasonStop
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: &stopReason,
Index: 0,
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, marshalErr := json.Marshal(resp)
if marshalErr != nil {
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
// Send a simple error message as fallback
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
}
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: "stop",
Index: 0,
Delta: &schema.Message{Content: "Internal error: " + err.Error()},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
return nil
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
return
}
}
}
finishReason := FinishReasonStop
if toolsCalled && len(input.Tools) > 0 {
finishReason = FinishReasonToolCalls
} else if toolsCalled {
finishReason = FinishReasonFunctionCall
}
finishReason := "stop"
if toolsCalled && len(input.Tools) > 0 {
finishReason = "tool_calls"
} else if toolsCalled {
finishReason = "function_call"
}
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: &finishReason,
Index: 0,
Delta: &schema.Message{},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
FinishReason: finishReason,
Index: 0,
Delta: &schema.Message{Content: &textContentToReturn},
}},
Object: "chat.completion.chunk",
Usage: *usage,
}
respData, _ := json.Marshal(resp)
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
log.Debug().Msgf("Stream ended")
}))
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
log.Debug().Msgf("Stream ended")
return nil
// no streaming mode
@@ -454,8 +439,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
tokenCallback := func(s string, c *[]schema.Choice) {
if !shouldUseFn {
// no function is called, just reply and use stop as finish reason
stopReason := FinishReasonStop
*c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return
}
@@ -473,14 +457,12 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
return
}
stopReason := FinishReasonStop
*c = append(*c, schema.Choice{
FinishReason: &stopReason,
FinishReason: "stop",
Message: &schema.Message{Role: "assistant", Content: &result}})
default:
toolCallsReason := FinishReasonToolCalls
toolChoice := schema.Choice{
FinishReason: &toolCallsReason,
FinishReason: "tool_calls",
Message: &schema.Message{
Role: "assistant",
},
@@ -504,9 +486,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
)
} else {
// otherwise we return more choices directly (deprecated)
functionCallReason := FinishReasonFunctionCall
*c = append(*c, schema.Choice{
FinishReason: &functionCallReason,
FinishReason: "function_call",
Message: &schema.Message{
Role: "assistant",
Content: &textContentToReturn,
@@ -527,9 +508,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
}
// Echo properly supports context cancellation via c.Request().Context()
// No workaround needed!
result, tokenUsage, err := ComputeChoices(
input,
predInput,
@@ -565,7 +543,7 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
log.Debug().Msgf("Response: %s", respData)
// Return the prediction in the response body
return c.JSON(200, resp)
return c.JSON(resp)
}
}
}
@@ -619,48 +597,7 @@ func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, in
audios = append(audios, m.StringAudios...)
}
// Serialize tools and tool_choice to JSON strings
toolsJSON := ""
if len(input.Tools) > 0 {
toolsBytes, err := json.Marshal(input.Tools)
if err == nil {
toolsJSON = string(toolsBytes)
}
}
toolChoiceJSON := ""
if input.ToolsChoice != nil {
toolChoiceBytes, err := json.Marshal(input.ToolsChoice)
if err == nil {
toolChoiceJSON = string(toolChoiceBytes)
}
}
// Extract logprobs from request
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
var logprobs *int
var topLogprobs *int
if input.Logprobs.IsEnabled() {
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
if input.TopLogprobs != nil {
topLogprobs = input.TopLogprobs
// For backend compatibility, set logprobs to the top_logprobs value
logprobs = input.TopLogprobs
} else {
// Default to 1 if logprobs is true but top_logprobs not specified
val := 1
logprobs = &val
topLogprobs = &val
}
}
// Extract logit_bias from request
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
var logitBias map[string]float64
if len(input.LogitBias) > 0 {
logitBias = input.LogitBias
}
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, cl, o, nil)
if err != nil {
log.Error().Err(err).Msg("model inference failed")
return "", err

View File

@@ -1,22 +1,25 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions
@@ -24,7 +27,7 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/completions [post]
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
created := int(time.Now().Unix())
@@ -44,9 +47,8 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
Text: s,
FinishReason: nil,
Index: 0,
Text: s,
},
},
Object: "text_completion",
@@ -62,25 +64,22 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
return err
}
return func(c echo.Context) error {
return func(c *fiber.Ctx) error {
created := int(time.Now().Unix())
// Handle Correlation
id := c.Request().Header.Get("X-Correlation-ID")
if id == "" {
id = uuid.New().String()
}
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
id := c.Get("X-Correlation-ID", uuid.New().String())
extraUsage := c.Get("Extra-Usage", "") != ""
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
if config.ResponseFormatMap != nil {
@@ -98,10 +97,15 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
if input.Stream {
log.Debug().Msgf("Stream request received")
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Context().SetContentType("text/event-stream")
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
//c.Set("Content-Type", "text/event-stream")
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
}
if input.Stream {
if len(config.PromptStrings) > 1 {
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
}
@@ -126,78 +130,53 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
ended <- process(id, predInput, input, config, ml, responses, extraUsage)
}()
LOOP:
for {
select {
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
respData, err := json.Marshal(ev)
if err != nil {
log.Debug().Msgf("Failed to marshal response: %v", err)
continue
}
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
log.Debug().Msgf("Sending chunk: %s", string(respData))
_, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData))
if err != nil {
return err
}
c.Response().Flush()
case err := <-ended:
if err == nil {
LOOP:
for {
select {
case ev := <-responses:
if len(ev.Choices) == 0 {
log.Debug().Msgf("No choices in the response, skipping")
continue
}
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
enc.Encode(ev)
log.Debug().Msgf("Sending chunk: %s", buf.String())
fmt.Fprintf(w, "data: %v\n", buf.String())
w.Flush()
case err := <-ended:
if err == nil {
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
fmt.Fprintf(w, "data: %v\n", "Internal error: "+err.Error())
w.Flush()
break LOOP
}
log.Error().Msgf("Stream ended with error: %v", err)
stopReason := FinishReasonStop
errorResp := schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model,
Choices: []schema.Choice{
{
Index: 0,
FinishReason: &stopReason,
Text: "Internal error: " + err.Error(),
},
},
Object: "text_completion",
}
errorData, marshalErr := json.Marshal(errorResp)
if marshalErr != nil {
log.Error().Msgf("Failed to marshal error response: %v", marshalErr)
// Send a simple error message as fallback
fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n")
} else {
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData))
}
c.Response().Flush()
return nil
}
}
stopReason := FinishReasonStop
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
FinishReason: &stopReason,
resp := &schema.OpenAIResponse{
ID: id,
Created: created,
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
Choices: []schema.Choice{
{
Index: 0,
FinishReason: "stop",
},
},
},
Object: "text_completion",
}
respData, _ := json.Marshal(resp)
Object: "text_completion",
}
respData, _ := json.Marshal(resp)
fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData)
fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n")
c.Response().Flush()
return nil
w.WriteString(fmt.Sprintf("data: %s\n\n", respData))
w.WriteString("data: [DONE]\n\n")
w.Flush()
}))
return <-ended
}
var result []schema.Choice
@@ -218,8 +197,7 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
r, tokenUsage, err := ComputeChoices(
input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) {
stopReason := FinishReasonStop
*c = append(*c, schema.Choice{Text: s, FinishReason: &stopReason, Index: k})
*c = append(*c, schema.Choice{Text: s, FinishReason: "stop", Index: k})
}, nil)
if err != nil {
return err
@@ -253,6 +231,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(200, resp)
return c.JSON(resp)
}
}

View File

@@ -1,8 +0,0 @@
package openai
// Finish reason constants for OpenAI API responses
const (
FinishReasonStop = "stop"
FinishReasonToolCalls = "tool_calls"
FinishReasonFunctionCall = "function_call"
)

View File

@@ -4,11 +4,11 @@ import (
"encoding/json"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
@@ -23,20 +23,20 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/edits [post]
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c echo.Context) error {
return func(c *fiber.Ctx) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
// Opt-in extra usage flag
extraUsage := c.Request().Header.Get("Extra-Usage") != ""
extraUsage := c.Get("Extra-Usage", "") != ""
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
log.Debug().Msgf("Edit Endpoint Input : %+v", input)
@@ -98,6 +98,6 @@ func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(200, resp)
return c.JSON(resp)
}
}

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -13,6 +12,7 @@ import (
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
@@ -21,16 +21,16 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/embeddings [post]
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
log.Debug().Msgf("Parameter Config: %+v", config)
@@ -78,6 +78,6 @@ func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(200, resp)
return c.JSON(resp)
}
}

View File

@@ -7,7 +7,6 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
@@ -15,13 +14,13 @@ import (
"time"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/backend"
"github.com/gofiber/fiber/v2"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
)
@@ -66,18 +65,18 @@ func downloadFile(url string) (string, error) {
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /v1/images/generations [post]
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
log.Error().Msg("Image Endpoint - Invalid Input")
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
log.Error().Msg("Image Endpoint - Invalid Config")
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
// Process input images (for img2img/inpainting)
@@ -189,7 +188,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
return err
}
baseURL := middleware.BaseURL(c)
baseURL := c.BaseURL()
// Use the first input image as src if available, otherwise use the original src
inputSrc := src
@@ -216,10 +215,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
item.B64JSON = base64.StdEncoding.EncodeToString(data)
} else {
base := filepath.Base(output)
item.URL, err = url.JoinPath(baseURL, "generated-images", base)
if err != nil {
return err
}
item.URL = baseURL + "/generated-images/" + base
}
result = append(result, *item)
@@ -238,7 +234,7 @@ func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfi
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(200, resp)
return c.JSON(resp)
}
}

View File

@@ -1,8 +1,6 @@
package openai
import (
"encoding/json"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
@@ -39,50 +37,8 @@ func ComputeChoices(
audios = append(audios, m.StringAudios...)
}
// Serialize tools and tool_choice to JSON strings
toolsJSON := ""
if len(req.Tools) > 0 {
toolsBytes, err := json.Marshal(req.Tools)
if err == nil {
toolsJSON = string(toolsBytes)
}
}
toolChoiceJSON := ""
if req.ToolsChoice != nil {
toolChoiceBytes, err := json.Marshal(req.ToolsChoice)
if err == nil {
toolChoiceJSON = string(toolChoiceBytes)
}
}
// Extract logprobs from request
// According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position
var logprobs *int
var topLogprobs *int
if req.Logprobs.IsEnabled() {
// If logprobs is enabled, use top_logprobs if provided, otherwise default to 1
if req.TopLogprobs != nil {
topLogprobs = req.TopLogprobs
// For backend compatibility, set logprobs to the top_logprobs value
logprobs = req.TopLogprobs
} else {
// Default to 1 if logprobs is true but top_logprobs not specified
val := 1
logprobs = &val
topLogprobs = &val
}
}
// Extract logit_bias from request
// According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100)
var logitBias map[string]float64
if len(req.LogitBias) > 0 {
logitBias = req.LogitBias
}
// get the model function to call for the result
predFunc, err := backend.ModelInference(
req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias)
predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback)
if err != nil {
return result, backend.TokenUsage{}, err
}
@@ -103,11 +59,6 @@ func ComputeChoices(
finetunedResponse := backend.Finetune(*config, predInput, prediction.Response)
cb(finetunedResponse, &result)
// Add logprobs to the last choice if present
if prediction.Logprobs != nil && len(result) > 0 {
result[len(result)-1].Logprobs = prediction.Logprobs
}
//result = append(result, Choice{Text: prediction})
}

View File

@@ -1,7 +1,7 @@
package openai
import (
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
@@ -12,15 +12,14 @@ import (
// @Summary List and describe the various models available in the API.
// @Success 200 {object} schema.ModelsDataResponse "Response"
// @Router /v1/models [get]
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
// If blank, no filter is applied.
filter := c.QueryParam("filter")
filter := c.Query("filter")
// By default, exclude any loose files that are already referenced by a configuration file.
var policy services.LooseFilePolicy
excludeConfigured := c.QueryParam("excludeConfigured")
if excludeConfigured == "" || excludeConfigured == "true" {
if c.QueryBool("excludeConfigured", true) {
policy = services.SKIP_IF_CONFIGURED
} else {
policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user?
@@ -42,7 +41,7 @@ func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, ap
dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"})
}
return c.JSON(200, schema.ModelsDataResponse{
return c.JSON(schema.ModelsDataResponse{
Object: "list",
Data: dataModels,
})

View File

@@ -1,18 +1,17 @@
package openai
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/templates"
@@ -26,27 +25,24 @@ import (
// @Param request body schema.OpenAIRequest true "query params"
// @Success 200 {object} schema.OpenAIResponse "Response"
// @Router /mcp/v1/completions [post]
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
// We do not support streaming mode (Yet?)
return func(c echo.Context) error {
return func(c *fiber.Ctx) error {
created := int(time.Now().Unix())
ctx := c.Request().Context()
ctx := c.Context()
// Handle Correlation
id := c.Request().Header.Get("X-Correlation-ID")
if id == "" {
id = uuid.New().String()
}
id := c.Get("X-Correlation-ID", uuid.New().String())
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
if config.MCP.Servers == "" && config.MCP.Stdio == "" {
@@ -54,15 +50,12 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
}
// Get MCP config from model config
remote, stdio, err := config.MCP.MCPConfigFromYAML()
if err != nil {
return fmt.Errorf("failed to get MCP config: %w", err)
}
remote, stdio := config.MCP.MCPConfigFromYAML()
// Check if we have tools in cache, or we have to have an initial connection
sessions, err := mcpTools.SessionsFromMCPConfig(config.Name, remote, stdio)
if err != nil {
return fmt.Errorf("failed to get MCP sessions: %w", err)
return err
}
if len(sessions) == 0 {
@@ -80,37 +73,46 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
if appConfig.ApiKeys != nil {
apiKey = appConfig.ApiKeys[0]
}
ctxWithCancellation, cancel := context.WithCancel(ctx)
defer cancel()
// TODO: instead of connecting to the API, we should just wire this internally
// and act like completion.go.
// We can do this as cogito expects an interface and we can create one that
// we satisfy to just call internally ComputeChoices
defaultLLM := cogito.NewOpenAILLM(config.Name, apiKey, "http://127.0.0.1:"+port)
// Build cogito options using the consolidated method
cogitoOpts := config.BuildCogitoOptions()
cogitoOpts = append(
cogitoOpts,
cogito.WithContext(ctxWithCancellation),
cogito.WithMCPs(sessions...),
cogitoOpts := []cogito.Option{
cogito.WithStatusCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Status: %s", config.Name, s)
}),
cogito.WithReasoningCallback(func(s string) {
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
}),
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", t.Name, t.Reasoning, t.Arguments)
return true
}),
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, tool arguments: %+v", t.Name, t.Result, t.ToolArguments)
}),
)
cogito.WithContext(ctx),
cogito.WithMCPs(sessions...),
cogito.WithIterations(3), // default to 3 iterations
cogito.WithMaxAttempts(3), // default to 3 attempts
cogito.WithForceReasoning(),
}
if config.Agent.EnableReasoning {
cogitoOpts = append(cogitoOpts, cogito.EnableToolReasoner)
}
if config.Agent.EnablePlanning {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan)
}
if config.Agent.EnableMCPPrompts {
cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts)
}
if config.Agent.EnablePlanReEvaluator {
cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator)
}
if config.Agent.MaxIterations != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithIterations(config.Agent.MaxIterations))
}
if config.Agent.MaxAttempts != 0 {
cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(config.Agent.MaxAttempts))
}
f, err := cogito.ExecuteTools(
defaultLLM, fragment,
@@ -137,6 +139,6 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
log.Debug().Msgf("Response: %s", jsonResult)
// Return the prediction in the response body
return c.JSON(200, resp)
return c.JSON(resp)
}
}

View File

@@ -10,11 +10,9 @@ import (
"sync"
"time"
"net/http"
"github.com/go-audio/audio"
"github.com/gorilla/websocket"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/openai/types"
@@ -169,50 +167,32 @@ type Model interface {
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins
},
}
// TODO: Implement ephemeral keys to allow these endpoints to be used
func RealtimeSessions(application *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
return c.NoContent(501)
func RealtimeSessions(application *application.Application) fiber.Handler {
return func(ctx *fiber.Ctx) error {
return ctx.SendStatus(501)
}
}
func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
return c.NoContent(501)
func RealtimeTranscriptionSession(application *application.Application) fiber.Handler {
return func(ctx *fiber.Ctx) error {
return ctx.SendStatus(501)
}
}
func Realtime(application *application.Application) echo.HandlerFunc {
return func(c echo.Context) error {
ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil)
if err != nil {
return err
}
defer ws.Close()
// Extract query parameters from Echo context before passing to websocket handler
model := c.QueryParam("model")
if model == "" {
model = "gpt-4o"
}
intent := c.QueryParam("intent")
registerRealtime(application, model, intent)(ws)
return nil
}
func Realtime(application *application.Application) fiber.Handler {
return websocket.New(registerRealtime(application))
}
func registerRealtime(application *application.Application, model, intent string) func(c *websocket.Conn) {
func registerRealtime(application *application.Application) func(c *websocket.Conn) {
return func(c *websocket.Conn) {
evaluator := application.TemplatesEvaluator()
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
model := c.Query("model", "gpt-4o")
intent := c.Query("intent")
if intent != "transcription" {
sendNotImplemented(c, "Only transcription mode is supported which requires the intent=transcription parameter")
}
@@ -1087,13 +1067,12 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st
// For example, the model might return a special token or JSON indicating a function call
/*
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil, "", "", nil, nil, nil)
predFunc, err := backend.ModelInference(context.Background(), prompt, input.Messages, images, videos, audios, ml, *config, o, nil)
result, tokenUsage, err := ComputeChoices(input, prompt, config, startupOptions, ml, func(s string, c *[]schema.Choice) {
if !shouldUseFn {
// no function is called, just reply and use stop as finish reason
stopReason := FinishReasonStop
*c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
*c = append(*c, schema.Choice{FinishReason: "stop", Index: 0, Message: &schema.Message{Role: "assistant", Content: &s}})
return
}
@@ -1120,8 +1099,7 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st
}
if len(input.Tools) > 0 {
toolCallsReason := FinishReasonToolCalls
toolChoice.FinishReason = &toolCallsReason
toolChoice.FinishReason = "tool_calls"
}
for _, ss := range results {
@@ -1142,9 +1120,8 @@ func processTextResponse(config *config.ModelConfig, session *Session, prompt st
)
} else {
// otherwise we return more choices directly
functionCallReason := FinishReasonFunctionCall
*c = append(*c, schema.Choice{
FinishReason: &functionCallReason,
FinishReason: "function_call",
Message: &schema.Message{
Role: "assistant",
Content: &textContentToReturn,

View File

@@ -7,13 +7,13 @@ import (
"path"
"path/filepath"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/schema"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
@@ -24,19 +24,19 @@ import (
// @Param file formData file true "file"
// @Success 200 {object} map[string]string "Response"
// @Router /v1/audio/transcriptions [post]
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || config == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
diarize := c.FormValue("diarize") != "false"
diarize := c.FormValue("diarize", "false") != "false"
// retrieve the file data from the request
file, err := c.FormFile("file")
@@ -76,6 +76,6 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here
return c.JSON(http.StatusOK, tr)
return c.Status(http.StatusOK).JSON(tr)
}
}

View File

@@ -6,7 +6,7 @@ import (
"strconv"
"strings"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/middleware"
@@ -14,24 +14,20 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
return func(c echo.Context) error {
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error {
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input == nil {
return echo.ErrBadRequest
return fiber.ErrBadRequest
}
var raw map[string]interface{}
body := make([]byte, 0)
if c.Request().Body != nil {
c.Request().Body.Read(body)
}
if len(body) > 0 {
if body := c.Body(); len(body) > 0 {
_ = json.Unmarshal(body, &raw)
}
// Build VideoRequest using shared mapper
vr := MapOpenAIToVideo(input, raw)
// Place VideoRequest into context so localai.VideoEndpoint can consume it
c.Set(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
// Place VideoRequest into locals so localai.VideoEndpoint can consume it
c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, vr)
// Delegate to existing localai handler
return localai.VideoEndpoint(cl, ml, appConfig)(c)
}

View File

@@ -1,50 +1,48 @@
package http
import (
"io/fs"
"net/http"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/favicon"
"github.com/gofiber/fiber/v2/middleware/filesystem"
"github.com/mudler/LocalAI/core/explorer"
"github.com/mudler/LocalAI/core/http/middleware"
"github.com/mudler/LocalAI/core/http/routes"
"github.com/rs/zerolog/log"
)
func Explorer(db *explorer.Database) *echo.Echo {
e := echo.New()
func Explorer(db *explorer.Database) *fiber.App {
// Set renderer
e.Renderer = renderEngine()
// Hide banner
e.HideBanner = true
e.Pre(middleware.StripPathPrefix())
routes.RegisterExplorerRoutes(e, db)
// Favicon handler
e.GET("/favicon.svg", func(c echo.Context) error {
data, err := embedDirStatic.ReadFile("static/favicon.svg")
if err != nil {
return c.NoContent(http.StatusNotFound)
}
c.Response().Header().Set("Content-Type", "image/svg+xml")
return c.Blob(http.StatusOK, "image/svg+xml", data)
})
// Static files - use fs.Sub to create a filesystem rooted at "static"
staticFS, err := fs.Sub(embedDirStatic, "static")
if err != nil {
// Log error but continue - static files might not work
log.Error().Err(err).Msg("failed to create static filesystem")
} else {
e.StaticFS("/static", staticFS)
fiberCfg := fiber.Config{
Views: renderEngine(),
// We disable the Fiber startup message as it does not conform to structured logging.
// We register a startup log line with connection information in the OnListen hook to keep things user friendly though
DisableStartupMessage: false,
// Override default error handler
}
app := fiber.New(fiberCfg)
app.Use(middleware.StripPathPrefix())
routes.RegisterExplorerRoutes(app, db)
httpFS := http.FS(embedDirStatic)
app.Use(favicon.New(favicon.Config{
URL: "/favicon.svg",
FileSystem: httpFS,
File: "static/favicon.svg",
}))
app.Use("/static", filesystem.New(filesystem.Config{
Root: httpFS,
PathPrefix: "static",
Browse: true,
}))
// Define a custom 404 handler
// Note: keep this at the bottom!
e.GET("/*", notFoundHandler)
app.Use(notFoundHandler)
return e
return app
}

View File

@@ -3,108 +3,50 @@ package middleware
import (
"crypto/subtle"
"errors"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/dave-gray101/v2keyauth"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/keyauth"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/utils"
"github.com/mudler/LocalAI/core/schema"
)
var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key")
// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
// Currently this requires an upstream patch - and feature patches are no longer accepted to v2
// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate.
// GetKeyAuthConfig returns Echo's KeyAuth middleware configuration
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) {
// Create validator function
validator := getApiKeyValidationFunction(applicationConfig)
func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) {
customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key", "cookie:token"}, keyauth.ConfigDefault.AuthScheme)
if err != nil {
return nil, err
}
// Create error handler
errorHandler := getApiKeyErrorHandler(applicationConfig)
// Create Next function (skip middleware for certain requests)
skipper := getApiKeyRequiredFilterFunction(applicationConfig)
// Wrap it with our custom key lookup that checks multiple sources
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if len(applicationConfig.ApiKeys) == 0 {
return next(c)
}
// Skip if skipper says so
if skipper != nil && skipper(c) {
return next(c)
}
// Try to extract key from multiple sources
key, err := extractKeyFromMultipleSources(c)
if err != nil {
return errorHandler(err, c)
}
// Validate the key
valid, err := validator(key, c)
if err != nil || !valid {
return errorHandler(ErrMissingOrMalformedAPIKey, c)
}
// Store key in context for later use
c.Set("api_key", key)
return next(c)
}
return &v2keyauth.Config{
CustomKeyLookup: customLookup,
Next: getApiKeyRequiredFilterFunction(applicationConfig),
Validator: getApiKeyValidationFunction(applicationConfig),
ErrorHandler: getApiKeyErrorHandler(applicationConfig),
AuthScheme: "Bearer",
}, nil
}
// extractKeyFromMultipleSources checks multiple sources for the API key
// in order: Authorization header, x-api-key header, xi-api-key header, token cookie
func extractKeyFromMultipleSources(c echo.Context) (string, error) {
// Check Authorization header first
auth := c.Request().Header.Get("Authorization")
if auth != "" {
// Check for Bearer scheme
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer "), nil
}
// If no Bearer prefix, return as-is (for backward compatibility)
return auth, nil
}
// Check x-api-key header
if key := c.Request().Header.Get("x-api-key"); key != "" {
return key, nil
}
// Check xi-api-key header
if key := c.Request().Header.Get("xi-api-key"); key != "" {
return key, nil
}
// Check token cookie
cookie, err := c.Cookie("token")
if err == nil && cookie != nil && cookie.Value != "" {
return cookie.Value, nil
}
return "", ErrMissingOrMalformedAPIKey
}
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error {
return func(err error, c echo.Context) error {
if errors.Is(err, ErrMissingOrMalformedAPIKey) {
func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler {
return func(ctx *fiber.Ctx, err error) error {
if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) {
if len(applicationConfig.ApiKeys) == 0 {
return nil // if no keys are set up, any error we get here is not an error.
return ctx.Next() // if no keys are set up, any error we get here is not an error.
}
c.Response().Header().Set("WWW-Authenticate", "Bearer")
ctx.Set("WWW-Authenticate", "Bearer")
if applicationConfig.OpaqueErrors {
return c.NoContent(http.StatusUnauthorized)
return ctx.SendStatus(401)
}
// Check if the request content type is JSON
contentType := c.Request().Header.Get("Content-Type")
contentType := string(ctx.Context().Request.Header.ContentType())
if strings.Contains(contentType, "application/json") {
return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{
return ctx.Status(401).JSON(schema.ErrorResponse{
Error: &schema.APIError{
Message: "An authentication key is required",
Code: 401,
@@ -113,69 +55,50 @@ func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(err
})
}
return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{
"BaseURL": BaseURL(c),
return ctx.Status(401).Render("views/login", fiber.Map{
"BaseURL": utils.BaseURL(ctx),
})
}
if applicationConfig.OpaqueErrors {
return c.NoContent(http.StatusInternalServerError)
return ctx.SendStatus(500)
}
return err
}
}
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) {
func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) {
if applicationConfig.UseSubtleKeyComparison {
return func(key string, c echo.Context) (bool, error) {
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything
}
for _, validKey := range applicationConfig.ApiKeys {
if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 {
if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 {
return true, nil
}
}
return false, ErrMissingOrMalformedAPIKey
return false, v2keyauth.ErrMissingOrMalformedAPIKey
}
}
return func(key string, c echo.Context) (bool, error) {
return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
if len(applicationConfig.ApiKeys) == 0 {
return true, nil // If no keys are setup, accept everything
}
for _, validKey := range applicationConfig.ApiKeys {
if key == validKey {
if apiKey == validKey {
return true, nil
}
}
return false, ErrMissingOrMalformedAPIKey
return false, v2keyauth.ErrMissingOrMalformedAPIKey
}
}
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper {
return func(c echo.Context) bool {
path := c.Request().URL.Path
// Always skip authentication for static files
if strings.HasPrefix(path, "/static/") {
return true
}
// Always skip authentication for generated content
if strings.HasPrefix(path, "/generated-audio/") ||
strings.HasPrefix(path, "/generated-images/") ||
strings.HasPrefix(path, "/generated-videos/") {
return true
}
// Skip authentication for favicon
if path == "/favicon.svg" {
return true
}
// Handle GET request exemptions if enabled
if applicationConfig.DisableApiKeyRequirementForHttpGet {
if c.Request().Method != http.MethodGet {
func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool {
if applicationConfig.DisableApiKeyRequirementForHttpGet {
return func(c *fiber.Ctx) bool {
if c.Method() != "GET" {
return false
}
for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
@@ -183,8 +106,8 @@ func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig
return true
}
}
return false
}
return false
}
return func(c *fiber.Ctx) bool { return false }
}

View File

@@ -1,48 +0,0 @@
package middleware
import (
"strings"
"github.com/labstack/echo/v4"
)
// BaseURL returns the base URL for the given HTTP request context.
// It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path.
// The returned URL is guaranteed to end with `/`.
// The method should be used in conjunction with the StripPathPrefix middleware.
func BaseURL(c echo.Context) string {
path := c.Path()
origPath := c.Request().URL.Path
// Check if StripPathPrefix middleware stored the original path
if storedPath, ok := c.Get("_original_path").(string); ok && storedPath != "" {
origPath = storedPath
}
// Check X-Forwarded-Proto for scheme
scheme := "http"
if c.Request().Header.Get("X-Forwarded-Proto") == "https" {
scheme = "https"
} else if c.Request().TLS != nil {
scheme = "https"
}
// Check X-Forwarded-Host for host
host := c.Request().Host
if forwardedHost := c.Request().Header.Get("X-Forwarded-Host"); forwardedHost != "" {
host = forwardedHost
}
if path != origPath && strings.HasSuffix(origPath, path) && len(path) > 0 {
prefixLen := len(origPath) - len(path)
if prefixLen > 0 && prefixLen <= len(origPath) {
pathPrefix := origPath[:prefixLen]
if !strings.HasSuffix(pathPrefix, "/") {
pathPrefix += "/"
}
return scheme + "://" + host + pathPrefix
}
}
return scheme + "://" + host + "/"
}

View File

@@ -1,58 +0,0 @@
package middleware
import (
"net/http/httptest"
"github.com/labstack/echo/v4"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("BaseURL", func() {
Context("without prefix", func() {
It("should return base URL without prefix", func() {
app := echo.New()
actualURL := ""
// Register route - use the actual request path so routing works
routePath := "/hello/world"
app.GET(routePath, func(c echo.Context) error {
actualURL = BaseURL(c)
return nil
})
req := httptest.NewRequest("GET", "/hello/world", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualURL).To(Equal("http://example.com/"), "base URL")
})
})
Context("with prefix", func() {
It("should return base URL with prefix", func() {
app := echo.New()
actualURL := ""
// Register route with the stripped path (after middleware removes prefix)
routePath := "/hello/world"
app.GET(routePath, func(c echo.Context) error {
// Simulate what StripPathPrefix middleware does - store original path
c.Set("_original_path", "/myprefix/hello/world")
// Modify the request path to simulate prefix stripping
c.Request().URL.Path = "/hello/world"
actualURL = BaseURL(c)
return nil
})
// Make request with stripped path (middleware would have already processed it)
req := httptest.NewRequest("GET", "/hello/world", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualURL).To(Equal("http://example.com/myprefix/"), "base URL")
})
})
})

View File

@@ -1,13 +0,0 @@
package middleware_test
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestMiddleware(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Middleware test suite")
}

View File

@@ -1,482 +1,458 @@
package middleware
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/rs/zerolog/log"
)
type correlationIDKeyType string
// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"
type RequestExtractor struct {
modelConfigLoader *config.ModelConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
}
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
return &RequestExtractor{
modelConfigLoader: modelConfigLoader,
modelLoader: modelLoader,
applicationConfig: applicationConfig,
}
}
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
// TODO: Refactor to not return error if unchanged
func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) {
model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && model != "" {
return
}
model = c.Param("model")
if model == "" {
model = c.QueryParam("model")
}
if model == "" {
// Set model from bearer token, if available
auth := c.Request().Header.Get("Authorization")
bearer := strings.TrimPrefix(auth, "Bearer ")
if bearer != "" && bearer != auth {
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
if err == nil && exists {
model = bearer
}
}
}
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
}
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
re.setModelNameFromRequest(c)
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || localModelName == "" {
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
}
return next(c)
}
}
}
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
re.setModelNameFromRequest(c)
localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if localModelName != "" { // Don't overwrite existing values
return next(c)
}
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
if err != nil {
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
return next(c)
}
if len(modelNames) == 0 {
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
// This is non-fatal - making it so was breaking the case of direct installation of raw models
// return errors.New("this endpoint requires at least one model to be installed")
return next(c)
}
c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
return next(c)
}
}
}
// TODO: If context and cancel above belong on all methods, move that part of above into here!
// Otherwise, it's in its own method below for now
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
input := initializer()
if input == nil {
return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body")
}
if err := c.Bind(input); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err))
}
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
if input.ModelName(nil) == "" {
localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && localModelName != "" {
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
input.ModelName(&localModelName)
}
}
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
if err != nil {
log.Err(err)
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
} else if cfg.Model == "" && input.ModelName(nil) != "" {
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
cfg.Model = input.ModelName(nil)
}
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return next(c)
}
}
}
func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error {
input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return echo.ErrBadRequest
}
cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return echo.ErrBadRequest
}
// Extract or generate the correlation ID
correlationID := c.Request().Header.Get("X-Correlation-ID")
if correlationID == "" {
correlationID = uuid.New().String()
}
c.Response().Header().Set("X-Correlation-ID", correlationID)
// Use the request context directly - Echo properly supports context cancellation!
// No need for workarounds like handleConnectionCancellation
reqCtx := c.Request().Context()
c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Cancel when request context is cancelled (client disconnects)
go func() {
select {
case <-reqCtx.Done():
cancel()
case <-c1.Done():
// Already cancelled
}
}()
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
input.Context = ctxWithCorrelationID
input.Cancel = cancel
err := mergeOpenAIRequestAndModelConfig(cfg, input)
if err != nil {
return err
}
if cfg.Model == "" {
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
cfg.Model = input.Model
}
c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return nil
}
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != nil {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
if input.ResponseFormat != nil {
switch responseFormat := input.ResponseFormat.(type) {
case string:
config.ResponseFormat = responseFormat
case map[string]interface{}:
config.ResponseFormatMap = responseFormat
}
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
imgIndex, vidIndex, audioIndex := 0, 0, 0
for i, m := range input.Messages {
nrOfImgsInMessage := 0
nrOfVideosInMessage := 0
nrOfAudiosInMessage := 0
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
textContent := ""
// we will template this at the end
CONTENT:
for _, pp := range c {
switch pp.Type {
case "text":
textContent += pp.Text
//input.Messages[i].StringContent = pp.Text
case "video", "video_url":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding video: %s", err)
continue CONTENT
}
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
vidIndex++
nrOfVideosInMessage++
case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding audio: %s", err)
continue CONTENT
}
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
audioIndex++
nrOfAudiosInMessage++
case "input_audio":
// TODO: make sure that we only return base64 stuff
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
audioIndex++
nrOfAudiosInMessage++
case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
imgIndex++
nrOfImgsInMessage++
}
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
TotalVideos: vidIndex,
TotalAudios: audioIndex,
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage,
AudiosInMessage: nrOfAudiosInMessage,
}, textContent)
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []any:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []any:
tokens := []int{}
inputStrings := []string{}
for _, ii := range i {
switch ii := ii.(type) {
case int:
tokens = append(tokens, ii)
case float64:
tokens = append(tokens, int(ii))
case string:
inputStrings = append(inputStrings, ii)
default:
log.Error().Msgf("Unknown input type: %T", ii)
}
}
config.InputToken = append(config.InputToken, tokens)
config.InputStrings = append(config.InputStrings, inputStrings...)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
// If a quality was defined as number, convert it to step
if input.Quality != "" {
q, err := strconv.Atoi(input.Quality)
if err == nil {
config.Step = q
}
}
if config.Validate() {
return nil
}
return fmt.Errorf("unable to validate configuration after merging")
}
package middleware
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services"
"github.com/mudler/LocalAI/core/templates"
"github.com/mudler/LocalAI/pkg/functions"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)
type correlationIDKeyType string
// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"
type RequestExtractor struct {
modelConfigLoader *config.ModelConfigLoader
modelLoader *model.ModelLoader
applicationConfig *config.ApplicationConfig
}
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
return &RequestExtractor{
modelConfigLoader: modelConfigLoader,
modelLoader: modelLoader,
applicationConfig: applicationConfig,
}
}
const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME"
const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST"
const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG"
// TODO: Refactor to not return error if unchanged
func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && model != "" {
return
}
model = ctx.Params("model")
if (model == "") && ctx.Query("model") != "" {
model = ctx.Query("model")
}
if model == "" {
// Set model from bearer token, if available
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
if bearer != "" {
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
if err == nil && exists {
model = bearer
}
}
}
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model)
}
func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if !ok || localModelName == "" {
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName)
log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default")
}
return ctx.Next()
}
}
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) fiber.Handler {
return func(ctx *fiber.Ctx) error {
re.setModelNameFromRequest(ctx)
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if localModelName != "" { // Don't overwrite existing values
return ctx.Next()
}
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
if err != nil {
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
return ctx.Next()
}
if len(modelNames) == 0 {
log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed")
// This is non-fatal - making it so was breaking the case of direct installation of raw models
// return errors.New("this endpoint requires at least one model to be installed")
return ctx.Next()
}
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0])
log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model")
return ctx.Next()
}
}
// TODO: If context and cancel above belong on all methods, move that part of above into here!
// Otherwise, it's in its own method below for now
func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler {
return func(ctx *fiber.Ctx) error {
input := initializer()
if input == nil {
return fmt.Errorf("unable to initialize body")
}
if err := ctx.BodyParser(input); err != nil {
return fmt.Errorf("failed parsing request body: %w", err)
}
// If this request doesn't have an associated model name, fetch it from earlier in the middleware chain
if input.ModelName(nil) == "" {
localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
if ok && localModelName != "" {
log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain")
input.ModelName(&localModelName)
}
}
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
if err != nil {
log.Err(err)
log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil))
} else if cfg.Model == "" && input.ModelName(nil) != "" {
log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input")
cfg.Model = input.ModelName(nil)
}
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return ctx.Next()
}
}
func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
if !ok || input.Model == "" {
return fiber.ErrBadRequest
}
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
if !ok || cfg == nil {
return fiber.ErrBadRequest
}
// Extract or generate the correlation ID
correlationID := ctx.Get("X-Correlation-ID", uuid.New().String())
ctx.Set("X-Correlation-ID", correlationID)
c1, cancel := context.WithCancel(re.applicationConfig.Context)
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID)
input.Context = ctxWithCorrelationID
input.Cancel = cancel
err := mergeOpenAIRequestAndModelConfig(cfg, input)
if err != nil {
return err
}
if cfg.Model == "" {
log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value")
cfg.Model = input.Model
}
ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input)
ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg)
return ctx.Next()
}
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
if input.Echo {
config.Echo = input.Echo
}
if input.TopK != nil {
config.TopK = input.TopK
}
if input.TopP != nil {
config.TopP = input.TopP
}
if input.Backend != "" {
config.Backend = input.Backend
}
if input.ClipSkip != 0 {
config.Diffusers.ClipSkip = input.ClipSkip
}
if input.NegativePromptScale != 0 {
config.NegativePromptScale = input.NegativePromptScale
}
if input.NegativePrompt != "" {
config.NegativePrompt = input.NegativePrompt
}
if input.RopeFreqBase != 0 {
config.RopeFreqBase = input.RopeFreqBase
}
if input.RopeFreqScale != 0 {
config.RopeFreqScale = input.RopeFreqScale
}
if input.Grammar != "" {
config.Grammar = input.Grammar
}
if input.Temperature != nil {
config.Temperature = input.Temperature
}
if input.Maxtokens != nil {
config.Maxtokens = input.Maxtokens
}
if input.ResponseFormat != nil {
switch responseFormat := input.ResponseFormat.(type) {
case string:
config.ResponseFormat = responseFormat
case map[string]interface{}:
config.ResponseFormatMap = responseFormat
}
}
switch stop := input.Stop.(type) {
case string:
if stop != "" {
config.StopWords = append(config.StopWords, stop)
}
case []interface{}:
for _, pp := range stop {
if s, ok := pp.(string); ok {
config.StopWords = append(config.StopWords, s)
}
}
}
if len(input.Tools) > 0 {
for _, tool := range input.Tools {
input.Functions = append(input.Functions, tool.Function)
}
}
if input.ToolsChoice != nil {
var toolChoice functions.Tool
switch content := input.ToolsChoice.(type) {
case string:
_ = json.Unmarshal([]byte(content), &toolChoice)
case map[string]interface{}:
dat, _ := json.Marshal(content)
_ = json.Unmarshal(dat, &toolChoice)
}
input.FunctionCall = map[string]interface{}{
"name": toolChoice.Function.Name,
}
}
// Decode each request's message content
imgIndex, vidIndex, audioIndex := 0, 0, 0
for i, m := range input.Messages {
nrOfImgsInMessage := 0
nrOfVideosInMessage := 0
nrOfAudiosInMessage := 0
switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)
textContent := ""
// we will template this at the end
CONTENT:
for _, pp := range c {
switch pp.Type {
case "text":
textContent += pp.Text
//input.Messages[i].StringContent = pp.Text
case "video", "video_url":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding video: %s", err)
continue CONTENT
}
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff
vidIndex++
nrOfVideosInMessage++
case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding audio: %s", err)
continue CONTENT
}
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
audioIndex++
nrOfAudiosInMessage++
case "input_audio":
// TODO: make sure that we only return base64 stuff
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data)
audioIndex++
nrOfAudiosInMessage++
case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
if err != nil {
log.Error().Msgf("Failed encoding image: %s", err)
continue CONTENT
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
imgIndex++
nrOfImgsInMessage++
}
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
TotalVideos: vidIndex,
TotalAudios: audioIndex,
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage,
AudiosInMessage: nrOfAudiosInMessage,
}, textContent)
}
}
if input.RepeatPenalty != 0 {
config.RepeatPenalty = input.RepeatPenalty
}
if input.FrequencyPenalty != 0 {
config.FrequencyPenalty = input.FrequencyPenalty
}
if input.PresencePenalty != 0 {
config.PresencePenalty = input.PresencePenalty
}
if input.Keep != 0 {
config.Keep = input.Keep
}
if input.Batch != 0 {
config.Batch = input.Batch
}
if input.IgnoreEOS {
config.IgnoreEOS = input.IgnoreEOS
}
if input.Seed != nil {
config.Seed = input.Seed
}
if input.TypicalP != nil {
config.TypicalP = input.TypicalP
}
log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input))
switch inputs := input.Input.(type) {
case string:
if inputs != "" {
config.InputStrings = append(config.InputStrings, inputs)
}
case []any:
for _, pp := range inputs {
switch i := pp.(type) {
case string:
config.InputStrings = append(config.InputStrings, i)
case []any:
tokens := []int{}
inputStrings := []string{}
for _, ii := range i {
switch ii := ii.(type) {
case int:
tokens = append(tokens, ii)
case float64:
tokens = append(tokens, int(ii))
case string:
inputStrings = append(inputStrings, ii)
default:
log.Error().Msgf("Unknown input type: %T", ii)
}
}
config.InputToken = append(config.InputToken, tokens)
config.InputStrings = append(config.InputStrings, inputStrings...)
}
}
}
// Can be either a string or an object
switch fnc := input.FunctionCall.(type) {
case string:
if fnc != "" {
config.SetFunctionCallString(fnc)
}
case map[string]interface{}:
var name string
n, exists := fnc["name"]
if exists {
nn, e := n.(string)
if e {
name = nn
}
}
config.SetFunctionCallNameString(name)
}
switch p := input.Prompt.(type) {
case string:
config.PromptStrings = append(config.PromptStrings, p)
case []interface{}:
for _, pp := range p {
if s, ok := pp.(string); ok {
config.PromptStrings = append(config.PromptStrings, s)
}
}
}
// If a quality was defined as number, convert it to step
if input.Quality != "" {
q, err := strconv.Atoi(input.Quality)
if err == nil {
config.Step = q
}
}
if config.Validate() {
return nil
}
return fmt.Errorf("unable to validate configuration after merging")
}

View File

@@ -3,55 +3,34 @@ package middleware
import (
"strings"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
)
// StripPathPrefix returns middleware that strips a path prefix from the request path.
// StripPathPrefix returns a middleware that strips a path prefix from the request path.
// The path prefix is obtained from the X-Forwarded-Prefix HTTP request header.
// This must be registered as Pre middleware (using e.Pre()) to modify the path before routing.
func StripPathPrefix() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
prefixes := c.Request().Header.Values("X-Forwarded-Prefix")
originalPath := c.Request().URL.Path
func StripPathPrefix() fiber.Handler {
return func(c *fiber.Ctx) error {
for _, prefix := range c.GetReqHeaders()["X-Forwarded-Prefix"] {
if prefix != "" {
path := c.Path()
pos := len(prefix)
for _, prefix := range prefixes {
if prefix != "" {
normalizedPrefix := prefix
if !strings.HasSuffix(prefix, "/") {
normalizedPrefix = prefix + "/"
}
if prefix[pos-1] == '/' {
pos--
} else {
prefix += "/"
}
if strings.HasPrefix(originalPath, normalizedPrefix) {
// Update the request path by stripping the normalized prefix
newPath := originalPath[len(normalizedPrefix):]
if newPath == "" {
newPath = "/"
}
// Ensure path starts with / for proper routing
if !strings.HasPrefix(newPath, "/") {
newPath = "/" + newPath
}
// Update the URL path - Echo's router uses URL.Path for routing
c.Request().URL.Path = newPath
c.Request().URL.RawPath = ""
// Update RequestURI to match the new path (needed for proper routing)
if c.Request().URL.RawQuery != "" {
c.Request().RequestURI = newPath + "?" + c.Request().URL.RawQuery
} else {
c.Request().RequestURI = newPath
}
// Store original path for BaseURL utility
c.Set("_original_path", originalPath)
break
} else if originalPath == prefix || originalPath == prefix+"/" {
// Redirect to prefix with trailing slash (use 302 to match test expectations)
return c.Redirect(302, normalizedPrefix)
}
if strings.HasPrefix(path, prefix) {
c.Path(path[pos:])
break
} else if prefix[:pos] == path {
c.Redirect(prefix)
return nil
}
}
return next(c)
}
return c.Next()
}
}

View File

@@ -2,133 +2,120 @@ package middleware
import (
"net/http/httptest"
"testing"
"github.com/labstack/echo/v4"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/gofiber/fiber/v2"
"github.com/stretchr/testify/require"
)
var _ = Describe("StripPathPrefix", func() {
var app *echo.Echo
func TestStripPathPrefix(t *testing.T) {
var actualPath string
var appInitialized bool
BeforeEach(func() {
actualPath = ""
if !appInitialized {
app = echo.New()
app.Pre(StripPathPrefix())
app := fiber.New()
app.GET("/hello/world", func(c echo.Context) error {
actualPath = c.Request().URL.Path
return nil
})
app.Use(StripPathPrefix())
app.GET("/", func(c echo.Context) error {
actualPath = c.Request().URL.Path
return nil
})
appInitialized = true
}
app.Get("/hello/world", func(c *fiber.Ctx) error {
actualPath = c.Path()
return nil
})
Context("without prefix", func() {
It("should not modify path when no header is present", func() {
req := httptest.NewRequest("GET", "/hello/world", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should not modify root path when no header is present", func() {
req := httptest.NewRequest("GET", "/", nil)
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/"), "rewritten path")
})
It("should not modify path when header does not match", func() {
req := httptest.NewRequest("GET", "/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
app.Get("/", func(c *fiber.Ctx) error {
actualPath = c.Path()
return nil
})
Context("with prefix", func() {
It("should return 404 when prefix does not match header", func() {
req := httptest.NewRequest("GET", "/prefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
for _, tc := range []struct {
name string
path string
prefixHeader []string
expectStatus int
expectPath string
}{
{
name: "without prefix and header",
path: "/hello/world",
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "without prefix and headers on root path",
path: "/",
expectStatus: 200,
expectPath: "/",
},
{
name: "without prefix but header",
path: "/hello/world",
prefixHeader: []string{"/otherprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix but non-matching header",
path: "/prefix/hello/world",
prefixHeader: []string{"/otherprefix/"},
expectStatus: 404,
},
{
name: "with prefix and matching header",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and 1st header matching",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix/", "/otherprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and 2nd header matching",
path: "/myprefix/hello/world",
prefixHeader: []string{"/otherprefix/", "/myprefix/"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and header not ending with slash",
path: "/myprefix/hello/world",
prefixHeader: []string{"/myprefix"},
expectStatus: 200,
expectPath: "/hello/world",
},
{
name: "with prefix and non-matching header not ending with slash",
path: "/myprefix-suffix/hello/world",
prefixHeader: []string{"/myprefix"},
expectStatus: 404,
},
{
name: "redirect when prefix does not end with a slash",
path: "/myprefix",
prefixHeader: []string{"/myprefix"},
expectStatus: 302,
expectPath: "/myprefix/",
},
} {
t.Run(tc.name, func(t *testing.T) {
actualPath = ""
req := httptest.NewRequest("GET", tc.path, nil)
if tc.prefixHeader != nil {
req.Header["X-Forwarded-Prefix"] = tc.prefixHeader
}
Expect(rec.Code).To(Equal(404), "response status code")
resp, err := app.Test(req, -1)
require.NoError(t, err)
require.Equal(t, tc.expectStatus, resp.StatusCode, "response status code")
if tc.expectStatus == 200 {
require.Equal(t, tc.expectPath, actualPath, "rewritten path")
} else if tc.expectStatus == 302 {
require.Equal(t, tc.expectPath, resp.Header.Get("Location"), "redirect location")
}
})
It("should strip matching prefix from path", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should strip prefix when it matches the first header value", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/", "/otherprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should strip prefix when it matches the second header value", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/", "/myprefix/"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should strip prefix when header does not end with slash", func() {
req := httptest.NewRequest("GET", "/myprefix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(200), "response status code")
Expect(actualPath).To(Equal("/hello/world"), "rewritten path")
})
It("should return 404 when prefix does not match header without trailing slash", func() {
req := httptest.NewRequest("GET", "/myprefix-suffix/hello/world", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(404), "response status code")
})
It("should redirect when prefix does not end with a slash", func() {
req := httptest.NewRequest("GET", "/myprefix", nil)
req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"}
rec := httptest.NewRecorder()
app.ServeHTTP(rec, req)
Expect(rec.Code).To(Equal(302), "response status code")
Expect(rec.Header().Get("Location")).To(Equal("/myprefix/"), "redirect location")
})
})
})
}
}

View File

@@ -17,7 +17,7 @@ import (
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"fmt"
. "github.com/mudler/LocalAI/core/http"
"github.com/labstack/echo/v4"
"github.com/gofiber/fiber/v2"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -62,7 +62,7 @@ func (f *fakeAI) VAD(*pb.VADRequest) (pb.VADResponse, error) { return pb.VADResp
var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
var tmpdir string
var appServer *application.Application
var app *echo.Echo
var app *fiber.App
var ctx context.Context
var cancel context.CancelFunc
@@ -97,9 +97,7 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
AfterEach(func() {
cancel()
if app != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = app.Shutdown(ctx)
_ = app.Shutdown()
}
_ = os.RemoveAll(tmpdir)
})
@@ -108,11 +106,7 @@ var _ = Describe("OpenAI /v1/videos (embedded backend)", func() {
var err error
app, err = API(appServer)
Expect(err).ToNot(HaveOccurred())
go func() {
if err := app.Start("127.0.0.1:9091"); err != nil && err != http.ErrServerClosed {
// Log error if needed
}
}()
go app.Listen("127.0.0.1:9091")
// wait for server
client := &http.Client{Timeout: 5 * time.Second}

Some files were not shown because too many files have changed in this diff Show More