mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-18 21:58:58 -04:00
feat(pii): NER tier engine — privacy-filter.cpp backend + NER-centric PII filter (#10360)
Squashed feat/pii-ner-tier-engine rebased onto master (was 45 commits; see backup/pii-ner-tier-engine-prerebase). Net change: - privacy-filter.cpp: standalone GGML engine for the openai-privacy-filter PII/NER token classifier, wired as a LocalAI gRPC backend (CPU/CUDA/Vulkan). TokenClassify moves off the patched llama.cpp path onto this backend. - PII filter reworked to be NER-centric (encoder/NER detection tier scanning whole conversations as one document), with a recreated bounded restricted- regex secret-matching pattern detector tier alongside it (per-model pii_detection.builtins / .patterns + core/services/routing/piipattern). - Detection labelled by source (ner vs pattern); backend trace / confidence / debug observability; analyze/redact exposed as a synchronous API. - Instance-wide default detector policy + per-usecase default-on; request filtering extended to completions, embeddings, edits & Ollama. - React UI: NER-centric PII editor, detector-models table, pattern/builtins editor, middleware default-policy UI. - Gallery: privacy-filter-multilingual token-classify model + NER install filter; token_classify known_usecase; batch sized to context for NER models. privacy-filter backend registered in the backend gallery (cpu/vulkan/cuda-13 meta + image entries with a capabilities map) matching its CI matrix jobs, and an /import-model auto-detect importer (PrivacyFilterImporter, narrow privacy-filter GGUF detection) replacing the prior pref-only registration. Reconciled against master's independent evolution: - Dropped master's PIIPatternOverrides feature (global-pattern runtime overrides + /api/pii/patterns API + runtime_settings.json persistence). The per-model NER + pattern-detector design supersedes it; it was built on the global redactor pattern set this branch replaced. - Reverted the llama.cpp Score carry-patch (0006-server-task-type-score): removed the patch and restored master's grpc-server.cpp Score RPC (direct llama_decode, slot-loop bypass) and LLAMA_VERSION pin, plus master's model_config validation forbidding score + chat/completion/embeddings on llama-cpp. token_classify is unaffected (it runs on the privacy-filter backend, not llama-cpp). Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
committed by
GitHub
parent
c133ca39dc
commit
3fa7b2955c
@@ -10,8 +10,6 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
corebackend "github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
pkggrpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -19,41 +17,14 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// BuildStreamFilter constructs the per-request streaming PII filter
|
||||
// for a cloud-proxy forward. Returns nil when the request isn't
|
||||
// streaming, PII is disabled for this model, or no redactor is wired
|
||||
// up — callers pass the result through unchanged. correlationID is
|
||||
// caller-supplied because the OpenAI and Anthropic endpoints read it
|
||||
// from different headers.
|
||||
func BuildStreamFilter(c echo.Context, cfg *config.ModelConfig, isStream bool, piiRedactor *pii.Redactor, piiEvents pii.EventStore, correlationID string) *pii.StreamFilter {
|
||||
if !isStream || piiRedactor == nil || !cfg.PIIIsEnabled() {
|
||||
return nil
|
||||
}
|
||||
userID := ""
|
||||
if u := auth.GetUser(c); u != nil {
|
||||
userID = u.ID
|
||||
}
|
||||
var overrides map[string]pii.Action
|
||||
if raw := cfg.PIIPatternOverrides(); len(raw) > 0 {
|
||||
overrides = make(map[string]pii.Action, len(raw))
|
||||
for ovid, action := range raw {
|
||||
switch pii.Action(action) {
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionAllow:
|
||||
overrides[ovid] = pii.Action(action)
|
||||
}
|
||||
}
|
||||
}
|
||||
return pii.NewStreamFilter(piiRedactor, overrides, piiEvents, correlationID, userID)
|
||||
}
|
||||
|
||||
// ForwardViaBackend loads the cloud-proxy gRPC backend, ships the
|
||||
// request via the Forward RPC, and pumps the response back to the
|
||||
// client through the SSE-aware PII pipeline.
|
||||
// client. PII redaction runs request-side (the NER middleware + MITM
|
||||
// input path); the response is forwarded unmodified.
|
||||
func ForwardViaBackend(
|
||||
c echo.Context,
|
||||
cfg *config.ModelConfig,
|
||||
body []byte,
|
||||
filter *pii.StreamFilter,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
) (resultErr error) {
|
||||
@@ -176,7 +147,7 @@ func ForwardViaBackend(
|
||||
return passthroughError(c, statusCode, contentType, bodyReader)
|
||||
}
|
||||
if isStream {
|
||||
return forwardStream(c, bodyReader, cfg.Proxy.Provider, filter)
|
||||
return forwardStream(c, bodyReader)
|
||||
}
|
||||
return forwardBuffered(c, statusCode, contentType, bodyReader)
|
||||
}
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
package cloudproxy
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("BuildStreamFilter", func() {
|
||||
var (
|
||||
c echo.Context
|
||||
cfg *config.ModelConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
piiOn := true
|
||||
cfg = &config.ModelConfig{
|
||||
Backend: "cloud-proxy",
|
||||
PII: config.PIIConfig{Enabled: &piiOn},
|
||||
}
|
||||
})
|
||||
|
||||
// Three guards must each independently force a nil return — proves
|
||||
// the gate is a logical AND, not an order-dependent short-circuit
|
||||
// that silently activates one branch.
|
||||
It("returns nil when isStream is false", func() {
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
r := pii.NewRedactor(patterns)
|
||||
Expect(BuildStreamFilter(c, cfg, false, r, nil, "corr-1")).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil when piiRedactor is nil", func() {
|
||||
Expect(BuildStreamFilter(c, cfg, true, nil, nil, "corr-1")).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil when the model has PII disabled", func() {
|
||||
piiOff := false
|
||||
cfg.PII.Enabled = &piiOff
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
r := pii.NewRedactor(patterns)
|
||||
Expect(BuildStreamFilter(c, cfg, true, r, nil, "corr-1")).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns a configured filter when all preconditions hold", func() {
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
r := pii.NewRedactor(patterns)
|
||||
store := pii.NewMemoryEventStore(8)
|
||||
filter := BuildStreamFilter(c, cfg, true, r, store, "corr-xyz")
|
||||
Expect(filter).NotTo(BeNil())
|
||||
})
|
||||
|
||||
// Empty correlationID is allowed — some entry points don't have one.
|
||||
// The filter must still construct so the stream can flow.
|
||||
It("constructs a filter even when correlationID is empty", func() {
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
r := pii.NewRedactor(patterns)
|
||||
Expect(BuildStreamFilter(c, cfg, true, r, nil, "")).NotTo(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/ssewire"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/piiadapter"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
@@ -24,8 +23,14 @@ import (
|
||||
|
||||
// PIIHandlerOptions configures NewPIIHandler.
|
||||
type PIIHandlerOptions struct {
|
||||
// Redactor is the regex PII redactor. nil disables redaction.
|
||||
Redactor *pii.Redactor
|
||||
// DetectorsByHost maps an intercepted host (lower-cased) to the NER
|
||||
// detector configs that should scan request bodies bound for it. The
|
||||
// configs are resolved at listener-start from each host's owning
|
||||
// model's pii.detectors + the detector models' pii_detection policy
|
||||
// (a model-config edit needs a MITM restart, as hosts already do). A
|
||||
// host absent from the map (or with an empty slice) is forwarded
|
||||
// unredacted. Detector errors at request time fail closed.
|
||||
DetectorsByHost map[string][]pii.NERConfig
|
||||
|
||||
// EventStore receives PIIEvent rows. nil discards events.
|
||||
EventStore pii.EventStore
|
||||
@@ -42,13 +47,6 @@ type PIIHandlerOptions struct {
|
||||
// upstream URL. Identity by default; tests inject a httptest
|
||||
// listener address.
|
||||
DialHost func(host string) string
|
||||
|
||||
// HostsWithPIIDisabled lists destination hosts whose request
|
||||
// bodies should NOT run through the redactor. TLS termination,
|
||||
// upstream forwarding, and audit events still happen — only the
|
||||
// regex pass is bypassed. Useful for telemetry/probe endpoints
|
||||
// whose bodies aren't PII-shaped.
|
||||
HostsWithPIIDisabled []string
|
||||
}
|
||||
|
||||
func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler {
|
||||
@@ -76,16 +74,9 @@ func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler {
|
||||
dialHost = func(h string) string { return h }
|
||||
}
|
||||
|
||||
patternAction := map[string]pii.Action{}
|
||||
if opts.Redactor != nil {
|
||||
for _, p := range opts.Redactor.Patterns() {
|
||||
patternAction[p.ID] = p.Action
|
||||
}
|
||||
}
|
||||
|
||||
piiDisabled := make(map[string]bool, len(opts.HostsWithPIIDisabled))
|
||||
for _, h := range opts.HostsWithPIIDisabled {
|
||||
piiDisabled[strings.ToLower(strings.TrimSpace(h))] = true
|
||||
detectorsByHost := make(map[string][]pii.NERConfig, len(opts.DetectorsByHost))
|
||||
for h, cfgs := range opts.DetectorsByHost {
|
||||
detectorsByHost[strings.ToLower(strings.TrimSpace(h))] = cfgs
|
||||
}
|
||||
|
||||
d := &piiDispatcher{
|
||||
@@ -96,26 +87,22 @@ func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler {
|
||||
// API keys such as Anthropic's x-api-key, which Go does NOT
|
||||
// strip on cross-host redirects — to an unvetted host. Surface
|
||||
// it as an error (handled as a 502) instead.
|
||||
client: httpclient.New(httpclient.WithTransport(transport)),
|
||||
redactor: opts.Redactor,
|
||||
store: opts.EventStore,
|
||||
patternAction: patternAction,
|
||||
corrHeader: corrHeader,
|
||||
dialHost: dialHost,
|
||||
piiDisabled: piiDisabled,
|
||||
client: httpclient.New(httpclient.WithTransport(transport)),
|
||||
detectorsByHost: detectorsByHost,
|
||||
store: opts.EventStore,
|
||||
corrHeader: corrHeader,
|
||||
dialHost: dialHost,
|
||||
}
|
||||
return d.serve
|
||||
}
|
||||
|
||||
type piiDispatcher struct {
|
||||
client *http.Client
|
||||
redactor *pii.Redactor
|
||||
store pii.EventStore
|
||||
patternAction map[string]pii.Action
|
||||
corrHeader string
|
||||
dialHost func(host string) string
|
||||
piiDisabled map[string]bool
|
||||
eventSeq atomic.Uint64
|
||||
client *http.Client
|
||||
detectorsByHost map[string][]pii.NERConfig
|
||||
store pii.EventStore
|
||||
corrHeader string
|
||||
dialHost func(host string) string
|
||||
eventSeq atomic.Uint64
|
||||
}
|
||||
|
||||
func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host string) {
|
||||
@@ -144,11 +131,17 @@ func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host strin
|
||||
}
|
||||
|
||||
shape := classifyRequestShape(host, r.URL.Path)
|
||||
if d.redactor != nil && shape != shapeUnknown && !d.piiDisabled[strings.ToLower(host)] {
|
||||
redacted, blocked, err := d.redactRequest(body, shape, correlationID)
|
||||
cfgs := d.detectorsByHost[strings.ToLower(host)]
|
||||
if len(cfgs) > 0 && shape != shapeUnknown {
|
||||
redacted, blocked, err := d.redactRequest(r.Context(), body, shape, cfgs, correlationID)
|
||||
switch {
|
||||
case err != nil:
|
||||
xlog.Debug("mitm: redact request failed; forwarding unchanged", "host", host, "path", r.URL.Path, "error", err)
|
||||
// Fail closed: a detector outage must not silently forward the
|
||||
// request unredacted — the operator configured this host's
|
||||
// model with detectors precisely to catch this PII.
|
||||
xlog.Error("mitm: NER redaction failed; blocking request (fail-closed)", "host", host, "path", r.URL.Path, "error", err)
|
||||
writePIIBlocked(w, correlationID)
|
||||
return
|
||||
case blocked:
|
||||
writePIIBlocked(w, correlationID)
|
||||
return
|
||||
@@ -185,12 +178,10 @@ func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host strin
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// Response/output redaction is out of scope for now — the MITM proxy
|
||||
// only scans request bodies (input). SSE responses pass through
|
||||
// unmodified.
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if shape != shapeUnknown && d.redactor != nil && isSSE(contentType) {
|
||||
d.streamWithPII(w, resp.Body, shape, correlationID)
|
||||
return
|
||||
}
|
||||
|
||||
if isSSE(contentType) {
|
||||
flusher, _ := w.(http.Flusher)
|
||||
buf := make([]byte, 32*1024)
|
||||
@@ -232,7 +223,7 @@ func classifyRequestShape(host, path string) requestShape {
|
||||
return shapeUnknown
|
||||
}
|
||||
|
||||
func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlationID string) ([]byte, bool, error) {
|
||||
func (d *piiDispatcher) redactRequest(ctx context.Context, body []byte, shape requestShape, cfgs []pii.NERConfig, correlationID string) ([]byte, bool, error) {
|
||||
var parsed any
|
||||
var adapter pii.Adapter
|
||||
switch shape {
|
||||
@@ -259,13 +250,21 @@ func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlati
|
||||
return body, false, nil
|
||||
}
|
||||
|
||||
// One scan over the joined messages so the NER tier keeps
|
||||
// conversational context (see pii.RedactNERSegments); results map
|
||||
// back per message with local offsets.
|
||||
segTexts := make([]string, len(texts))
|
||||
for i, st := range texts {
|
||||
segTexts[i] = st.Text
|
||||
}
|
||||
results, err := pii.RedactNERSegments(ctx, segTexts, cfgs)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("ner detect: %w", err)
|
||||
}
|
||||
|
||||
updates := make([]pii.ScannedText, 0, len(texts))
|
||||
blocked := false
|
||||
for _, st := range texts {
|
||||
if st.Text == "" {
|
||||
continue
|
||||
}
|
||||
res := d.redactor.RedactWithOverrides(st.Text, nil)
|
||||
for i, res := range results {
|
||||
if len(res.Spans) == 0 {
|
||||
continue
|
||||
}
|
||||
@@ -273,7 +272,7 @@ func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlati
|
||||
if res.Blocked {
|
||||
blocked = true
|
||||
}
|
||||
updates = append(updates, pii.ScannedText{Index: st.Index, Text: res.Redacted})
|
||||
updates = append(updates, pii.ScannedText{Index: texts[i].Index, Text: res.Redacted})
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
@@ -295,13 +294,14 @@ func (d *piiDispatcher) recordEvents(spans []pii.Span, correlationID string) {
|
||||
ev := pii.PIIEvent{
|
||||
ID: fmt.Sprintf("mitm_%s_%d", correlationID, d.eventSeq.Add(1)),
|
||||
Kind: pii.KindPII,
|
||||
Origin: pii.OriginProxy,
|
||||
CorrelationID: correlationID,
|
||||
Direction: pii.DirectionIn,
|
||||
PatternID: span.Pattern,
|
||||
ByteOffset: span.Start,
|
||||
Length: span.End - span.Start,
|
||||
HashPrefix: span.HashPrefix,
|
||||
Action: d.patternAction[span.Pattern],
|
||||
Action: span.Action,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := d.store.Record(context.Background(), ev); err != nil {
|
||||
@@ -310,49 +310,6 @@ func (d *piiDispatcher) recordEvents(spans []pii.Span, correlationID string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *piiDispatcher) streamWithPII(w http.ResponseWriter, src io.Reader, shape requestShape, correlationID string) {
|
||||
flusher, _ := w.(http.Flusher)
|
||||
filter := pii.NewStreamFilter(d.redactor, nil, d.store, correlationID, "")
|
||||
|
||||
provider := ssewire.OpenAI
|
||||
if shape == shapeAnthropicMessages {
|
||||
provider = ssewire.Anthropic
|
||||
}
|
||||
|
||||
emit := func(s string) {
|
||||
_, _ = w.Write([]byte(s))
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
scanner := ssewire.NewScanner(src)
|
||||
for scanner.Scan() {
|
||||
ev := scanner.Event()
|
||||
if ssewire.IsTerminalMarker(ev.DataLine, provider) {
|
||||
if residual := filter.Drain(); residual != "" {
|
||||
emit(ssewire.SynthResidualEvent(provider, residual))
|
||||
}
|
||||
emit(ev.Raw)
|
||||
continue
|
||||
}
|
||||
out := ev.Raw
|
||||
if ev.DataLine != "" {
|
||||
rewritten, drop := ssewire.RewritePayload(ev.DataLine, provider, filter)
|
||||
if drop {
|
||||
continue
|
||||
}
|
||||
if rewritten != ev.DataLine {
|
||||
out = strings.Replace(ev.Raw, ev.DataLine, rewritten, 1)
|
||||
}
|
||||
}
|
||||
emit(out)
|
||||
}
|
||||
if residual := filter.Drain(); residual != "" {
|
||||
emit(ssewire.SynthResidualEvent(provider, residual))
|
||||
}
|
||||
}
|
||||
|
||||
func writePIIBlocked(w http.ResponseWriter, correlationID string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
@@ -19,34 +19,58 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// startPIITestRig is the same shape as startMITMTestRig but plugs
|
||||
// in the production PII handler instead of the passthrough fixture.
|
||||
// The "host" the client thinks it's reaching is forced to
|
||||
// api.anthropic.com so the request shape classifier matches.
|
||||
// substringDetector is a deterministic pii.NERDetector for tests: it
|
||||
// reports an entity for every occurrence of each configured substring,
|
||||
// with byte offsets into the scanned text. Lets the MITM tests drive
|
||||
// request redaction without a real token-classification backend.
|
||||
type substringDetector struct{ groups map[string]string } // substring -> entity group
|
||||
|
||||
func (d substringDetector) Detect(_ context.Context, text string) ([]pii.NEREntity, error) {
|
||||
var out []pii.NEREntity
|
||||
for sub, group := range d.groups {
|
||||
for idx := 0; ; {
|
||||
i := strings.Index(text[idx:], sub)
|
||||
if i < 0 {
|
||||
break
|
||||
}
|
||||
start := idx + i
|
||||
out = append(out, pii.NEREntity{Group: group, Start: start, End: start + len(sub), Score: 1})
|
||||
idx = start + len(sub)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// testDetectorCfg flags emails (mask) and a known secret token (block).
|
||||
func testDetectorCfg() pii.NERConfig {
|
||||
return pii.NERConfig{
|
||||
Detector: substringDetector{groups: map[string]string{
|
||||
"alice@example.com": "EMAIL",
|
||||
"bob@example.org": "EMAIL",
|
||||
"sk-abcdefghijklmnopqrstuvwxyz1234": "PASSWORD",
|
||||
}},
|
||||
EntityActions: map[string]pii.Action{"EMAIL": pii.ActionMask, "PASSWORD": pii.ActionBlock},
|
||||
}
|
||||
}
|
||||
|
||||
// startPIITestRig plugs the production PII handler into a CONNECT proxy,
|
||||
// with the upstream playing the role of api.anthropic.com. Request
|
||||
// bodies bound for api.anthropic.com run through the NER detector above.
|
||||
func startPIITestRig(upstream http.Handler) (*http.Client, string, *fakeStore, func()) {
|
||||
// Upstream fake — plays the role of api.anthropic.com.
|
||||
ts := httptest.NewTLSServer(upstream)
|
||||
upstreamCertPool := x509.NewCertPool()
|
||||
upstreamCertPool.AddCert(ts.Certificate())
|
||||
upstreamURL, _ := url.Parse(ts.URL)
|
||||
|
||||
// Compiled patterns required for the redactor to actually fire
|
||||
// (DefaultPatterns alone returns Pattern structs without regex).
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred())
|
||||
redactor := pii.NewRedactor(patterns)
|
||||
store := &fakeStore{}
|
||||
|
||||
ca, err := NewInMemoryCA()
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred())
|
||||
|
||||
// DialHost remaps the upstream dial target to the httptest
|
||||
// fake while leaving the classifier-facing host
|
||||
// ("api.anthropic.com") untouched. ServerName=example.com is
|
||||
// what httptest.NewTLSServer issues its cert for.
|
||||
upstreamHost := upstreamURL.Host
|
||||
prodHandler := NewPIIHandler(PIIHandlerOptions{
|
||||
Redactor: redactor,
|
||||
DetectorsByHost: map[string][]pii.NERConfig{
|
||||
"api.anthropic.com": {testDetectorCfg()},
|
||||
},
|
||||
EventStore: store,
|
||||
UpstreamTLS: &tls.Config{
|
||||
RootCAs: upstreamCertPool,
|
||||
@@ -79,8 +103,6 @@ func startPIITestRig(upstream http.Handler) (*http.Client, string, *fakeStore, f
|
||||
srv.Stop()
|
||||
ts.Close()
|
||||
}
|
||||
// We point requests at api.anthropic.com so classifyRequestShape
|
||||
// matches; the wrappedHandler retargets to the upstream fake.
|
||||
return client, "https://api.anthropic.com", store, cleanup
|
||||
}
|
||||
|
||||
@@ -101,7 +123,7 @@ func (s *fakeStore) Close() error { return nil }
|
||||
func (s *fakeStore) recorded() int { return len(s.events) }
|
||||
|
||||
var _ = Describe("PIIHandler", func() {
|
||||
It("redacts request email", func() {
|
||||
It("redacts request email via NER", func() {
|
||||
var receivedBody []byte
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedBody, _ = io.ReadAll(r.Body)
|
||||
@@ -119,15 +141,11 @@ var _ = Describe("PIIHandler", func() {
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
Expect(string(receivedBody)).NotTo(ContainSubstring("alice@example.com"), "upstream received unredacted body")
|
||||
Expect(string(receivedBody)).To(ContainSubstring("[REDACTED:email]"), "upstream did not see redaction marker")
|
||||
Expect(string(receivedBody)).To(ContainSubstring("[REDACTED:ner:EMAIL]"), "upstream did not see redaction marker")
|
||||
Expect(store.recorded()).NotTo(BeZero(), "no PIIEvent recorded for the email match")
|
||||
})
|
||||
|
||||
It("refuses to follow an upstream redirect", func() {
|
||||
// A 3xx from the upstream would otherwise be followed, replaying
|
||||
// the request (and its provider API key, e.g. Anthropic's
|
||||
// x-api-key which Go does NOT strip on cross-host redirects) to
|
||||
// the Location host. The refused redirect surfaces as a 502.
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "https://evil.example.com/steal", http.StatusFound)
|
||||
})
|
||||
@@ -142,7 +160,7 @@ var _ = Describe("PIIHandler", func() {
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusBadGateway), "refused redirect must surface as 502, not be followed")
|
||||
})
|
||||
|
||||
It("blocks api key in request", func() {
|
||||
It("blocks a detected secret in the request", func() {
|
||||
upstreamCalled := false
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
upstreamCalled = true
|
||||
@@ -156,46 +174,13 @@ var _ = Describe("PIIHandler", func() {
|
||||
resp, err := client.Post(base+"/v1/messages", "application/json", strings.NewReader(body))
|
||||
Expect(err).NotTo(HaveOccurred(), "client.Post")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
Expect(resp.StatusCode).To(Equal(400), "api_key_prefix has Block default")
|
||||
Expect(resp.StatusCode).To(Equal(400), "PASSWORD entity action is block")
|
||||
Expect(upstreamCalled).To(BeFalse(), "upstream was called despite block — proxy should short-circuit")
|
||||
body2, _ := io.ReadAll(resp.Body)
|
||||
Expect(string(body2)).To(ContainSubstring("pii_blocked"))
|
||||
})
|
||||
|
||||
It("streaming redaction", func() {
|
||||
// Anthropic-shape SSE; "alice@" + "example.com" splits the
|
||||
// email across chunks so the StreamFilter has to buffer.
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
flusher := w.(http.Flusher)
|
||||
chunks := []string{
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"contact me at alice@"}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"example.com any time"}}`,
|
||||
`{"type":"message_stop"}`,
|
||||
}
|
||||
for _, c := range chunks {
|
||||
_, _ = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", "content_block_delta", c)
|
||||
flusher.Flush()
|
||||
}
|
||||
})
|
||||
|
||||
client, base, _, cleanup := startPIITestRig(upstream)
|
||||
defer cleanup()
|
||||
|
||||
body := `{"model":"claude-3-5-sonnet","max_tokens":100,"stream":true,"messages":[{"role":"user","content":"hi"}]}`
|
||||
resp, err := client.Post(base+"/v1/messages", "application/json", strings.NewReader(body))
|
||||
Expect(err).NotTo(HaveOccurred(), "Post")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
out, _ := io.ReadAll(resp.Body)
|
||||
outStr := string(out)
|
||||
Expect(outStr).NotTo(ContainSubstring("alice@example.com"), "email leaked through MITM stream")
|
||||
Expect(outStr).To(ContainSubstring("[REDACTED:email]"), "redaction marker missing from MITM stream")
|
||||
})
|
||||
|
||||
It("non-chat path passes through", func() {
|
||||
// A path the classifier doesn't recognise (e.g. an OAuth
|
||||
// callback) must forward the body verbatim, no PII parsing.
|
||||
var receivedBody []byte
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedBody, _ = io.ReadAll(r.Body)
|
||||
@@ -216,14 +201,12 @@ var _ = Describe("PIIHandler", func() {
|
||||
|
||||
var _ = Describe("redactRequest", func() {
|
||||
It("handles anthropic shape", func() {
|
||||
patterns, _ := pii.Compile(pii.DefaultPatterns())
|
||||
r := pii.NewRedactor(patterns)
|
||||
body := []byte(`{"model":"claude","max_tokens":10,"messages":[{"role":"user","content":"reach me at bob@example.org"}]}`)
|
||||
|
||||
d := &piiDispatcher{redactor: r, patternAction: map[string]pii.Action{}}
|
||||
out, blocked, err := d.redactRequest(body, shapeAnthropicMessages, "corr-1")
|
||||
d := &piiDispatcher{}
|
||||
out, blocked, err := d.redactRequest(context.Background(), body, shapeAnthropicMessages, []pii.NERConfig{testDetectorCfg()}, "corr-1")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(blocked).To(BeFalse(), "email is mask, not block — blocked should be false")
|
||||
Expect(blocked).To(BeFalse(), "EMAIL is mask, not block — blocked should be false")
|
||||
var parsed map[string]any
|
||||
Expect(json.Unmarshal(out, &parsed)).To(Succeed())
|
||||
msgs := parsed["messages"].([]any)
|
||||
@@ -273,9 +256,6 @@ var _ = Describe("Proxy events", func() {
|
||||
})
|
||||
|
||||
It("tunneled host emits connect event only", func() {
|
||||
// A non-allowlisted CONNECT must record a proxy_connect with
|
||||
// Intercepted=false and NOT a proxy_traffic event (tunneled
|
||||
// bytes never reach the dispatcher).
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = fmt.Fprint(w, "passthrough")
|
||||
})
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
// Package cloudproxy stitches the cloud-proxy gRPC backend to the
|
||||
// HTTP edge: model rewrite, body shaping, and SSE-aware PII filtering
|
||||
// on the response. The outbound HTTP request itself lives inside the
|
||||
// cloud-proxy backend binary (backend/go/cloud-proxy), not here — this
|
||||
// package is the core-side glue.
|
||||
// HTTP edge: model rewrite and body shaping. The outbound HTTP request
|
||||
// itself lives inside the cloud-proxy backend binary
|
||||
// (backend/go/cloud-proxy), not here — this package is the core-side
|
||||
// glue. PII redaction runs request-side (the NER middleware + MITM
|
||||
// input path); response/output is forwarded unmodified.
|
||||
package cloudproxy
|
||||
|
||||
import (
|
||||
@@ -10,11 +11,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/ssewire"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -61,65 +59,30 @@ func forwardBuffered(c echo.Context, statusCode int, contentType string, body io
|
||||
return err
|
||||
}
|
||||
|
||||
// forwardStream applies SSE-aware PII rewriting as the response flows
|
||||
// to the client. provider selects the dialect (openai vs anthropic);
|
||||
// it comes from cfg.Proxy.Provider on the cloud-proxy backend.
|
||||
func forwardStream(c echo.Context, body io.Reader, provider string, filter *pii.StreamFilter) error {
|
||||
// forwardStream relays the upstream SSE response to the client,
|
||||
// flushing per read so events arrive in real time. Response/output PII
|
||||
// redaction is out of scope for now, so the stream is forwarded
|
||||
// unmodified.
|
||||
func forwardStream(c echo.Context, body io.Reader) error {
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().WriteHeader(http.StatusOK)
|
||||
|
||||
emit := func(line string) error {
|
||||
_, err := fmt.Fprint(c.Response().Writer, line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
flushResidual := func() {
|
||||
if filter == nil {
|
||||
return
|
||||
}
|
||||
residual := filter.Drain()
|
||||
if residual == "" {
|
||||
return
|
||||
}
|
||||
if line := ssewire.SynthResidualEvent(ssewire.Provider(provider), residual); line != "" {
|
||||
_ = emit(line)
|
||||
}
|
||||
}
|
||||
|
||||
prov := ssewire.Provider(provider)
|
||||
scanner := ssewire.NewScanner(body)
|
||||
for scanner.Scan() {
|
||||
ev := scanner.Event()
|
||||
if ssewire.IsTerminalMarker(ev.DataLine, prov) {
|
||||
flushResidual()
|
||||
_ = emit(ev.Raw)
|
||||
continue
|
||||
}
|
||||
out := ev.Raw
|
||||
if filter != nil && ev.DataLine != "" {
|
||||
rewritten, drop := ssewire.RewritePayload(ev.DataLine, prov, filter)
|
||||
if drop {
|
||||
continue
|
||||
}
|
||||
if rewritten != ev.DataLine {
|
||||
// strings.Replace with n=1 touches only the data line,
|
||||
// preserving any "event:"/"id:" preamble.
|
||||
out = strings.Replace(ev.Raw, ev.DataLine, rewritten, 1)
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, rErr := body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, wErr := c.Response().Writer.Write(buf[:n]); wErr != nil {
|
||||
return nil
|
||||
}
|
||||
c.Response().Flush()
|
||||
}
|
||||
if err := emit(out); err != nil {
|
||||
if rErr != nil {
|
||||
if rErr != io.EOF {
|
||||
xlog.Debug("cloudproxy: stream read error", "error", rErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
xlog.Debug("cloudproxy: stream read error", "error", err)
|
||||
}
|
||||
flushResidual()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
// Package ssewire holds the SSE-format helpers shared between
|
||||
// the request-shape cloud proxy (core/services/cloudproxy) and the
|
||||
// TLS-terminating MITM proxy (core/services/cloudproxy/mitm). Both
|
||||
// run a pii.StreamFilter over per-token text extracted from
|
||||
// provider-specific JSON chunks; this package owns the JSON shapes
|
||||
// so a future provider addition is one edit, not two.
|
||||
package ssewire
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
|
||||
// Provider is the upstream wire format an SSE stream conforms to.
|
||||
type Provider string
|
||||
|
||||
const (
|
||||
OpenAI Provider = "openai"
|
||||
Anthropic Provider = "anthropic"
|
||||
)
|
||||
|
||||
// Event is one SSE event with its exact wire bytes preserved in
|
||||
// Raw (so unmodified events round-trip byte-for-byte) and the
|
||||
// extracted JSON payload from the data: line in DataLine.
|
||||
type Event struct {
|
||||
Raw string
|
||||
DataLine string
|
||||
}
|
||||
|
||||
// Scanner reads SSE events one at a time from an upstream body.
|
||||
type Scanner struct {
|
||||
r *bufio.Reader
|
||||
ev Event
|
||||
err error
|
||||
}
|
||||
|
||||
func NewScanner(r io.Reader) *Scanner {
|
||||
return &Scanner{r: bufio.NewReaderSize(r, 64*1024)}
|
||||
}
|
||||
|
||||
func (s *Scanner) Scan() bool {
|
||||
var raw strings.Builder
|
||||
var dataLine string
|
||||
for {
|
||||
line, err := s.r.ReadString('\n')
|
||||
if line != "" {
|
||||
raw.WriteString(line)
|
||||
trimmed := strings.TrimRight(line, "\r\n")
|
||||
if trimmed == "" {
|
||||
if raw.Len() == len(line) {
|
||||
raw.Reset()
|
||||
continue
|
||||
}
|
||||
s.ev = Event{Raw: raw.String(), DataLine: dataLine}
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "data:") && dataLine == "" {
|
||||
payload := strings.TrimPrefix(trimmed, "data:")
|
||||
payload = strings.TrimPrefix(payload, " ")
|
||||
dataLine = payload
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
s.err = err
|
||||
if raw.Len() > 0 {
|
||||
s.ev = Event{Raw: raw.String(), DataLine: dataLine}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scanner) Event() Event { return s.ev }
|
||||
func (s *Scanner) Err() error { return s.err }
|
||||
|
||||
// IsTerminalMarker reports whether the data line is the per-provider
|
||||
// end-of-stream sentinel. The streaming PII filter must drain its
|
||||
// residue before the caller forwards a terminal marker — clients
|
||||
// stop reading after it.
|
||||
func IsTerminalMarker(dataLine string, provider Provider) bool {
|
||||
if dataLine == "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(dataLine) == "[DONE]" {
|
||||
return true
|
||||
}
|
||||
if provider == Anthropic {
|
||||
var probe struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(dataLine), &probe); err == nil {
|
||||
return probe.Type == "message_stop"
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RewritePayload runs the data line's content-bearing field through
|
||||
// the streaming filter. drop=true tells the caller to suppress the
|
||||
// SSE event entirely (the filter buffered the whole token while
|
||||
// disambiguating a pattern boundary).
|
||||
func RewritePayload(dataLine string, provider Provider, filter *pii.StreamFilter) (rewritten string, drop bool) {
|
||||
if strings.TrimSpace(dataLine) == "[DONE]" {
|
||||
return dataLine, false
|
||||
}
|
||||
switch provider {
|
||||
case Anthropic:
|
||||
return rewriteAnthropic(dataLine, filter)
|
||||
default:
|
||||
return rewriteOpenAI(dataLine, filter)
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteOpenAI(dataLine string, filter *pii.StreamFilter) (string, bool) {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal([]byte(dataLine), &m); err != nil {
|
||||
return dataLine, false
|
||||
}
|
||||
choices, ok := m["choices"].([]any)
|
||||
if !ok || len(choices) == 0 {
|
||||
return dataLine, false
|
||||
}
|
||||
first, ok := choices[0].(map[string]any)
|
||||
if !ok {
|
||||
return dataLine, false
|
||||
}
|
||||
delta, ok := first["delta"].(map[string]any)
|
||||
if !ok {
|
||||
return dataLine, false
|
||||
}
|
||||
content, ok := delta["content"].(string)
|
||||
if !ok || content == "" {
|
||||
return dataLine, false
|
||||
}
|
||||
rewritten := filter.Push(content)
|
||||
if rewritten == "" {
|
||||
return "", true
|
||||
}
|
||||
if rewritten == content {
|
||||
return dataLine, false
|
||||
}
|
||||
delta["content"] = rewritten
|
||||
out, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return dataLine, false
|
||||
}
|
||||
return string(out), false
|
||||
}
|
||||
|
||||
func rewriteAnthropic(dataLine string, filter *pii.StreamFilter) (string, bool) {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal([]byte(dataLine), &m); err != nil {
|
||||
return dataLine, false
|
||||
}
|
||||
if t, _ := m["type"].(string); t != "content_block_delta" {
|
||||
return dataLine, false
|
||||
}
|
||||
delta, ok := m["delta"].(map[string]any)
|
||||
if !ok {
|
||||
return dataLine, false
|
||||
}
|
||||
if dt, _ := delta["type"].(string); dt != "text_delta" {
|
||||
return dataLine, false
|
||||
}
|
||||
text, ok := delta["text"].(string)
|
||||
if !ok || text == "" {
|
||||
return dataLine, false
|
||||
}
|
||||
rewritten := filter.Push(text)
|
||||
if rewritten == "" {
|
||||
return "", true
|
||||
}
|
||||
if rewritten == text {
|
||||
return dataLine, false
|
||||
}
|
||||
delta["text"] = rewritten
|
||||
out, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return dataLine, false
|
||||
}
|
||||
return string(out), false
|
||||
}
|
||||
|
||||
// SynthResidualEvent builds a provider-shaped SSE event carrying
|
||||
// the streaming filter's drained tail so the response body remains
|
||||
// a valid event stream after the proxy splices in held-back text.
|
||||
func SynthResidualEvent(provider Provider, text string) string {
|
||||
switch provider {
|
||||
case Anthropic:
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return "event: content_block_delta\ndata: " + string(b) + "\n\n"
|
||||
default:
|
||||
payload := map[string]any{
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": []map[string]any{
|
||||
{"index": 0, "delta": map[string]string{"content": text}},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return "data: " + string(b) + "\n\n"
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package ssewire
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestSsewire(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "ssewire test suite")
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
package ssewire
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Scanner contract: returns one Event per double-newline-terminated
|
||||
// SSE block, preserving the raw bytes (so unmodified events round-trip
|
||||
// exactly) and extracting the first data: payload as DataLine.
|
||||
|
||||
var _ = Describe("Scanner", func() {
|
||||
It("scans a basic event", func() {
|
||||
in := "event: foo\ndata: hello\n\n"
|
||||
s := NewScanner(strings.NewReader(in))
|
||||
Expect(s.Scan()).To(BeTrue(), "Scan returned false on a well-formed event; err=%v", s.Err())
|
||||
ev := s.Event()
|
||||
Expect(ev.Raw).To(Equal(in))
|
||||
Expect(ev.DataLine).To(Equal("hello"))
|
||||
Expect(s.Scan()).To(BeFalse(), "Scan should return false after the only event")
|
||||
})
|
||||
|
||||
It("handles CRLF", func() {
|
||||
// Some upstreams emit CRLF instead of LF. The scanner trims
|
||||
// trailing \r off the data line so DataLine carries the same
|
||||
// bytes whichever line ending the producer chose.
|
||||
in := "event: foo\r\ndata: hello\r\n\r\n"
|
||||
s := NewScanner(strings.NewReader(in))
|
||||
Expect(s.Scan()).To(BeTrue(), "Scan returned false on CRLF event; err=%v", s.Err())
|
||||
Expect(s.Event().DataLine).To(Equal("hello"))
|
||||
})
|
||||
|
||||
It("scans multiple events", func() {
|
||||
in := "data: one\n\ndata: two\n\ndata: three\n\n"
|
||||
s := NewScanner(strings.NewReader(in))
|
||||
got := []string{}
|
||||
for s.Scan() {
|
||||
got = append(got, s.Event().DataLine)
|
||||
}
|
||||
Expect(got).To(Equal([]string{"one", "two", "three"}))
|
||||
})
|
||||
|
||||
It("handles empty data payload", func() {
|
||||
// "data:" with no payload is valid SSE — DataLine should be empty
|
||||
// and Scan should still surface the event so callers can decide.
|
||||
in := "data:\n\n"
|
||||
s := NewScanner(strings.NewReader(in))
|
||||
Expect(s.Scan()).To(BeTrue(), "Scan returned false on empty data payload; err=%v", s.Err())
|
||||
Expect(s.Event().DataLine).To(Equal(""))
|
||||
})
|
||||
|
||||
It("skips leading blank lines", func() {
|
||||
// A producer that prints a blank "keep-alive" before the first
|
||||
// real event must not produce a phantom event.
|
||||
in := "\n\n\ndata: real\n\n"
|
||||
s := NewScanner(strings.NewReader(in))
|
||||
Expect(s.Scan()).To(BeTrue(), "Scan returned false; err=%v", s.Err())
|
||||
Expect(s.Event().DataLine).To(Equal("real"))
|
||||
})
|
||||
|
||||
It("handles mid-event EOF", func() {
|
||||
// EOF mid-event still surfaces the partial event with whatever
|
||||
// data was extracted — the StreamFilter+caller decides how to
|
||||
// handle a truncated upstream rather than silently dropping it.
|
||||
in := "data: half"
|
||||
s := NewScanner(strings.NewReader(in))
|
||||
Expect(s.Scan()).To(BeTrue(), "Scan returned false on partial event")
|
||||
ev := s.Event()
|
||||
Expect(ev.DataLine).To(Equal("half"))
|
||||
Expect(s.Scan()).To(BeFalse(), "Scan should not surface a second event after EOF")
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("IsTerminalMarker", func() {
|
||||
cases := []struct {
|
||||
name string
|
||||
dataLine string
|
||||
provider Provider
|
||||
want bool
|
||||
}{
|
||||
{"openai DONE", "[DONE]", OpenAI, true},
|
||||
{"openai DONE with whitespace", " [DONE] ", OpenAI, true},
|
||||
{"anthropic DONE also recognised", "[DONE]", Anthropic, true},
|
||||
{"anthropic message_stop", `{"type":"message_stop"}`, Anthropic, true},
|
||||
{"anthropic content_block_delta is not terminal", `{"type":"content_block_delta"}`, Anthropic, false},
|
||||
{"openai chat.completion.chunk is not terminal", `{"object":"chat.completion.chunk"}`, OpenAI, false},
|
||||
{"openai message_stop is not terminal (wrong provider)", `{"type":"message_stop"}`, OpenAI, false},
|
||||
{"empty data", "", OpenAI, false},
|
||||
{"non-json garbage", "garbage", Anthropic, false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
It(c.name, func() {
|
||||
Expect(IsTerminalMarker(c.dataLine, c.provider)).To(Equal(c.want))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
var _ = Describe("SynthResidualEvent", func() {
|
||||
It("anthropic", func() {
|
||||
got := SynthResidualEvent(Anthropic, "tail")
|
||||
Expect(strings.HasPrefix(got, "event: content_block_delta\ndata:")).To(BeTrue(), "Anthropic residual event missing event/data lines: %q", got)
|
||||
Expect(strings.HasSuffix(got, "\n\n")).To(BeTrue(), "Anthropic residual event missing trailing blank line: %q", got)
|
||||
Expect(got).To(ContainSubstring(`"text":"tail"`))
|
||||
})
|
||||
|
||||
It("openai", func() {
|
||||
got := SynthResidualEvent(OpenAI, "tail")
|
||||
Expect(strings.HasPrefix(got, "data: ")).To(BeTrue(), "OpenAI residual event missing data: prefix: %q", got)
|
||||
Expect(strings.HasSuffix(got, "\n\n")).To(BeTrue(), "OpenAI residual event missing trailing blank line: %q", got)
|
||||
Expect(got).To(ContainSubstring(`"content":"tail"`))
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user