Files
LocalAI/pkg/mcp/localaitools/fakes_test.go
Richard Palethorpe 6a80e23733 feat(middleware): Model routing, PII filtering, Cloud model proxies (#9802)
Add a routing middleware stack and a cloud-proxy backend.

* cloud-proxy: a Go gRPC backend that forwards OpenAI- and
  Anthropic-shaped chat requests to upstream providers, with an
  optional translate mode (OpenAI request -> Anthropic /v1/messages
  -> OpenAI response) and full tool-calling support.

* routing: admission control, content-aware model routing
  (embedding cache + classifier + rerank + Arch-Router score),
  PII detection/redaction (regex + NER) with streaming filter and
  OpenAI/Anthropic adapters, and a per-user/per-key billing recorder
  backed by GORM or in-memory storage.

* middleware: UsageMiddleware records usage via the billing recorder,
  plus admission, route-model, usage-stamp and trace middlewares.

* observability: BackendTrace ring buffer stores full request bodies
  (capped), MITM proxy emits structured trace events, and router
  classifier decisions surface at /api/router/decide.

* gallery: Arch-Router-1.5B (Q4_K_M and Q8_0).

* UI: cloud-proxy model-editor fields, classifier system-prompt and
  score-normalization config, and a Traces page rendering request
  bodies.

Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-05-25 09:28:27 +02:00

316 lines
9.2 KiB
Go

package localaitools
import (
"context"
"errors"
"sync"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services/modeladmin"
"github.com/mudler/LocalAI/pkg/vram"
)
// fakeClient is a recording, configurable LocalAIClient for unit tests.
// Each method records the args it was called with and returns whatever the
// matching field on the struct is configured to return. Methods are guarded
// by a mutex so tests can run with -race.
type fakeClient struct {
mu sync.Mutex
// Recorded calls (in order).
calls []fakeCall
// Per-method overrides. Tests set these.
gallerySearch func(GallerySearchQuery) ([]gallery.Metadata, error)
listInstalledModels func(Capability) ([]InstalledModel, error)
listGalleries func() ([]config.Gallery, error)
getJobStatus func(string) (*JobStatus, error)
getModelConfig func(string) (*ModelConfigView, error)
installModel func(InstallModelRequest) (string, error)
importModelURI func(ImportModelURIRequest) (*ImportModelURIResponse, error)
deleteModel func(string) error
editModelConfig func(string, map[string]any) error
reloadModels func() error
listBackends func() ([]Backend, error)
listKnownBackends func() ([]schema.KnownBackend, error)
installBackend func(InstallBackendRequest) (string, error)
upgradeBackend func(string) (string, error)
systemInfo func() (*SystemInfo, error)
listNodes func() ([]Node, error)
vramEstimate func(VRAMEstimateRequest) (*vram.EstimateResult, error)
toggleModelState func(string, modeladmin.Action) error
toggleModelPinned func(string, modeladmin.Action) error
getBranding func() (*Branding, error)
setBranding func(SetBrandingRequest) (*Branding, error)
getUsageStats func(UsageStatsQuery) (*UsageStats, error)
listPIIPatterns func() ([]PIIPattern, error)
getPIIEvents func(PIIEventsQuery) ([]PIIEvent, error)
testPIIRedaction func(PIIRedactTestRequest) (*PIIRedactTestResult, error)
setPIIPatternAction func(PIIPatternActionUpdate) error
getMiddlewareStatus func() (*MiddlewareStatus, error)
getRouterDecisions func(RouterDecisionsQuery) ([]RouterDecision, error)
}
type fakeCall struct {
method string
args any
}
func (f *fakeClient) record(method string, args any) {
f.mu.Lock()
defer f.mu.Unlock()
f.calls = append(f.calls, fakeCall{method: method, args: args})
}
func (f *fakeClient) recorded() []fakeCall {
f.mu.Lock()
defer f.mu.Unlock()
out := make([]fakeCall, len(f.calls))
copy(out, f.calls)
return out
}
var errNotConfigured = errors.New("fakeClient method not configured")
func (f *fakeClient) GallerySearch(_ context.Context, q GallerySearchQuery) ([]gallery.Metadata, error) {
f.record("GallerySearch", q)
if f.gallerySearch != nil {
return f.gallerySearch(q)
}
return nil, nil
}
func (f *fakeClient) ListInstalledModels(_ context.Context, capability Capability) ([]InstalledModel, error) {
f.record("ListInstalledModels", capability)
if f.listInstalledModels != nil {
return f.listInstalledModels(capability)
}
return nil, nil
}
func (f *fakeClient) ListGalleries(_ context.Context) ([]config.Gallery, error) {
f.record("ListGalleries", nil)
if f.listGalleries != nil {
return f.listGalleries()
}
return nil, nil
}
func (f *fakeClient) GetJobStatus(_ context.Context, jobID string) (*JobStatus, error) {
f.record("GetJobStatus", jobID)
if f.getJobStatus != nil {
return f.getJobStatus(jobID)
}
return nil, errNotConfigured
}
func (f *fakeClient) GetModelConfig(_ context.Context, name string) (*ModelConfigView, error) {
f.record("GetModelConfig", name)
if f.getModelConfig != nil {
return f.getModelConfig(name)
}
return nil, errNotConfigured
}
func (f *fakeClient) InstallModel(_ context.Context, req InstallModelRequest) (string, error) {
f.record("InstallModel", req)
if f.installModel != nil {
return f.installModel(req)
}
return "", errNotConfigured
}
func (f *fakeClient) DeleteModel(_ context.Context, name string) error {
f.record("DeleteModel", name)
if f.deleteModel != nil {
return f.deleteModel(name)
}
return nil
}
func (f *fakeClient) ImportModelURI(_ context.Context, req ImportModelURIRequest) (*ImportModelURIResponse, error) {
f.record("ImportModelURI", req)
if f.importModelURI != nil {
return f.importModelURI(req)
}
return &ImportModelURIResponse{JobID: "fake-import-job"}, nil
}
func (f *fakeClient) EditModelConfig(_ context.Context, name string, patch map[string]any) error {
f.record("EditModelConfig", []any{name, patch})
if f.editModelConfig != nil {
return f.editModelConfig(name, patch)
}
return nil
}
func (f *fakeClient) ReloadModels(_ context.Context) error {
f.record("ReloadModels", nil)
if f.reloadModels != nil {
return f.reloadModels()
}
return nil
}
func (f *fakeClient) ListBackends(_ context.Context) ([]Backend, error) {
f.record("ListBackends", nil)
if f.listBackends != nil {
return f.listBackends()
}
return nil, nil
}
func (f *fakeClient) ListKnownBackends(_ context.Context) ([]schema.KnownBackend, error) {
f.record("ListKnownBackends", nil)
if f.listKnownBackends != nil {
return f.listKnownBackends()
}
return nil, nil
}
func (f *fakeClient) InstallBackend(_ context.Context, req InstallBackendRequest) (string, error) {
f.record("InstallBackend", req)
if f.installBackend != nil {
return f.installBackend(req)
}
return "", errNotConfigured
}
func (f *fakeClient) UpgradeBackend(_ context.Context, name string) (string, error) {
f.record("UpgradeBackend", name)
if f.upgradeBackend != nil {
return f.upgradeBackend(name)
}
return "", errNotConfigured
}
func (f *fakeClient) SystemInfo(_ context.Context) (*SystemInfo, error) {
f.record("SystemInfo", nil)
if f.systemInfo != nil {
return f.systemInfo()
}
return &SystemInfo{Version: "test"}, nil
}
func (f *fakeClient) ListNodes(_ context.Context) ([]Node, error) {
f.record("ListNodes", nil)
if f.listNodes != nil {
return f.listNodes()
}
return nil, nil
}
func (f *fakeClient) VRAMEstimate(_ context.Context, req VRAMEstimateRequest) (*vram.EstimateResult, error) {
f.record("VRAMEstimate", req)
if f.vramEstimate != nil {
return f.vramEstimate(req)
}
return nil, errNotConfigured
}
func (f *fakeClient) ToggleModelState(_ context.Context, name string, action modeladmin.Action) error {
f.record("ToggleModelState", []any{name, action})
if f.toggleModelState != nil {
return f.toggleModelState(name, action)
}
return nil
}
func (f *fakeClient) ToggleModelPinned(_ context.Context, name string, action modeladmin.Action) error {
f.record("ToggleModelPinned", []any{name, action})
if f.toggleModelPinned != nil {
return f.toggleModelPinned(name, action)
}
return nil
}
func (f *fakeClient) GetBranding(_ context.Context) (*Branding, error) {
f.record("GetBranding", nil)
if f.getBranding != nil {
return f.getBranding()
}
return &Branding{InstanceName: "LocalAI"}, nil
}
func (f *fakeClient) SetBranding(_ context.Context, req SetBrandingRequest) (*Branding, error) {
f.record("SetBranding", req)
if f.setBranding != nil {
return f.setBranding(req)
}
return &Branding{InstanceName: "LocalAI"}, nil
}
func (f *fakeClient) GetUsageStats(_ context.Context, q UsageStatsQuery) (*UsageStats, error) {
f.record("GetUsageStats", q)
if f.getUsageStats != nil {
return f.getUsageStats(q)
}
return &UsageStats{
Viewer: UsageViewer{ID: "fake-user", Name: "fake", Role: "user"},
Period: "month",
}, nil
}
func (f *fakeClient) ListPIIPatterns(_ context.Context) ([]PIIPattern, error) {
f.record("ListPIIPatterns", nil)
if f.listPIIPatterns != nil {
return f.listPIIPatterns()
}
return []PIIPattern{}, nil
}
func (f *fakeClient) GetPIIEvents(_ context.Context, q PIIEventsQuery) ([]PIIEvent, error) {
f.record("GetPIIEvents", q)
if f.getPIIEvents != nil {
return f.getPIIEvents(q)
}
return []PIIEvent{}, nil
}
func (f *fakeClient) TestPIIRedaction(_ context.Context, req PIIRedactTestRequest) (*PIIRedactTestResult, error) {
f.record("TestPIIRedaction", req)
if f.testPIIRedaction != nil {
return f.testPIIRedaction(req)
}
return &PIIRedactTestResult{Redacted: req.Text}, nil
}
func (f *fakeClient) SetPIIPatternAction(_ context.Context, req PIIPatternActionUpdate) error {
f.record("SetPIIPatternAction", req)
if f.setPIIPatternAction != nil {
return f.setPIIPatternAction(req)
}
return nil
}
func (f *fakeClient) PersistPIIPatterns(_ context.Context) error {
f.record("PersistPIIPatterns", nil)
return nil
}
func (f *fakeClient) GetRouterDecisions(_ context.Context, q RouterDecisionsQuery) ([]RouterDecision, error) {
f.record("GetRouterDecisions", q)
if f.getRouterDecisions != nil {
return f.getRouterDecisions(q)
}
return []RouterDecision{}, nil
}
func (f *fakeClient) GetMiddlewareStatus(_ context.Context) (*MiddlewareStatus, error) {
f.record("GetMiddlewareStatus", nil)
if f.getMiddlewareStatus != nil {
return f.getMiddlewareStatus()
}
return &MiddlewareStatus{
PII: MiddlewarePIIStatus{
EnabledGlobally: true,
Patterns: []PIIPattern{},
Models: []MiddlewarePIIModel{},
},
Router: MiddlewareRouterStatus{Configured: false, Models: []string{}},
}, nil
}