mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-14 11:49:33 -04:00
Add a routing middleware stack and a cloud-proxy backend. * cloud-proxy: a Go gRPC backend that forwards OpenAI- and Anthropic-shaped chat requests to upstream providers, with an optional translate mode (OpenAI request -> Anthropic /v1/messages -> OpenAI response) and full tool-calling support. * routing: admission control, content-aware model routing (embedding cache + classifier + rerank + Arch-Router score), PII detection/redaction (regex + NER) with streaming filter and OpenAI/Anthropic adapters, and a per-user/per-key billing recorder backed by GORM or in-memory storage. * middleware: UsageMiddleware records usage via the billing recorder, plus admission, route-model, usage-stamp and trace middlewares. * observability: BackendTrace ring buffer stores full request bodies (capped), MITM proxy emits structured trace events, and router classifier decisions surface at /api/router/decide. * gallery: Arch-Router-1.5B (Q4_K_M and Q8_0). * UI: cloud-proxy model-editor fields, classifier system-prompt and score-normalization config, and a Traces page rendering request bodies. Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash] Signed-off-by: Richard Palethorpe <io@richiejp.com>
316 lines
9.2 KiB
Go
316 lines
9.2 KiB
Go
package localaitools
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync"
|
|
|
|
"github.com/mudler/LocalAI/core/config"
|
|
"github.com/mudler/LocalAI/core/gallery"
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
"github.com/mudler/LocalAI/core/services/modeladmin"
|
|
"github.com/mudler/LocalAI/pkg/vram"
|
|
)
|
|
|
|
// fakeClient is a recording, configurable LocalAIClient for unit tests.
|
|
// Each method records the args it was called with and returns whatever the
|
|
// matching field on the struct is configured to return. Methods are guarded
|
|
// by a mutex so tests can run with -race.
|
|
type fakeClient struct {
|
|
mu sync.Mutex
|
|
|
|
// Recorded calls (in order).
|
|
calls []fakeCall
|
|
|
|
// Per-method overrides. Tests set these.
|
|
gallerySearch func(GallerySearchQuery) ([]gallery.Metadata, error)
|
|
listInstalledModels func(Capability) ([]InstalledModel, error)
|
|
listGalleries func() ([]config.Gallery, error)
|
|
getJobStatus func(string) (*JobStatus, error)
|
|
getModelConfig func(string) (*ModelConfigView, error)
|
|
installModel func(InstallModelRequest) (string, error)
|
|
importModelURI func(ImportModelURIRequest) (*ImportModelURIResponse, error)
|
|
deleteModel func(string) error
|
|
editModelConfig func(string, map[string]any) error
|
|
reloadModels func() error
|
|
listBackends func() ([]Backend, error)
|
|
listKnownBackends func() ([]schema.KnownBackend, error)
|
|
installBackend func(InstallBackendRequest) (string, error)
|
|
upgradeBackend func(string) (string, error)
|
|
systemInfo func() (*SystemInfo, error)
|
|
listNodes func() ([]Node, error)
|
|
vramEstimate func(VRAMEstimateRequest) (*vram.EstimateResult, error)
|
|
toggleModelState func(string, modeladmin.Action) error
|
|
toggleModelPinned func(string, modeladmin.Action) error
|
|
getBranding func() (*Branding, error)
|
|
setBranding func(SetBrandingRequest) (*Branding, error)
|
|
getUsageStats func(UsageStatsQuery) (*UsageStats, error)
|
|
listPIIPatterns func() ([]PIIPattern, error)
|
|
getPIIEvents func(PIIEventsQuery) ([]PIIEvent, error)
|
|
testPIIRedaction func(PIIRedactTestRequest) (*PIIRedactTestResult, error)
|
|
setPIIPatternAction func(PIIPatternActionUpdate) error
|
|
getMiddlewareStatus func() (*MiddlewareStatus, error)
|
|
getRouterDecisions func(RouterDecisionsQuery) ([]RouterDecision, error)
|
|
}
|
|
|
|
type fakeCall struct {
|
|
method string
|
|
args any
|
|
}
|
|
|
|
func (f *fakeClient) record(method string, args any) {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
f.calls = append(f.calls, fakeCall{method: method, args: args})
|
|
}
|
|
|
|
func (f *fakeClient) recorded() []fakeCall {
|
|
f.mu.Lock()
|
|
defer f.mu.Unlock()
|
|
out := make([]fakeCall, len(f.calls))
|
|
copy(out, f.calls)
|
|
return out
|
|
}
|
|
|
|
var errNotConfigured = errors.New("fakeClient method not configured")
|
|
|
|
func (f *fakeClient) GallerySearch(_ context.Context, q GallerySearchQuery) ([]gallery.Metadata, error) {
|
|
f.record("GallerySearch", q)
|
|
if f.gallerySearch != nil {
|
|
return f.gallerySearch(q)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeClient) ListInstalledModels(_ context.Context, capability Capability) ([]InstalledModel, error) {
|
|
f.record("ListInstalledModels", capability)
|
|
if f.listInstalledModels != nil {
|
|
return f.listInstalledModels(capability)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeClient) ListGalleries(_ context.Context) ([]config.Gallery, error) {
|
|
f.record("ListGalleries", nil)
|
|
if f.listGalleries != nil {
|
|
return f.listGalleries()
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeClient) GetJobStatus(_ context.Context, jobID string) (*JobStatus, error) {
|
|
f.record("GetJobStatus", jobID)
|
|
if f.getJobStatus != nil {
|
|
return f.getJobStatus(jobID)
|
|
}
|
|
return nil, errNotConfigured
|
|
}
|
|
|
|
func (f *fakeClient) GetModelConfig(_ context.Context, name string) (*ModelConfigView, error) {
|
|
f.record("GetModelConfig", name)
|
|
if f.getModelConfig != nil {
|
|
return f.getModelConfig(name)
|
|
}
|
|
return nil, errNotConfigured
|
|
}
|
|
|
|
func (f *fakeClient) InstallModel(_ context.Context, req InstallModelRequest) (string, error) {
|
|
f.record("InstallModel", req)
|
|
if f.installModel != nil {
|
|
return f.installModel(req)
|
|
}
|
|
return "", errNotConfigured
|
|
}
|
|
|
|
func (f *fakeClient) DeleteModel(_ context.Context, name string) error {
|
|
f.record("DeleteModel", name)
|
|
if f.deleteModel != nil {
|
|
return f.deleteModel(name)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeClient) ImportModelURI(_ context.Context, req ImportModelURIRequest) (*ImportModelURIResponse, error) {
|
|
f.record("ImportModelURI", req)
|
|
if f.importModelURI != nil {
|
|
return f.importModelURI(req)
|
|
}
|
|
return &ImportModelURIResponse{JobID: "fake-import-job"}, nil
|
|
}
|
|
|
|
func (f *fakeClient) EditModelConfig(_ context.Context, name string, patch map[string]any) error {
|
|
f.record("EditModelConfig", []any{name, patch})
|
|
if f.editModelConfig != nil {
|
|
return f.editModelConfig(name, patch)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeClient) ReloadModels(_ context.Context) error {
|
|
f.record("ReloadModels", nil)
|
|
if f.reloadModels != nil {
|
|
return f.reloadModels()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeClient) ListBackends(_ context.Context) ([]Backend, error) {
|
|
f.record("ListBackends", nil)
|
|
if f.listBackends != nil {
|
|
return f.listBackends()
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeClient) ListKnownBackends(_ context.Context) ([]schema.KnownBackend, error) {
|
|
f.record("ListKnownBackends", nil)
|
|
if f.listKnownBackends != nil {
|
|
return f.listKnownBackends()
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeClient) InstallBackend(_ context.Context, req InstallBackendRequest) (string, error) {
|
|
f.record("InstallBackend", req)
|
|
if f.installBackend != nil {
|
|
return f.installBackend(req)
|
|
}
|
|
return "", errNotConfigured
|
|
}
|
|
|
|
func (f *fakeClient) UpgradeBackend(_ context.Context, name string) (string, error) {
|
|
f.record("UpgradeBackend", name)
|
|
if f.upgradeBackend != nil {
|
|
return f.upgradeBackend(name)
|
|
}
|
|
return "", errNotConfigured
|
|
}
|
|
|
|
func (f *fakeClient) SystemInfo(_ context.Context) (*SystemInfo, error) {
|
|
f.record("SystemInfo", nil)
|
|
if f.systemInfo != nil {
|
|
return f.systemInfo()
|
|
}
|
|
return &SystemInfo{Version: "test"}, nil
|
|
}
|
|
|
|
func (f *fakeClient) ListNodes(_ context.Context) ([]Node, error) {
|
|
f.record("ListNodes", nil)
|
|
if f.listNodes != nil {
|
|
return f.listNodes()
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *fakeClient) VRAMEstimate(_ context.Context, req VRAMEstimateRequest) (*vram.EstimateResult, error) {
|
|
f.record("VRAMEstimate", req)
|
|
if f.vramEstimate != nil {
|
|
return f.vramEstimate(req)
|
|
}
|
|
return nil, errNotConfigured
|
|
}
|
|
|
|
func (f *fakeClient) ToggleModelState(_ context.Context, name string, action modeladmin.Action) error {
|
|
f.record("ToggleModelState", []any{name, action})
|
|
if f.toggleModelState != nil {
|
|
return f.toggleModelState(name, action)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeClient) ToggleModelPinned(_ context.Context, name string, action modeladmin.Action) error {
|
|
f.record("ToggleModelPinned", []any{name, action})
|
|
if f.toggleModelPinned != nil {
|
|
return f.toggleModelPinned(name, action)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeClient) GetBranding(_ context.Context) (*Branding, error) {
|
|
f.record("GetBranding", nil)
|
|
if f.getBranding != nil {
|
|
return f.getBranding()
|
|
}
|
|
return &Branding{InstanceName: "LocalAI"}, nil
|
|
}
|
|
|
|
func (f *fakeClient) SetBranding(_ context.Context, req SetBrandingRequest) (*Branding, error) {
|
|
f.record("SetBranding", req)
|
|
if f.setBranding != nil {
|
|
return f.setBranding(req)
|
|
}
|
|
return &Branding{InstanceName: "LocalAI"}, nil
|
|
}
|
|
|
|
func (f *fakeClient) GetUsageStats(_ context.Context, q UsageStatsQuery) (*UsageStats, error) {
|
|
f.record("GetUsageStats", q)
|
|
if f.getUsageStats != nil {
|
|
return f.getUsageStats(q)
|
|
}
|
|
return &UsageStats{
|
|
Viewer: UsageViewer{ID: "fake-user", Name: "fake", Role: "user"},
|
|
Period: "month",
|
|
}, nil
|
|
}
|
|
|
|
func (f *fakeClient) ListPIIPatterns(_ context.Context) ([]PIIPattern, error) {
|
|
f.record("ListPIIPatterns", nil)
|
|
if f.listPIIPatterns != nil {
|
|
return f.listPIIPatterns()
|
|
}
|
|
return []PIIPattern{}, nil
|
|
}
|
|
|
|
func (f *fakeClient) GetPIIEvents(_ context.Context, q PIIEventsQuery) ([]PIIEvent, error) {
|
|
f.record("GetPIIEvents", q)
|
|
if f.getPIIEvents != nil {
|
|
return f.getPIIEvents(q)
|
|
}
|
|
return []PIIEvent{}, nil
|
|
}
|
|
|
|
func (f *fakeClient) TestPIIRedaction(_ context.Context, req PIIRedactTestRequest) (*PIIRedactTestResult, error) {
|
|
f.record("TestPIIRedaction", req)
|
|
if f.testPIIRedaction != nil {
|
|
return f.testPIIRedaction(req)
|
|
}
|
|
return &PIIRedactTestResult{Redacted: req.Text}, nil
|
|
}
|
|
|
|
func (f *fakeClient) SetPIIPatternAction(_ context.Context, req PIIPatternActionUpdate) error {
|
|
f.record("SetPIIPatternAction", req)
|
|
if f.setPIIPatternAction != nil {
|
|
return f.setPIIPatternAction(req)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeClient) PersistPIIPatterns(_ context.Context) error {
|
|
f.record("PersistPIIPatterns", nil)
|
|
return nil
|
|
}
|
|
|
|
func (f *fakeClient) GetRouterDecisions(_ context.Context, q RouterDecisionsQuery) ([]RouterDecision, error) {
|
|
f.record("GetRouterDecisions", q)
|
|
if f.getRouterDecisions != nil {
|
|
return f.getRouterDecisions(q)
|
|
}
|
|
return []RouterDecision{}, nil
|
|
}
|
|
|
|
func (f *fakeClient) GetMiddlewareStatus(_ context.Context) (*MiddlewareStatus, error) {
|
|
f.record("GetMiddlewareStatus", nil)
|
|
if f.getMiddlewareStatus != nil {
|
|
return f.getMiddlewareStatus()
|
|
}
|
|
return &MiddlewareStatus{
|
|
PII: MiddlewarePIIStatus{
|
|
EnabledGlobally: true,
|
|
Patterns: []PIIPattern{},
|
|
Models: []MiddlewarePIIModel{},
|
|
},
|
|
Router: MiddlewareRouterStatus{Configured: false, Models: []string{}},
|
|
}, nil
|
|
}
|
|
|