mirror of
https://github.com/mudler/LocalAI.git
synced 2026-04-04 23:14:41 -04:00
feat(backends): add system backend, refactor (#6059)
- Add a system backend path - Refactor and consolidate system information in system state - Use system state in all the components to figure out the system paths to used whenever needed - Refactor BackendConfig -> ModelConfig. This was otherway misleading as now we do have a backend configuration which is not the model config. Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
253b7537dc
commit
089efe05fd
@@ -198,7 +198,7 @@ func API(application *application.Application) (*fiber.App, error) {
|
||||
}
|
||||
|
||||
galleryService := services.NewGalleryService(application.ApplicationConfig(), application.ModelLoader())
|
||||
err = galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader())
|
||||
err = galleryService.Start(application.ApplicationConfig().Context, application.BackendLoader(), application.ApplicationConfig().SystemState)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -320,12 +321,17 @@ var _ = Describe("API test", func() {
|
||||
},
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(backendPath),
|
||||
system.WithModelPath(modelDir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithGalleries(galleries),
|
||||
config.WithModelPath(modelDir),
|
||||
config.WithBackendsPath(backendPath),
|
||||
config.WithApiKeys([]string{apiKey}),
|
||||
)...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -523,13 +529,18 @@ var _ = Describe("API test", func() {
|
||||
},
|
||||
}
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(backendPath),
|
||||
system.WithModelPath(modelDir),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithGeneratedContentDir(tmpdir),
|
||||
config.WithBackendsPath(backendPath),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithGalleries(galleries),
|
||||
config.WithModelPath(modelDir),
|
||||
)...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -729,12 +740,17 @@ var _ = Describe("API test", func() {
|
||||
|
||||
var err error
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(backendPath),
|
||||
system.WithModelPath(modelPath),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithExternalBackend("transformers", os.Getenv("HUGGINGFACE_GRPC")),
|
||||
config.WithContext(c),
|
||||
config.WithBackendsPath(backendPath),
|
||||
config.WithModelPath(modelPath),
|
||||
config.WithSystemState(systemState),
|
||||
)...)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
app, err = API(application)
|
||||
@@ -960,11 +976,17 @@ var _ = Describe("API test", func() {
|
||||
c, cancel = context.WithCancel(context.Background())
|
||||
|
||||
var err error
|
||||
|
||||
systemState, err := system.GetSystemState(
|
||||
system.WithBackendPath(backendPath),
|
||||
system.WithModelPath(modelPath),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
application, err := application.New(
|
||||
append(commonOpts,
|
||||
config.WithContext(c),
|
||||
config.WithModelPath(modelPath),
|
||||
config.WithBackendsPath(backendPath),
|
||||
config.WithSystemState(systemState),
|
||||
config.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/sound-generation [post]
|
||||
func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
|
||||
@@ -23,7 +23,7 @@ func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
// @Param request body schema.TTSRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/text-to-speech/{voice-id} [post]
|
||||
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
voiceID := c.Params("voice-id")
|
||||
@@ -27,7 +27,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
// @Param request body schema.JINARerankRequest true "query params"
|
||||
// @Success 200 {object} schema.JINARerankResponse "Response"
|
||||
// @Router /v1/rerank [post]
|
||||
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
|
||||
@@ -25,7 +25,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -11,24 +11,27 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type BackendEndpointService struct {
|
||||
galleries []config.Gallery
|
||||
backendPath string
|
||||
backendApplier *services.GalleryService
|
||||
galleries []config.Gallery
|
||||
backendPath string
|
||||
backendSystemPath string
|
||||
backendApplier *services.GalleryService
|
||||
}
|
||||
|
||||
type GalleryBackend struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func CreateBackendEndpointService(galleries []config.Gallery, backendPath string, backendApplier *services.GalleryService) BackendEndpointService {
|
||||
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *services.GalleryService) BackendEndpointService {
|
||||
return BackendEndpointService{
|
||||
galleries: galleries,
|
||||
backendPath: backendPath,
|
||||
backendApplier: backendApplier,
|
||||
galleries: galleries,
|
||||
backendPath: systemState.Backend.BackendsPath,
|
||||
backendSystemPath: systemState.Backend.BackendsSystemPath,
|
||||
backendApplier: backendApplier,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,9 +114,9 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
|
||||
// @Summary List all Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends [get]
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint() func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backends, err := gallery.ListSystemBackends(mgs.backendPath)
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -141,9 +144,9 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.
|
||||
// @Summary List all available Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends/available [get]
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint() func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backends, err := gallery.AvailableBackends(mgs.galleries, mgs.backendPath)
|
||||
backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
// @Param request body schema.DetectionRequest true "query params"
|
||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
||||
// @Router /v1/detection [post]
|
||||
func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
@@ -24,7 +24,7 @@ func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, ap
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -26,11 +27,11 @@ type GalleryModel struct {
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
return ModelGalleryEndpointService{
|
||||
galleries: galleries,
|
||||
backendGalleries: backendGalleries,
|
||||
modelPath: modelPath,
|
||||
modelPath: systemState.Model.ModelsPath,
|
||||
galleryApplier: galleryApplier,
|
||||
}
|
||||
}
|
||||
@@ -115,10 +116,10 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fib
|
||||
// @Summary List installable models.
|
||||
// @Success 200 {object} []gallery.GalleryModel "Response"
|
||||
// @Router /models/available [get]
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return err
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
// @Router /v1/tokenMetrics [get]
|
||||
// @Router /tokenMetrics [get]
|
||||
func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.TokenMetricsRequest)
|
||||
@@ -37,7 +37,7 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
}
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(modelFile, appConfig)
|
||||
cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(modelFile, appConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
|
||||
@@ -14,14 +14,14 @@ import (
|
||||
// @Param request body schema.TokenizeRequest true "Request"
|
||||
// @Success 200 {object} schema.TokenizeResponse "Response"
|
||||
// @Router /v1/tokenize [post]
|
||||
func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -22,14 +22,14 @@ import (
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
// @Router /v1/audio/speech [post]
|
||||
// @Router /tts [post]
|
||||
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -16,14 +16,14 @@ import (
|
||||
// @Param request body schema.VADRequest true "query params"
|
||||
// @Success 200 {object} proto.VADResponse "Response"
|
||||
// @Router /vad [post]
|
||||
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func downloadFile(url string) (string, error) {
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /video [post]
|
||||
func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
|
||||
if !ok || input.Model == "" {
|
||||
@@ -72,7 +72,7 @@ func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Video Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
|
||||
@@ -11,12 +11,12 @@ import (
|
||||
)
|
||||
|
||||
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
cl *config.BackendConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
|
||||
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
|
||||
for _, m := range backendConfigs {
|
||||
for _, m := range modelConfigs {
|
||||
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -34,7 +34,7 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
"Version": internal.PrintableVersion(),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"Models": modelsWithoutConfig,
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
"ApplicationConfig": appConfig,
|
||||
"ProcessingModels": processingModels,
|
||||
|
||||
@@ -27,11 +27,11 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/chat/completions [post]
|
||||
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
var id, textContentToReturn string
|
||||
var created int
|
||||
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
@@ -66,7 +66,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
result := ""
|
||||
_, tokenUsage, _ := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
@@ -183,7 +183,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
@@ -501,7 +501,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
}
|
||||
}
|
||||
|
||||
func handleQuestion(config *config.BackendConfig, cl *config.BackendConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
|
||||
func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
|
||||
|
||||
if len(funcResults) == 0 && result != "" {
|
||||
log.Debug().Msgf("nothing function results but we had a message from the LLM")
|
||||
|
||||
@@ -27,10 +27,10 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
@@ -73,7 +73,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/edits [post]
|
||||
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
@@ -34,7 +34,7 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
// Opt-in extra usage flag
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -21,14 +21,14 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/embeddings [post]
|
||||
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ func downloadFile(url string) (string, error) {
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/images/generations [post]
|
||||
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
@@ -73,7 +73,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Image Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
func ComputeChoices(
|
||||
req *schema.OpenAIRequest,
|
||||
predInput string,
|
||||
config *config.BackendConfig,
|
||||
bcl *config.BackendConfigLoader,
|
||||
config *config.ModelConfig,
|
||||
bcl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
loader *model.ModelLoader,
|
||||
cb func(string, *[]schema.Choice),
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
// @Summary List and describe the various models available in the API.
|
||||
// @Success 200 {object} schema.ModelsDataResponse "Response"
|
||||
// @Router /v1/models [get]
|
||||
func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// If blank, no filter is applied.
|
||||
filter := c.Query("filter")
|
||||
|
||||
@@ -559,7 +559,7 @@ func sendNotImplemented(c *websocket.Conn, message string) {
|
||||
sendError(c, "not_implemented", message, "", "event_TODO")
|
||||
}
|
||||
|
||||
func updateTransSession(session *Session, update *types.ClientSession, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
func updateTransSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
sessionLock.Lock()
|
||||
defer sessionLock.Unlock()
|
||||
|
||||
@@ -589,7 +589,7 @@ func updateTransSession(session *Session, update *types.ClientSession, cl *confi
|
||||
}
|
||||
|
||||
// Function to update session configurations
|
||||
func updateSession(session *Session, update *types.ClientSession, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
func updateSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
sessionLock.Lock()
|
||||
defer sessionLock.Unlock()
|
||||
|
||||
@@ -628,7 +628,7 @@ func updateSession(session *Session, update *types.ClientSession, cl *config.Bac
|
||||
|
||||
// handleVAD is a goroutine that listens for audio data from the client,
|
||||
// runs VAD on the audio data, and commits utterances to the conversation
|
||||
func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
|
||||
func handleVAD(cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
|
||||
vadContext, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
<-done
|
||||
@@ -742,7 +742,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
||||
}
|
||||
}
|
||||
|
||||
func commitUtterance(ctx context.Context, utt []byte, cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
|
||||
func commitUtterance(ctx context.Context, utt []byte, cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
|
||||
if len(utt) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -853,7 +853,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADS
|
||||
|
||||
// TODO: Below needed for normal mode instead of transcription only
|
||||
// Function to generate a response based on the conversation
|
||||
// func generateResponse(config *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {
|
||||
// func generateResponse(config *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {
|
||||
//
|
||||
// log.Debug().Msg("Generating realtime response...")
|
||||
//
|
||||
@@ -1067,7 +1067,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADS
|
||||
// }
|
||||
|
||||
// Function to process text response and detect function calls
|
||||
func processTextResponse(config *config.BackendConfig, session *Session, prompt string) (string, *FunctionCall, error) {
|
||||
func processTextResponse(config *config.ModelConfig, session *Session, prompt string) (string, *FunctionCall, error) {
|
||||
|
||||
// Placeholder implementation
|
||||
// Replace this with actual model inference logic using session.Model and prompt
|
||||
|
||||
@@ -22,14 +22,14 @@ var (
|
||||
// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods
|
||||
// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS)
|
||||
type wrappedModel struct {
|
||||
TTSConfig *config.BackendConfig
|
||||
TranscriptionConfig *config.BackendConfig
|
||||
LLMConfig *config.BackendConfig
|
||||
TTSConfig *config.ModelConfig
|
||||
TranscriptionConfig *config.ModelConfig
|
||||
LLMConfig *config.ModelConfig
|
||||
TTSClient grpcClient.Backend
|
||||
TranscriptionClient grpcClient.Backend
|
||||
LLMClient grpcClient.Backend
|
||||
|
||||
VADConfig *config.BackendConfig
|
||||
VADConfig *config.ModelConfig
|
||||
VADClient grpcClient.Backend
|
||||
}
|
||||
|
||||
@@ -37,17 +37,17 @@ type wrappedModel struct {
|
||||
// We have to wrap this out as well because we want to load two models one for VAD and one for the actual model.
|
||||
// In the future there could be models that accept continous audio input only so this design will be useful for that
|
||||
type anyToAnyModel struct {
|
||||
LLMConfig *config.BackendConfig
|
||||
LLMConfig *config.ModelConfig
|
||||
LLMClient grpcClient.Backend
|
||||
|
||||
VADConfig *config.BackendConfig
|
||||
VADConfig *config.ModelConfig
|
||||
VADClient grpcClient.Backend
|
||||
}
|
||||
|
||||
type transcriptOnlyModel struct {
|
||||
TranscriptionConfig *config.BackendConfig
|
||||
TranscriptionConfig *config.ModelConfig
|
||||
TranscriptionClient grpcClient.Backend
|
||||
VADConfig *config.BackendConfig
|
||||
VADConfig *config.ModelConfig
|
||||
VADClient grpcClient.Backend
|
||||
}
|
||||
|
||||
@@ -105,8 +105,8 @@ func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOpti
|
||||
return m.LLMClient.PredictStream(ctx, in, f)
|
||||
}
|
||||
|
||||
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.BackendConfig, error) {
|
||||
cfgVAD, err := cl.LoadBackendConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) {
|
||||
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -122,7 +122,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConf
|
||||
return nil, nil, fmt.Errorf("failed to load tts model: %w", err)
|
||||
}
|
||||
|
||||
cfgSST, err := cl.LoadBackendConfigFileByName(pipeline.Transcription, ml.ModelPath)
|
||||
cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -139,17 +139,17 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConf
|
||||
}
|
||||
|
||||
return &transcriptOnlyModel{
|
||||
VADConfig: cfgVAD,
|
||||
VADClient: VADClient,
|
||||
VADConfig: cfgVAD,
|
||||
VADClient: VADClient,
|
||||
TranscriptionConfig: cfgSST,
|
||||
TranscriptionClient: transcriptionClient,
|
||||
}, cfgSST, nil
|
||||
}
|
||||
|
||||
// returns and loads either a wrapped model or a model that support audio-to-audio
|
||||
func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) {
|
||||
func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) {
|
||||
|
||||
cfgVAD, err := cl.LoadBackendConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -166,7 +166,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
}
|
||||
|
||||
// TODO: Do we always need a transcription model? It can be disabled. Note that any-to-any instruction following models don't transcribe as such, so if transcription is required it is a separate process
|
||||
cfgSST, err := cl.LoadBackendConfigFileByName(pipeline.Transcription, ml.ModelPath)
|
||||
cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -185,7 +185,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
// TODO: Decide when we have a real any-to-any model
|
||||
if false {
|
||||
|
||||
cfgAnyToAny, err := cl.LoadBackendConfigFileByName(pipeline.LLM, ml.ModelPath)
|
||||
cfgAnyToAny, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -212,7 +212,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
log.Debug().Msg("Loading a wrapped model")
|
||||
|
||||
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
|
||||
cfgLLM, err := cl.LoadBackendConfigFileByName(pipeline.LLM, ml.ModelPath)
|
||||
cfgLLM, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -222,7 +222,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
cfgTTS, err := cl.LoadBackendConfigFileByName(pipeline.TTS, ml.ModelPath)
|
||||
cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -232,7 +232,6 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
|
||||
opts = backend.ModelOptions(*cfgTTS, appConfig)
|
||||
ttsClient, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
|
||||
@@ -24,14 +24,14 @@ import (
|
||||
// @Param file formData file true "file"
|
||||
// @Success 200 {object} map[string]string "Response"
|
||||
// @Router /v1/audio/transcriptions [post]
|
||||
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -26,16 +26,16 @@ type correlationIDKeyType string
|
||||
const CorrelationIDKey correlationIDKeyType = "correlationID"
|
||||
|
||||
type RequestExtractor struct {
|
||||
backendConfigLoader *config.BackendConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
modelConfigLoader *config.ModelConfigLoader
|
||||
modelLoader *model.ModelLoader
|
||||
applicationConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
func NewRequestExtractor(backendConfigLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
||||
func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor {
|
||||
return &RequestExtractor{
|
||||
backendConfigLoader: backendConfigLoader,
|
||||
modelLoader: modelLoader,
|
||||
applicationConfig: applicationConfig,
|
||||
modelConfigLoader: modelConfigLoader,
|
||||
modelLoader: modelLoader,
|
||||
applicationConfig: applicationConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) {
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request.
|
||||
if bearer != "" {
|
||||
exists, err := services.CheckIfModelExists(re.backendConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
|
||||
exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE)
|
||||
if err == nil && exists {
|
||||
model = bearer
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModel
|
||||
}
|
||||
}
|
||||
|
||||
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.BackendConfigFilterFn) fiber.Handler {
|
||||
func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
re.setModelNameFromRequest(ctx)
|
||||
localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string)
|
||||
@@ -89,7 +89,7 @@ func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn con
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
modelNames, err := services.ListModels(re.backendConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
|
||||
modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()")
|
||||
return ctx.Next()
|
||||
@@ -129,7 +129,7 @@ func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIR
|
||||
}
|
||||
}
|
||||
|
||||
cfg, err := re.backendConfigLoader.LoadBackendConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
|
||||
cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
@@ -152,7 +152,7 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
@@ -168,7 +168,7 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
|
||||
input.Context = ctxWithCorrelationID
|
||||
input.Cancel = cancel
|
||||
|
||||
err := mergeOpenAIRequestAndBackendConfig(cfg, input)
|
||||
err := mergeOpenAIRequestAndModelConfig(cfg, input)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -184,7 +184,7 @@ func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error {
|
||||
return ctx.Next()
|
||||
}
|
||||
|
||||
func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *schema.OpenAIRequest) error {
|
||||
func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error {
|
||||
if input.Echo {
|
||||
config.Echo = input.Echo
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
func RegisterElevenLabsRoutes(app *fiber.App,
|
||||
re *middleware.RequestExtractor,
|
||||
cl *config.BackendConfigLoader,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig) {
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
func RegisterJINARoutes(app *fiber.App,
|
||||
re *middleware.RequestExtractor,
|
||||
cl *config.BackendConfigLoader,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig) {
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
|
||||
func RegisterLocalAIRoutes(router *fiber.App,
|
||||
requestExtractor *middleware.RequestExtractor,
|
||||
cl *config.BackendConfigLoader,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
galleryService *services.GalleryService) {
|
||||
@@ -23,20 +23,23 @@ func RegisterLocalAIRoutes(router *fiber.App,
|
||||
|
||||
// LocalAI API endpoints
|
||||
if !appConfig.DisableGalleryEndpoint {
|
||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.ModelPath, galleryService)
|
||||
modelGalleryEndpointService := localai.CreateModelGalleryEndpointService(appConfig.Galleries, appConfig.BackendGalleries, appConfig.SystemState, galleryService)
|
||||
router.Post("/models/apply", modelGalleryEndpointService.ApplyModelGalleryEndpoint())
|
||||
router.Post("/models/delete/:name", modelGalleryEndpointService.DeleteModelGalleryEndpoint())
|
||||
|
||||
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint())
|
||||
router.Get("/models/available", modelGalleryEndpointService.ListModelFromGalleryEndpoint(appConfig.SystemState))
|
||||
router.Get("/models/galleries", modelGalleryEndpointService.ListModelGalleriesEndpoint())
|
||||
router.Get("/models/jobs/:uuid", modelGalleryEndpointService.GetOpStatusEndpoint())
|
||||
router.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint())
|
||||
|
||||
backendGalleryEndpointService := localai.CreateBackendEndpointService(appConfig.BackendGalleries, appConfig.BackendsPath, galleryService)
|
||||
backendGalleryEndpointService := localai.CreateBackendEndpointService(
|
||||
appConfig.BackendGalleries,
|
||||
appConfig.SystemState,
|
||||
galleryService)
|
||||
router.Post("/backends/apply", backendGalleryEndpointService.ApplyBackendEndpoint())
|
||||
router.Post("/backends/delete/:name", backendGalleryEndpointService.DeleteBackendEndpoint())
|
||||
router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint())
|
||||
router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint())
|
||||
router.Get("/backends", backendGalleryEndpointService.ListBackendsEndpoint(appConfig.SystemState))
|
||||
router.Get("/backends/available", backendGalleryEndpointService.ListAvailableBackendsEndpoint(appConfig.SystemState))
|
||||
router.Get("/backends/galleries", backendGalleryEndpointService.ListBackendGalleriesEndpoint())
|
||||
router.Get("/backends/jobs/:uuid", backendGalleryEndpointService.GetOpStatusEndpoint())
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func RegisterUIRoutes(app *fiber.App,
|
||||
cl *config.BackendConfigLoader,
|
||||
cl *config.ModelConfigLoader,
|
||||
ml *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
galleryService *services.GalleryService) {
|
||||
@@ -65,9 +65,9 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
}
|
||||
|
||||
app.Get("/talk/", func(c *fiber.Ctx) error {
|
||||
backendConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
|
||||
modelConfigs, _ := services.ListModels(cl, ml, config.NoFilterFn, services.SKIP_IF_CONFIGURED)
|
||||
|
||||
if len(backendConfigs) == 0 {
|
||||
if len(modelConfigs) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
}
|
||||
@@ -75,8 +75,8 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Talk",
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsConfig": backendConfigs,
|
||||
"Model": backendConfigs[0],
|
||||
"ModelsConfig": modelConfigs,
|
||||
"Model": modelConfigs[0],
|
||||
|
||||
"Version": internal.PrintableVersion(),
|
||||
}
|
||||
@@ -86,17 +86,17 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
})
|
||||
|
||||
app.Get("/chat/", func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
if len(backendConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
}
|
||||
modelThatCanBeUsed := ""
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
|
||||
for _, m := range backendConfigs {
|
||||
for _, m := range modelConfigs {
|
||||
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -106,7 +106,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
|
||||
title := "LocalAI - Chat"
|
||||
|
||||
for _, b := range backendConfigs {
|
||||
for _, b := range modelConfigs {
|
||||
if b.HasUsecases(config.FLAG_CHAT) {
|
||||
modelThatCanBeUsed = b.Name
|
||||
title = "LocalAI - Chat with " + modelThatCanBeUsed
|
||||
@@ -119,7 +119,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"Model": modelThatCanBeUsed,
|
||||
"Version": internal.PrintableVersion(),
|
||||
}
|
||||
@@ -130,12 +130,12 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
|
||||
// Show the Chat page
|
||||
app.Get("/chat/:model", func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
|
||||
for _, m := range backendConfigs {
|
||||
for _, m := range modelConfigs {
|
||||
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -146,7 +146,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Chat with " + c.Params("model"),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": c.Params("model"),
|
||||
@@ -158,13 +158,13 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
})
|
||||
|
||||
app.Get("/text2image/:model", func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Generate images with " + c.Params("model"),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": c.Params("model"),
|
||||
"Version": internal.PrintableVersion(),
|
||||
@@ -175,10 +175,10 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
})
|
||||
|
||||
app.Get("/text2image/", func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
if len(backendConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
}
|
||||
@@ -186,7 +186,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
modelThatCanBeUsed := ""
|
||||
title := "LocalAI - Generate images"
|
||||
|
||||
for _, b := range backendConfigs {
|
||||
for _, b := range modelConfigs {
|
||||
if b.HasUsecases(config.FLAG_IMAGE) {
|
||||
modelThatCanBeUsed = b.Name
|
||||
title = "LocalAI - Generate images with " + modelThatCanBeUsed
|
||||
@@ -197,7 +197,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
summary := fiber.Map{
|
||||
"Title": title,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": modelThatCanBeUsed,
|
||||
"Version": internal.PrintableVersion(),
|
||||
@@ -208,13 +208,13 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
})
|
||||
|
||||
app.Get("/tts/:model", func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
summary := fiber.Map{
|
||||
"Title": "LocalAI - Generate images with " + c.Params("model"),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": c.Params("model"),
|
||||
"Version": internal.PrintableVersion(),
|
||||
@@ -225,10 +225,10 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
})
|
||||
|
||||
app.Get("/tts/", func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY)
|
||||
|
||||
if len(backendConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
if len(modelConfigs)+len(modelsWithoutConfig) == 0 {
|
||||
// If no model is available redirect to the index which suggests how to install models
|
||||
return c.Redirect(utils.BaseURL(c))
|
||||
}
|
||||
@@ -236,7 +236,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
modelThatCanBeUsed := ""
|
||||
title := "LocalAI - Generate audio"
|
||||
|
||||
for _, b := range backendConfigs {
|
||||
for _, b := range modelConfigs {
|
||||
if b.HasUsecases(config.FLAG_TTS) {
|
||||
modelThatCanBeUsed = b.Name
|
||||
title = "LocalAI - Generate audio with " + modelThatCanBeUsed
|
||||
@@ -246,7 +246,7 @@ func RegisterUIRoutes(app *fiber.App,
|
||||
summary := fiber.Map{
|
||||
"Title": title,
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"ModelsWithoutConfig": modelsWithoutConfig,
|
||||
"Model": modelThatCanBeUsed,
|
||||
"Version": internal.PrintableVersion(),
|
||||
|
||||
@@ -28,7 +28,7 @@ func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationC
|
||||
page := c.Query("page")
|
||||
items := c.Query("items")
|
||||
|
||||
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.BackendsPath)
|
||||
backends, err := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list backends from galleries")
|
||||
return c.Status(fiber.StatusInternalServerError).Render("views/error", fiber.Map{
|
||||
@@ -129,7 +129,7 @@ func registerBackendGalleryRoutes(app *fiber.App, appConfig *config.ApplicationC
|
||||
return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||
}
|
||||
|
||||
backends, _ := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.BackendsPath)
|
||||
backends, _ := gallery.AvailableBackends(appConfig.BackendGalleries, appConfig.SystemState)
|
||||
|
||||
if page != "" {
|
||||
// return a subset of the backends
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
func registerGalleryRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) {
|
||||
|
||||
// Show the Models page (all models)
|
||||
app.Get("/browse", func(c *fiber.Ctx) error {
|
||||
@@ -28,7 +28,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo
|
||||
page := c.Query("page")
|
||||
items := c.Query("items")
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
|
||||
models, err := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return c.Status(fiber.StatusInternalServerError).Render("views/error", fiber.Map{
|
||||
@@ -131,7 +131,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo
|
||||
return c.Status(fiber.StatusBadRequest).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
|
||||
}
|
||||
|
||||
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.ModelPath)
|
||||
models, _ := gallery.AvailableGalleryModels(appConfig.Galleries, appConfig.SystemState)
|
||||
|
||||
if page != "" {
|
||||
// return a subset of the models
|
||||
@@ -224,7 +224,7 @@ func registerGalleryRoutes(app *fiber.App, cl *config.BackendConfigLoader, appCo
|
||||
}
|
||||
go func() {
|
||||
galleryService.ModelGalleryChannel <- op
|
||||
cl.RemoveBackendConfig(galleryName)
|
||||
cl.RemoveModelConfig(galleryName)
|
||||
}()
|
||||
|
||||
return c.SendString(elements.StartModelProgressBar(uid, "0", "Deletion"))
|
||||
|
||||
Reference in New Issue
Block a user