mirror of
https://github.com/mudler/LocalAI.git
synced 2026-05-18 05:33:09 -04:00
feat(backends): add system backend, refactor (#6059)
- Add a system backend path - Refactor and consolidate system information in system state - Use system state in all the components to figure out the system paths to used whenever needed - Refactor BackendConfig -> ModelConfig. This was otherway misleading as now we do have a backend configuration which is not the model config. Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
253b7537dc
commit
089efe05fd
@@ -15,7 +15,7 @@ import (
|
||||
// @Param request body schema.ElevenLabsSoundGenerationRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/sound-generation [post]
|
||||
func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest)
|
||||
@@ -23,7 +23,7 @@ func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoad
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
// @Param request body schema.TTSRequest true "query params"
|
||||
// @Success 200 {string} binary "Response"
|
||||
// @Router /v1/text-to-speech/{voice-id} [post]
|
||||
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
voiceID := c.Params("voice-id")
|
||||
@@ -27,7 +27,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
// @Param request body schema.JINARerankRequest true "query params"
|
||||
// @Success 200 {object} schema.JINARerankResponse "Response"
|
||||
// @Router /v1/rerank [post]
|
||||
func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest)
|
||||
@@ -25,7 +25,7 @@ func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -11,24 +11,27 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type BackendEndpointService struct {
|
||||
galleries []config.Gallery
|
||||
backendPath string
|
||||
backendApplier *services.GalleryService
|
||||
galleries []config.Gallery
|
||||
backendPath string
|
||||
backendSystemPath string
|
||||
backendApplier *services.GalleryService
|
||||
}
|
||||
|
||||
type GalleryBackend struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func CreateBackendEndpointService(galleries []config.Gallery, backendPath string, backendApplier *services.GalleryService) BackendEndpointService {
|
||||
func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *services.GalleryService) BackendEndpointService {
|
||||
return BackendEndpointService{
|
||||
galleries: galleries,
|
||||
backendPath: backendPath,
|
||||
backendApplier: backendApplier,
|
||||
galleries: galleries,
|
||||
backendPath: systemState.Backend.BackendsPath,
|
||||
backendSystemPath: systemState.Backend.BackendsSystemPath,
|
||||
backendApplier: backendApplier,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,9 +114,9 @@ func (mgs *BackendEndpointService) DeleteBackendEndpoint() func(c *fiber.Ctx) er
|
||||
// @Summary List all Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends [get]
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint() func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backends, err := gallery.ListSystemBackends(mgs.backendPath)
|
||||
backends, err := gallery.ListSystemBackends(systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -141,9 +144,9 @@ func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() func(c *fiber.
|
||||
// @Summary List all available Backends
|
||||
// @Success 200 {object} []gallery.GalleryBackend "Response"
|
||||
// @Router /backends/available [get]
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint() func(c *fiber.Ctx) error {
|
||||
func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backends, err := gallery.AvailableBackends(mgs.galleries, mgs.backendPath)
|
||||
backends, err := gallery.AvailableBackends(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
// @Param request body schema.DetectionRequest true "query params"
|
||||
// @Success 200 {object} schema.DetectionResponse "Response"
|
||||
// @Router /v1/detection [post]
|
||||
func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest)
|
||||
@@ -24,7 +24,7 @@ func DetectionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, ap
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/utils"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -26,11 +27,11 @@ type GalleryModel struct {
|
||||
gallery.GalleryModel
|
||||
}
|
||||
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, modelPath string, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService) ModelGalleryEndpointService {
|
||||
return ModelGalleryEndpointService{
|
||||
galleries: galleries,
|
||||
backendGalleries: backendGalleries,
|
||||
modelPath: modelPath,
|
||||
modelPath: systemState.Model.ModelsPath,
|
||||
galleryApplier: galleryApplier,
|
||||
}
|
||||
}
|
||||
@@ -115,10 +116,10 @@ func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() func(c *fib
|
||||
// @Summary List installable models.
|
||||
// @Success 200 {object} []gallery.GalleryModel "Response"
|
||||
// @Router /models/available [get]
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint() func(c *fiber.Ctx) error {
|
||||
func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, mgs.modelPath)
|
||||
models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not list models from galleries")
|
||||
return err
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
// @Router /v1/tokenMetrics [get]
|
||||
// @Router /tokenMetrics [get]
|
||||
func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
input := new(schema.TokenMetricsRequest)
|
||||
@@ -37,7 +37,7 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader,
|
||||
log.Warn().Msgf("Model not found in context: %s", input.Model)
|
||||
}
|
||||
|
||||
cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(modelFile, appConfig)
|
||||
cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(modelFile, appConfig)
|
||||
|
||||
if err != nil {
|
||||
log.Err(err)
|
||||
|
||||
@@ -14,14 +14,14 @@ import (
|
||||
// @Param request body schema.TokenizeRequest true "Request"
|
||||
// @Success 200 {object} schema.TokenizeResponse "Response"
|
||||
// @Router /v1/tokenize [post]
|
||||
func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -22,14 +22,14 @@ import (
|
||||
// @Success 200 {string} binary "generated audio/wav file"
|
||||
// @Router /v1/audio/speech [post]
|
||||
// @Router /tts [post]
|
||||
func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -16,14 +16,14 @@ import (
|
||||
// @Param request body schema.VADRequest true "query params"
|
||||
// @Success 200 {object} proto.VADResponse "Response"
|
||||
// @Router /vad [post]
|
||||
func VADEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || cfg == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func downloadFile(url string) (string, error) {
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /video [post]
|
||||
func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest)
|
||||
if !ok || input.Model == "" {
|
||||
@@ -72,7 +72,7 @@ func VideoEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Video Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
|
||||
@@ -11,12 +11,12 @@ import (
|
||||
)
|
||||
|
||||
func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
cl *config.BackendConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
|
||||
cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) func(*fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
backendConfigs := cl.GetAllBackendConfigs()
|
||||
modelConfigs := cl.GetAllModelsConfigs()
|
||||
galleryConfigs := map[string]*gallery.ModelConfig{}
|
||||
|
||||
for _, m := range backendConfigs {
|
||||
for _, m := range modelConfigs {
|
||||
cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name)
|
||||
if err != nil {
|
||||
continue
|
||||
@@ -34,7 +34,7 @@ func WelcomeEndpoint(appConfig *config.ApplicationConfig,
|
||||
"Version": internal.PrintableVersion(),
|
||||
"BaseURL": utils.BaseURL(c),
|
||||
"Models": modelsWithoutConfig,
|
||||
"ModelsConfig": backendConfigs,
|
||||
"ModelsConfig": modelConfigs,
|
||||
"GalleryConfig": galleryConfigs,
|
||||
"ApplicationConfig": appConfig,
|
||||
"ProcessingModels": processingModels,
|
||||
|
||||
@@ -27,11 +27,11 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/chat/completions [post]
|
||||
func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
var id, textContentToReturn string
|
||||
var created int
|
||||
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
initialMessage := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
@@ -66,7 +66,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
})
|
||||
close(responses)
|
||||
}
|
||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
result := ""
|
||||
_, tokenUsage, _ := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool {
|
||||
result += s
|
||||
@@ -183,7 +183,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
@@ -501,7 +501,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
}
|
||||
}
|
||||
|
||||
func handleQuestion(config *config.BackendConfig, cl *config.BackendConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
|
||||
func handleQuestion(config *config.ModelConfig, cl *config.ModelConfigLoader, input *schema.OpenAIRequest, ml *model.ModelLoader, o *config.ApplicationConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
|
||||
|
||||
if len(funcResults) == 0 && result != "" {
|
||||
log.Debug().Msgf("nothing function results but we had a message from the LLM")
|
||||
|
||||
@@ -27,10 +27,10 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
usage := schema.OpenAIUsage{
|
||||
PromptTokens: tokenUsage.Prompt,
|
||||
@@ -73,7 +73,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, e
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/edits [post]
|
||||
func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
|
||||
return func(c *fiber.Ctx) error {
|
||||
|
||||
@@ -34,7 +34,7 @@ func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, evaluat
|
||||
// Opt-in extra usage flag
|
||||
extraUsage := c.Get("Extra-Usage", "") != ""
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -21,14 +21,14 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/embeddings [post]
|
||||
func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ func downloadFile(url string) (string, error) {
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/images/generations [post]
|
||||
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
@@ -73,7 +73,7 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
log.Error().Msg("Image Endpoint - Invalid Config")
|
||||
return fiber.ErrBadRequest
|
||||
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
func ComputeChoices(
|
||||
req *schema.OpenAIRequest,
|
||||
predInput string,
|
||||
config *config.BackendConfig,
|
||||
bcl *config.BackendConfigLoader,
|
||||
config *config.ModelConfig,
|
||||
bcl *config.ModelConfigLoader,
|
||||
o *config.ApplicationConfig,
|
||||
loader *model.ModelLoader,
|
||||
cb func(string, *[]schema.Choice),
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
// @Summary List and describe the various models available in the API.
|
||||
// @Success 200 {object} schema.ModelsDataResponse "Response"
|
||||
// @Router /v1/models [get]
|
||||
func ListModelsEndpoint(bcl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
|
||||
func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(ctx *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
// If blank, no filter is applied.
|
||||
filter := c.Query("filter")
|
||||
|
||||
@@ -559,7 +559,7 @@ func sendNotImplemented(c *websocket.Conn, message string) {
|
||||
sendError(c, "not_implemented", message, "", "event_TODO")
|
||||
}
|
||||
|
||||
func updateTransSession(session *Session, update *types.ClientSession, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
func updateTransSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
sessionLock.Lock()
|
||||
defer sessionLock.Unlock()
|
||||
|
||||
@@ -589,7 +589,7 @@ func updateTransSession(session *Session, update *types.ClientSession, cl *confi
|
||||
}
|
||||
|
||||
// Function to update session configurations
|
||||
func updateSession(session *Session, update *types.ClientSession, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
func updateSession(session *Session, update *types.ClientSession, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
sessionLock.Lock()
|
||||
defer sessionLock.Unlock()
|
||||
|
||||
@@ -628,7 +628,7 @@ func updateSession(session *Session, update *types.ClientSession, cl *config.Bac
|
||||
|
||||
// handleVAD is a goroutine that listens for audio data from the client,
|
||||
// runs VAD on the audio data, and commits utterances to the conversation
|
||||
func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
|
||||
func handleVAD(cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn, done chan struct{}) {
|
||||
vadContext, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
<-done
|
||||
@@ -742,7 +742,7 @@ func handleVAD(cfg *config.BackendConfig, evaluator *templates.Evaluator, sessio
|
||||
}
|
||||
}
|
||||
|
||||
func commitUtterance(ctx context.Context, utt []byte, cfg *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
|
||||
func commitUtterance(ctx context.Context, utt []byte, cfg *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conv *Conversation, c *websocket.Conn) {
|
||||
if len(utt) == 0 {
|
||||
return
|
||||
}
|
||||
@@ -853,7 +853,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADS
|
||||
|
||||
// TODO: Below needed for normal mode instead of transcription only
|
||||
// Function to generate a response based on the conversation
|
||||
// func generateResponse(config *config.BackendConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {
|
||||
// func generateResponse(config *config.ModelConfig, evaluator *templates.Evaluator, session *Session, conversation *Conversation, responseCreate ResponseCreate, c *websocket.Conn, mt int) {
|
||||
//
|
||||
// log.Debug().Msg("Generating realtime response...")
|
||||
//
|
||||
@@ -1067,7 +1067,7 @@ func runVAD(ctx context.Context, session *Session, adata []int16) ([]*proto.VADS
|
||||
// }
|
||||
|
||||
// Function to process text response and detect function calls
|
||||
func processTextResponse(config *config.BackendConfig, session *Session, prompt string) (string, *FunctionCall, error) {
|
||||
func processTextResponse(config *config.ModelConfig, session *Session, prompt string) (string, *FunctionCall, error) {
|
||||
|
||||
// Placeholder implementation
|
||||
// Replace this with actual model inference logic using session.Model and prompt
|
||||
|
||||
@@ -22,14 +22,14 @@ var (
|
||||
// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods
|
||||
// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS)
|
||||
type wrappedModel struct {
|
||||
TTSConfig *config.BackendConfig
|
||||
TranscriptionConfig *config.BackendConfig
|
||||
LLMConfig *config.BackendConfig
|
||||
TTSConfig *config.ModelConfig
|
||||
TranscriptionConfig *config.ModelConfig
|
||||
LLMConfig *config.ModelConfig
|
||||
TTSClient grpcClient.Backend
|
||||
TranscriptionClient grpcClient.Backend
|
||||
LLMClient grpcClient.Backend
|
||||
|
||||
VADConfig *config.BackendConfig
|
||||
VADConfig *config.ModelConfig
|
||||
VADClient grpcClient.Backend
|
||||
}
|
||||
|
||||
@@ -37,17 +37,17 @@ type wrappedModel struct {
|
||||
// We have to wrap this out as well because we want to load two models one for VAD and one for the actual model.
|
||||
// In the future there could be models that accept continous audio input only so this design will be useful for that
|
||||
type anyToAnyModel struct {
|
||||
LLMConfig *config.BackendConfig
|
||||
LLMConfig *config.ModelConfig
|
||||
LLMClient grpcClient.Backend
|
||||
|
||||
VADConfig *config.BackendConfig
|
||||
VADConfig *config.ModelConfig
|
||||
VADClient grpcClient.Backend
|
||||
}
|
||||
|
||||
type transcriptOnlyModel struct {
|
||||
TranscriptionConfig *config.BackendConfig
|
||||
TranscriptionConfig *config.ModelConfig
|
||||
TranscriptionClient grpcClient.Backend
|
||||
VADConfig *config.BackendConfig
|
||||
VADConfig *config.ModelConfig
|
||||
VADClient grpcClient.Backend
|
||||
}
|
||||
|
||||
@@ -105,8 +105,8 @@ func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOpti
|
||||
return m.LLMClient.PredictStream(ctx, in, f)
|
||||
}
|
||||
|
||||
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.BackendConfig, error) {
|
||||
cfgVAD, err := cl.LoadBackendConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) {
|
||||
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -122,7 +122,7 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConf
|
||||
return nil, nil, fmt.Errorf("failed to load tts model: %w", err)
|
||||
}
|
||||
|
||||
cfgSST, err := cl.LoadBackendConfigFileByName(pipeline.Transcription, ml.ModelPath)
|
||||
cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -139,17 +139,17 @@ func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.BackendConf
|
||||
}
|
||||
|
||||
return &transcriptOnlyModel{
|
||||
VADConfig: cfgVAD,
|
||||
VADClient: VADClient,
|
||||
VADConfig: cfgVAD,
|
||||
VADClient: VADClient,
|
||||
TranscriptionConfig: cfgSST,
|
||||
TranscriptionClient: transcriptionClient,
|
||||
}, cfgSST, nil
|
||||
}
|
||||
|
||||
// returns and loads either a wrapped model or a model that support audio-to-audio
|
||||
func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) {
|
||||
func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, error) {
|
||||
|
||||
cfgVAD, err := cl.LoadBackendConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -166,7 +166,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
}
|
||||
|
||||
// TODO: Do we always need a transcription model? It can be disabled. Note that any-to-any instruction following models don't transcribe as such, so if transcription is required it is a separate process
|
||||
cfgSST, err := cl.LoadBackendConfigFileByName(pipeline.Transcription, ml.ModelPath)
|
||||
cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -185,7 +185,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
// TODO: Decide when we have a real any-to-any model
|
||||
if false {
|
||||
|
||||
cfgAnyToAny, err := cl.LoadBackendConfigFileByName(pipeline.LLM, ml.ModelPath)
|
||||
cfgAnyToAny, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -212,7 +212,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
log.Debug().Msg("Loading a wrapped model")
|
||||
|
||||
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
|
||||
cfgLLM, err := cl.LoadBackendConfigFileByName(pipeline.LLM, ml.ModelPath)
|
||||
cfgLLM, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -222,7 +222,7 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
cfgTTS, err := cl.LoadBackendConfigFileByName(pipeline.TTS, ml.ModelPath)
|
||||
cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath)
|
||||
if err != nil {
|
||||
|
||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||
@@ -232,7 +232,6 @@ func newModel(pipeline *config.Pipeline, cl *config.BackendConfigLoader, ml *mod
|
||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||
}
|
||||
|
||||
|
||||
opts = backend.ModelOptions(*cfgTTS, appConfig)
|
||||
ttsClient, err := ml.Load(opts...)
|
||||
if err != nil {
|
||||
|
||||
@@ -24,14 +24,14 @@ import (
|
||||
// @Param file formData file true "file"
|
||||
// @Success 200 {object} map[string]string "Response"
|
||||
// @Router /v1/audio/transcriptions [post]
|
||||
func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
||||
return func(c *fiber.Ctx) error {
|
||||
input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
if !ok || input.Model == "" {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig)
|
||||
config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig)
|
||||
if !ok || config == nil {
|
||||
return fiber.ErrBadRequest
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user