Files
LocalAI/backend/go/cloud-proxy/provider_openai_test.go
Richard Palethorpe 6a80e23733 feat(middleware): Model routing, PII filtering, Cloud model proxies (#9802)
Add a routing middleware stack and a cloud-proxy backend.

* cloud-proxy: a Go gRPC backend that forwards OpenAI- and
  Anthropic-shaped chat requests to upstream providers, with an
  optional translate mode (OpenAI request -> Anthropic /v1/messages
  -> OpenAI response) and full tool-calling support.

* routing: admission control, content-aware model routing
  (embedding cache + classifier + rerank + Arch-Router score),
  PII detection/redaction (regex + NER) with streaming filter and
  OpenAI/Anthropic adapters, and a per-user/per-key billing recorder
  backed by GORM or in-memory storage.

* middleware: UsageMiddleware records usage via the billing recorder,
  plus admission, route-model, usage-stamp and trace middlewares.

* observability: BackendTrace ring buffer stores full request bodies
  (capped), MITM proxy emits structured trace events, and router
  classifier decisions surface at /api/router/decide.

* gallery: Arch-Router-1.5B (Q4_K_M and Q8_0).

* UI: cloud-proxy model-editor fields, classifier system-prompt and
  score-normalization config, and a Traces page rendering request
  bodies.

Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
2026-05-25 09:28:27 +02:00

171 lines
5.4 KiB
Go

package main
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
. "github.com/onsi/gomega"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)
// fakeOpenAIUpstream returns an httptest.Server that decodes the
// inbound request as an openAIRequest, calls handler with it, and
// writes the handler's reply as the response.
func fakeOpenAIUpstream(t *testing.T, handler func(req openAIRequest) (status int, body string, contentType string)) (*httptest.Server, *openAIRequest) {
t.Helper()
var captured openAIRequest
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
raw, _ := io.ReadAll(r.Body)
_ = json.Unmarshal(raw, &captured)
status, body, ct := handler(captured)
w.Header().Set("Content-Type", ct)
w.WriteHeader(status)
_, _ = io.WriteString(w, body)
}))
return srv, &captured
}
func newTranslateCloudProxy(t *testing.T, upstreamURL string) *CloudProxy {
t.Helper()
g := NewWithT(t)
t.Setenv("CLOUD_PROXY_OPENAI_FAKE", "sk-fake-openai")
cp := NewCloudProxy()
err := cp.Load(&pb.ModelOptions{
Model: "gpt-4o-local",
Proxy: &pb.ProxyOptions{
UpstreamUrl: upstreamURL,
Mode: modeTranslate,
Provider: providerOpenAI,
ApiKeyEnv: "CLOUD_PROXY_OPENAI_FAKE",
UpstreamModel: "gpt-4o",
},
})
g.Expect(err).NotTo(HaveOccurred())
return cp
}
func TestPredict_OpenAI_BasicChat(t *testing.T) {
g := NewWithT(t)
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
return 200, `{"id":"resp-1","choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}],"usage":{"prompt_tokens":5,"completion_tokens":2,"total_tokens":7}}`, "application/json"
})
defer srv.Close()
cp := newTranslateCloudProxy(t, srv.URL)
got, err := cp.Predict(&pb.PredictOptions{
Messages: []*pb.Message{
{Role: "system", Content: "be brief"},
{Role: "user", Content: "hello"},
},
Temperature: 0.5,
TopP: 0.9,
Tokens: 32,
})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(got).To(Equal("hi there"))
// Verify the upstream saw a properly-translated request.
g.Expect(captured.Model).To(Equal("gpt-4o"))
g.Expect(captured.Messages).To(HaveLen(2))
g.Expect(captured.Messages[0].Role).To(Equal("system"))
g.Expect(captured.Messages[1].Role).To(Equal("user"))
g.Expect(captured.Temperature).NotTo(BeNil())
g.Expect(*captured.Temperature).To(Equal(0.5))
g.Expect(captured.MaxTokens).NotTo(BeNil())
g.Expect(*captured.MaxTokens).To(Equal(int32(32)))
g.Expect(captured.Stream).To(BeFalse())
}
func TestPredict_OpenAI_PromptFallback(t *testing.T) {
g := NewWithT(t)
// No Messages array — backend should synth a single user message
// from Prompt so non-chat clients still route through translate.
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
return 200, `{"choices":[{"message":{"role":"assistant","content":"ok"}}]}`, "application/json"
})
defer srv.Close()
cp := newTranslateCloudProxy(t, srv.URL)
_, err := cp.Predict(&pb.PredictOptions{Prompt: "what time is it?"})
g.Expect(err).NotTo(HaveOccurred())
g.Expect(captured.Messages).To(HaveLen(1))
g.Expect(captured.Messages[0].Role).To(Equal("user"))
g.Expect(captured.Messages[0].Content).To(Equal("what time is it?"))
}
func TestPredict_OpenAI_UpstreamError(t *testing.T) {
g := NewWithT(t)
srv, _ := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
return 401, `{"error":{"message":"bad key"}}`, "application/json"
})
defer srv.Close()
cp := newTranslateCloudProxy(t, srv.URL)
_, err := cp.Predict(&pb.PredictOptions{Messages: []*pb.Message{{Role: "user", Content: "x"}}})
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("401"))
}
func TestPredictStream_OpenAI_StreamsContent(t *testing.T) {
g := NewWithT(t)
// Stream three content deltas then [DONE]. Verify the channel
// receives them in order with no missing pieces.
chunks := []string{
`{"choices":[{"index":0,"delta":{"role":"assistant"}}]}`,
`{"choices":[{"index":0,"delta":{"content":"hello"}}]}`,
`{"choices":[{"index":0,"delta":{"content":" "}}]}`,
`{"choices":[{"index":0,"delta":{"content":"world"}}]}`,
`{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`,
}
body := ""
for _, c := range chunks {
body += "data: " + c + "\n\n"
}
body += "data: [DONE]\n\n"
srv, captured := fakeOpenAIUpstream(t, func(_ openAIRequest) (int, string, string) {
return 200, body, "text/event-stream"
})
defer srv.Close()
cp := newTranslateCloudProxy(t, srv.URL)
results := make(chan string, 8)
done := make(chan error, 1)
go func() {
done <- cp.PredictStream(&pb.PredictOptions{
Messages: []*pb.Message{{Role: "user", Content: "hi"}},
}, results)
}()
var got []string
for s := range results {
got = append(got, s)
}
err := <-done
g.Expect(err).NotTo(HaveOccurred())
g.Expect(strings.Join(got, "")).To(Equal("hello world"))
g.Expect(captured.Stream).To(BeTrue())
}
func TestPredict_RejectedInPassthroughMode(t *testing.T) {
g := NewWithT(t)
t.Setenv("CLOUD_PROXY_FAKE", "k")
cp := NewCloudProxy()
err := cp.Load(&pb.ModelOptions{
Proxy: &pb.ProxyOptions{
UpstreamUrl: "https://example.com",
Mode: modePassthrough,
ApiKeyEnv: "CLOUD_PROXY_FAKE",
},
})
g.Expect(err).NotTo(HaveOccurred())
_, err = cp.Predict(&pb.PredictOptions{})
g.Expect(err).To(HaveOccurred())
g.Expect(err.Error()).To(ContainSubstring("only valid in translate"))
}