feat(pii): NER tier engine — privacy-filter.cpp backend + NER-centric PII filter (#10360)

Squashed feat/pii-ner-tier-engine rebased onto master (was 45 commits; see
backup/pii-ner-tier-engine-prerebase). Net change:

- privacy-filter.cpp: standalone GGML engine for the openai-privacy-filter
  PII/NER token classifier, wired as a LocalAI gRPC backend (CPU/CUDA/Vulkan).
  TokenClassify moves off the patched llama.cpp path onto this backend.
- PII filter reworked to be NER-centric (encoder/NER detection tier scanning
  whole conversations as one document), with a recreated bounded restricted-
  regex secret-matching pattern detector tier alongside it (per-model
  pii_detection.builtins / .patterns + core/services/routing/piipattern).
- Detection labelled by source (ner vs pattern); backend trace / confidence /
  debug observability; analyze/redact exposed as a synchronous API.
- Instance-wide default detector policy + per-usecase default-on; request
  filtering extended to completions, embeddings, edits & Ollama.
- React UI: NER-centric PII editor, detector-models table, pattern/builtins
  editor, middleware default-policy UI.
- Gallery: privacy-filter-multilingual token-classify model + NER install
  filter; token_classify known_usecase; batch sized to context for NER models.
  privacy-filter backend registered in the backend gallery (cpu/vulkan/cuda-13
  meta + image entries with a capabilities map) matching its CI matrix jobs,
  and an /import-model auto-detect importer (PrivacyFilterImporter, narrow
  privacy-filter GGUF detection) replacing the prior pref-only registration.

Reconciled against master's independent evolution:

- Dropped master's PIIPatternOverrides feature (global-pattern runtime
  overrides + /api/pii/patterns API + runtime_settings.json persistence). The
  per-model NER + pattern-detector design supersedes it; it was built on the
  global redactor pattern set this branch replaced.
- Reverted the llama.cpp Score carry-patch (0006-server-task-type-score):
  removed the patch and restored master's grpc-server.cpp Score RPC (direct
  llama_decode, slot-loop bypass) and LLAMA_VERSION pin, plus master's
  model_config validation forbidding score + chat/completion/embeddings on
  llama-cpp. token_classify is unaffected (it runs on the privacy-filter
  backend, not llama-cpp).

Assisted-by: Claude:claude-opus-4-8 [Claude Code]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
Richard Palethorpe
2026-06-18 11:45:22 +01:00
committed by GitHub
parent c133ca39dc
commit 3fa7b2955c
134 changed files with 6671 additions and 4223 deletions

View File

@@ -10,8 +10,6 @@ import (
"github.com/labstack/echo/v4"
corebackend "github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/http/auth"
"github.com/mudler/LocalAI/core/services/routing/pii"
"github.com/mudler/LocalAI/core/trace"
pkggrpc "github.com/mudler/LocalAI/pkg/grpc"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
@@ -19,41 +17,14 @@ import (
"github.com/mudler/xlog"
)
// BuildStreamFilter constructs the per-request streaming PII filter
// for a cloud-proxy forward. Returns nil when the request isn't
// streaming, PII is disabled for this model, or no redactor is wired
// up — callers pass the result through unchanged. correlationID is
// caller-supplied because the OpenAI and Anthropic endpoints read it
// from different headers.
func BuildStreamFilter(c echo.Context, cfg *config.ModelConfig, isStream bool, piiRedactor *pii.Redactor, piiEvents pii.EventStore, correlationID string) *pii.StreamFilter {
if !isStream || piiRedactor == nil || !cfg.PIIIsEnabled() {
return nil
}
userID := ""
if u := auth.GetUser(c); u != nil {
userID = u.ID
}
var overrides map[string]pii.Action
if raw := cfg.PIIPatternOverrides(); len(raw) > 0 {
overrides = make(map[string]pii.Action, len(raw))
for ovid, action := range raw {
switch pii.Action(action) {
case pii.ActionMask, pii.ActionBlock, pii.ActionAllow:
overrides[ovid] = pii.Action(action)
}
}
}
return pii.NewStreamFilter(piiRedactor, overrides, piiEvents, correlationID, userID)
}
// ForwardViaBackend loads the cloud-proxy gRPC backend, ships the
// request via the Forward RPC, and pumps the response back to the
// client through the SSE-aware PII pipeline.
// client. PII redaction runs request-side (the NER middleware + MITM
// input path); the response is forwarded unmodified.
func ForwardViaBackend(
c echo.Context,
cfg *config.ModelConfig,
body []byte,
filter *pii.StreamFilter,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
) (resultErr error) {
@@ -176,7 +147,7 @@ func ForwardViaBackend(
return passthroughError(c, statusCode, contentType, bodyReader)
}
if isStream {
return forwardStream(c, bodyReader, cfg.Proxy.Provider, filter)
return forwardStream(c, bodyReader)
}
return forwardBuffered(c, statusCode, contentType, bodyReader)
}

View File

@@ -1,72 +0,0 @@
package cloudproxy
import (
"net/http/httptest"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/routing/pii"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("BuildStreamFilter", func() {
var (
c echo.Context
cfg *config.ModelConfig
)
BeforeEach(func() {
e := echo.New()
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
rec := httptest.NewRecorder()
c = e.NewContext(req, rec)
piiOn := true
cfg = &config.ModelConfig{
Backend: "cloud-proxy",
PII: config.PIIConfig{Enabled: &piiOn},
}
})
// Three guards must each independently force a nil return — proves
// the gate is a logical AND, not an order-dependent short-circuit
// that silently activates one branch.
It("returns nil when isStream is false", func() {
patterns, err := pii.Compile(pii.DefaultPatterns())
Expect(err).NotTo(HaveOccurred())
r := pii.NewRedactor(patterns)
Expect(BuildStreamFilter(c, cfg, false, r, nil, "corr-1")).To(BeNil())
})
It("returns nil when piiRedactor is nil", func() {
Expect(BuildStreamFilter(c, cfg, true, nil, nil, "corr-1")).To(BeNil())
})
It("returns nil when the model has PII disabled", func() {
piiOff := false
cfg.PII.Enabled = &piiOff
patterns, err := pii.Compile(pii.DefaultPatterns())
Expect(err).NotTo(HaveOccurred())
r := pii.NewRedactor(patterns)
Expect(BuildStreamFilter(c, cfg, true, r, nil, "corr-1")).To(BeNil())
})
It("returns a configured filter when all preconditions hold", func() {
patterns, err := pii.Compile(pii.DefaultPatterns())
Expect(err).NotTo(HaveOccurred())
r := pii.NewRedactor(patterns)
store := pii.NewMemoryEventStore(8)
filter := BuildStreamFilter(c, cfg, true, r, store, "corr-xyz")
Expect(filter).NotTo(BeNil())
})
// Empty correlationID is allowed — some entry points don't have one.
// The filter must still construct so the stream can flow.
It("constructs a filter even when correlationID is empty", func() {
patterns, err := pii.Compile(pii.DefaultPatterns())
Expect(err).NotTo(HaveOccurred())
r := pii.NewRedactor(patterns)
Expect(BuildStreamFilter(c, cfg, true, r, nil, "")).NotTo(BeNil())
})
})

View File

@@ -16,7 +16,6 @@ import (
"golang.org/x/net/http2"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services/cloudproxy/ssewire"
"github.com/mudler/LocalAI/core/services/routing/pii"
"github.com/mudler/LocalAI/core/services/routing/piiadapter"
"github.com/mudler/LocalAI/pkg/httpclient"
@@ -24,8 +23,14 @@ import (
// PIIHandlerOptions configures NewPIIHandler.
type PIIHandlerOptions struct {
// Redactor is the regex PII redactor. nil disables redaction.
Redactor *pii.Redactor
// DetectorsByHost maps an intercepted host (lower-cased) to the NER
// detector configs that should scan request bodies bound for it. The
// configs are resolved at listener-start from each host's owning
// model's pii.detectors + the detector models' pii_detection policy
// (a model-config edit needs a MITM restart, as hosts already do). A
// host absent from the map (or with an empty slice) is forwarded
// unredacted. Detector errors at request time fail closed.
DetectorsByHost map[string][]pii.NERConfig
// EventStore receives PIIEvent rows. nil discards events.
EventStore pii.EventStore
@@ -42,13 +47,6 @@ type PIIHandlerOptions struct {
// upstream URL. Identity by default; tests inject a httptest
// listener address.
DialHost func(host string) string
// HostsWithPIIDisabled lists destination hosts whose request
// bodies should NOT run through the redactor. TLS termination,
// upstream forwarding, and audit events still happen — only the
// regex pass is bypassed. Useful for telemetry/probe endpoints
// whose bodies aren't PII-shaped.
HostsWithPIIDisabled []string
}
func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler {
@@ -76,16 +74,9 @@ func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler {
dialHost = func(h string) string { return h }
}
patternAction := map[string]pii.Action{}
if opts.Redactor != nil {
for _, p := range opts.Redactor.Patterns() {
patternAction[p.ID] = p.Action
}
}
piiDisabled := make(map[string]bool, len(opts.HostsWithPIIDisabled))
for _, h := range opts.HostsWithPIIDisabled {
piiDisabled[strings.ToLower(strings.TrimSpace(h))] = true
detectorsByHost := make(map[string][]pii.NERConfig, len(opts.DetectorsByHost))
for h, cfgs := range opts.DetectorsByHost {
detectorsByHost[strings.ToLower(strings.TrimSpace(h))] = cfgs
}
d := &piiDispatcher{
@@ -96,26 +87,22 @@ func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler {
// API keys such as Anthropic's x-api-key, which Go does NOT
// strip on cross-host redirects — to an unvetted host. Surface
// it as an error (handled as a 502) instead.
client: httpclient.New(httpclient.WithTransport(transport)),
redactor: opts.Redactor,
store: opts.EventStore,
patternAction: patternAction,
corrHeader: corrHeader,
dialHost: dialHost,
piiDisabled: piiDisabled,
client: httpclient.New(httpclient.WithTransport(transport)),
detectorsByHost: detectorsByHost,
store: opts.EventStore,
corrHeader: corrHeader,
dialHost: dialHost,
}
return d.serve
}
type piiDispatcher struct {
client *http.Client
redactor *pii.Redactor
store pii.EventStore
patternAction map[string]pii.Action
corrHeader string
dialHost func(host string) string
piiDisabled map[string]bool
eventSeq atomic.Uint64
client *http.Client
detectorsByHost map[string][]pii.NERConfig
store pii.EventStore
corrHeader string
dialHost func(host string) string
eventSeq atomic.Uint64
}
func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host string) {
@@ -144,11 +131,17 @@ func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host strin
}
shape := classifyRequestShape(host, r.URL.Path)
if d.redactor != nil && shape != shapeUnknown && !d.piiDisabled[strings.ToLower(host)] {
redacted, blocked, err := d.redactRequest(body, shape, correlationID)
cfgs := d.detectorsByHost[strings.ToLower(host)]
if len(cfgs) > 0 && shape != shapeUnknown {
redacted, blocked, err := d.redactRequest(r.Context(), body, shape, cfgs, correlationID)
switch {
case err != nil:
xlog.Debug("mitm: redact request failed; forwarding unchanged", "host", host, "path", r.URL.Path, "error", err)
// Fail closed: a detector outage must not silently forward the
// request unredacted — the operator configured this host's
// model with detectors precisely to catch this PII.
xlog.Error("mitm: NER redaction failed; blocking request (fail-closed)", "host", host, "path", r.URL.Path, "error", err)
writePIIBlocked(w, correlationID)
return
case blocked:
writePIIBlocked(w, correlationID)
return
@@ -185,12 +178,10 @@ func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host strin
}
w.WriteHeader(resp.StatusCode)
// Response/output redaction is out of scope for now — the MITM proxy
// only scans request bodies (input). SSE responses pass through
// unmodified.
contentType := resp.Header.Get("Content-Type")
if shape != shapeUnknown && d.redactor != nil && isSSE(contentType) {
d.streamWithPII(w, resp.Body, shape, correlationID)
return
}
if isSSE(contentType) {
flusher, _ := w.(http.Flusher)
buf := make([]byte, 32*1024)
@@ -232,7 +223,7 @@ func classifyRequestShape(host, path string) requestShape {
return shapeUnknown
}
func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlationID string) ([]byte, bool, error) {
func (d *piiDispatcher) redactRequest(ctx context.Context, body []byte, shape requestShape, cfgs []pii.NERConfig, correlationID string) ([]byte, bool, error) {
var parsed any
var adapter pii.Adapter
switch shape {
@@ -259,13 +250,21 @@ func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlati
return body, false, nil
}
// One scan over the joined messages so the NER tier keeps
// conversational context (see pii.RedactNERSegments); results map
// back per message with local offsets.
segTexts := make([]string, len(texts))
for i, st := range texts {
segTexts[i] = st.Text
}
results, err := pii.RedactNERSegments(ctx, segTexts, cfgs)
if err != nil {
return nil, false, fmt.Errorf("ner detect: %w", err)
}
updates := make([]pii.ScannedText, 0, len(texts))
blocked := false
for _, st := range texts {
if st.Text == "" {
continue
}
res := d.redactor.RedactWithOverrides(st.Text, nil)
for i, res := range results {
if len(res.Spans) == 0 {
continue
}
@@ -273,7 +272,7 @@ func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlati
if res.Blocked {
blocked = true
}
updates = append(updates, pii.ScannedText{Index: st.Index, Text: res.Redacted})
updates = append(updates, pii.ScannedText{Index: texts[i].Index, Text: res.Redacted})
}
if len(updates) > 0 {
@@ -295,13 +294,14 @@ func (d *piiDispatcher) recordEvents(spans []pii.Span, correlationID string) {
ev := pii.PIIEvent{
ID: fmt.Sprintf("mitm_%s_%d", correlationID, d.eventSeq.Add(1)),
Kind: pii.KindPII,
Origin: pii.OriginProxy,
CorrelationID: correlationID,
Direction: pii.DirectionIn,
PatternID: span.Pattern,
ByteOffset: span.Start,
Length: span.End - span.Start,
HashPrefix: span.HashPrefix,
Action: d.patternAction[span.Pattern],
Action: span.Action,
CreatedAt: time.Now(),
}
if err := d.store.Record(context.Background(), ev); err != nil {
@@ -310,49 +310,6 @@ func (d *piiDispatcher) recordEvents(spans []pii.Span, correlationID string) {
}
}
func (d *piiDispatcher) streamWithPII(w http.ResponseWriter, src io.Reader, shape requestShape, correlationID string) {
flusher, _ := w.(http.Flusher)
filter := pii.NewStreamFilter(d.redactor, nil, d.store, correlationID, "")
provider := ssewire.OpenAI
if shape == shapeAnthropicMessages {
provider = ssewire.Anthropic
}
emit := func(s string) {
_, _ = w.Write([]byte(s))
if flusher != nil {
flusher.Flush()
}
}
scanner := ssewire.NewScanner(src)
for scanner.Scan() {
ev := scanner.Event()
if ssewire.IsTerminalMarker(ev.DataLine, provider) {
if residual := filter.Drain(); residual != "" {
emit(ssewire.SynthResidualEvent(provider, residual))
}
emit(ev.Raw)
continue
}
out := ev.Raw
if ev.DataLine != "" {
rewritten, drop := ssewire.RewritePayload(ev.DataLine, provider, filter)
if drop {
continue
}
if rewritten != ev.DataLine {
out = strings.Replace(ev.Raw, ev.DataLine, rewritten, 1)
}
}
emit(out)
}
if residual := filter.Drain(); residual != "" {
emit(ssewire.SynthResidualEvent(provider, residual))
}
}
func writePIIBlocked(w http.ResponseWriter, correlationID string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)

View File

@@ -19,34 +19,58 @@ import (
. "github.com/onsi/gomega"
)
// startPIITestRig is the same shape as startMITMTestRig but plugs
// in the production PII handler instead of the passthrough fixture.
// The "host" the client thinks it's reaching is forced to
// api.anthropic.com so the request shape classifier matches.
// substringDetector is a deterministic pii.NERDetector for tests: it
// reports an entity for every occurrence of each configured substring,
// with byte offsets into the scanned text. Lets the MITM tests drive
// request redaction without a real token-classification backend.
type substringDetector struct{ groups map[string]string } // substring -> entity group
func (d substringDetector) Detect(_ context.Context, text string) ([]pii.NEREntity, error) {
var out []pii.NEREntity
for sub, group := range d.groups {
for idx := 0; ; {
i := strings.Index(text[idx:], sub)
if i < 0 {
break
}
start := idx + i
out = append(out, pii.NEREntity{Group: group, Start: start, End: start + len(sub), Score: 1})
idx = start + len(sub)
}
}
return out, nil
}
// testDetectorCfg flags emails (mask) and a known secret token (block).
func testDetectorCfg() pii.NERConfig {
return pii.NERConfig{
Detector: substringDetector{groups: map[string]string{
"alice@example.com": "EMAIL",
"bob@example.org": "EMAIL",
"sk-abcdefghijklmnopqrstuvwxyz1234": "PASSWORD",
}},
EntityActions: map[string]pii.Action{"EMAIL": pii.ActionMask, "PASSWORD": pii.ActionBlock},
}
}
// startPIITestRig plugs the production PII handler into a CONNECT proxy,
// with the upstream playing the role of api.anthropic.com. Request
// bodies bound for api.anthropic.com run through the NER detector above.
func startPIITestRig(upstream http.Handler) (*http.Client, string, *fakeStore, func()) {
// Upstream fake — plays the role of api.anthropic.com.
ts := httptest.NewTLSServer(upstream)
upstreamCertPool := x509.NewCertPool()
upstreamCertPool.AddCert(ts.Certificate())
upstreamURL, _ := url.Parse(ts.URL)
// Compiled patterns required for the redactor to actually fire
// (DefaultPatterns alone returns Pattern structs without regex).
patterns, err := pii.Compile(pii.DefaultPatterns())
ExpectWithOffset(1, err).NotTo(HaveOccurred())
redactor := pii.NewRedactor(patterns)
store := &fakeStore{}
ca, err := NewInMemoryCA()
ExpectWithOffset(1, err).NotTo(HaveOccurred())
// DialHost remaps the upstream dial target to the httptest
// fake while leaving the classifier-facing host
// ("api.anthropic.com") untouched. ServerName=example.com is
// what httptest.NewTLSServer issues its cert for.
upstreamHost := upstreamURL.Host
prodHandler := NewPIIHandler(PIIHandlerOptions{
Redactor: redactor,
DetectorsByHost: map[string][]pii.NERConfig{
"api.anthropic.com": {testDetectorCfg()},
},
EventStore: store,
UpstreamTLS: &tls.Config{
RootCAs: upstreamCertPool,
@@ -79,8 +103,6 @@ func startPIITestRig(upstream http.Handler) (*http.Client, string, *fakeStore, f
srv.Stop()
ts.Close()
}
// We point requests at api.anthropic.com so classifyRequestShape
// matches; the wrappedHandler retargets to the upstream fake.
return client, "https://api.anthropic.com", store, cleanup
}
@@ -101,7 +123,7 @@ func (s *fakeStore) Close() error { return nil }
func (s *fakeStore) recorded() int { return len(s.events) }
var _ = Describe("PIIHandler", func() {
It("redacts request email", func() {
It("redacts request email via NER", func() {
var receivedBody []byte
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedBody, _ = io.ReadAll(r.Body)
@@ -119,15 +141,11 @@ var _ = Describe("PIIHandler", func() {
Expect(resp.StatusCode).To(Equal(200))
Expect(string(receivedBody)).NotTo(ContainSubstring("alice@example.com"), "upstream received unredacted body")
Expect(string(receivedBody)).To(ContainSubstring("[REDACTED:email]"), "upstream did not see redaction marker")
Expect(string(receivedBody)).To(ContainSubstring("[REDACTED:ner:EMAIL]"), "upstream did not see redaction marker")
Expect(store.recorded()).NotTo(BeZero(), "no PIIEvent recorded for the email match")
})
It("refuses to follow an upstream redirect", func() {
// A 3xx from the upstream would otherwise be followed, replaying
// the request (and its provider API key, e.g. Anthropic's
// x-api-key which Go does NOT strip on cross-host redirects) to
// the Location host. The refused redirect surfaces as a 502.
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "https://evil.example.com/steal", http.StatusFound)
})
@@ -142,7 +160,7 @@ var _ = Describe("PIIHandler", func() {
Expect(resp.StatusCode).To(Equal(http.StatusBadGateway), "refused redirect must surface as 502, not be followed")
})
It("blocks api key in request", func() {
It("blocks a detected secret in the request", func() {
upstreamCalled := false
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
upstreamCalled = true
@@ -156,46 +174,13 @@ var _ = Describe("PIIHandler", func() {
resp, err := client.Post(base+"/v1/messages", "application/json", strings.NewReader(body))
Expect(err).NotTo(HaveOccurred(), "client.Post")
defer func() { _ = resp.Body.Close() }()
Expect(resp.StatusCode).To(Equal(400), "api_key_prefix has Block default")
Expect(resp.StatusCode).To(Equal(400), "PASSWORD entity action is block")
Expect(upstreamCalled).To(BeFalse(), "upstream was called despite block — proxy should short-circuit")
body2, _ := io.ReadAll(resp.Body)
Expect(string(body2)).To(ContainSubstring("pii_blocked"))
})
It("streaming redaction", func() {
// Anthropic-shape SSE; "alice@" + "example.com" splits the
// email across chunks so the StreamFilter has to buffer.
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(200)
flusher := w.(http.Flusher)
chunks := []string{
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"contact me at alice@"}}`,
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"example.com any time"}}`,
`{"type":"message_stop"}`,
}
for _, c := range chunks {
_, _ = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", "content_block_delta", c)
flusher.Flush()
}
})
client, base, _, cleanup := startPIITestRig(upstream)
defer cleanup()
body := `{"model":"claude-3-5-sonnet","max_tokens":100,"stream":true,"messages":[{"role":"user","content":"hi"}]}`
resp, err := client.Post(base+"/v1/messages", "application/json", strings.NewReader(body))
Expect(err).NotTo(HaveOccurred(), "Post")
defer func() { _ = resp.Body.Close() }()
out, _ := io.ReadAll(resp.Body)
outStr := string(out)
Expect(outStr).NotTo(ContainSubstring("alice@example.com"), "email leaked through MITM stream")
Expect(outStr).To(ContainSubstring("[REDACTED:email]"), "redaction marker missing from MITM stream")
})
It("non-chat path passes through", func() {
// A path the classifier doesn't recognise (e.g. an OAuth
// callback) must forward the body verbatim, no PII parsing.
var receivedBody []byte
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedBody, _ = io.ReadAll(r.Body)
@@ -216,14 +201,12 @@ var _ = Describe("PIIHandler", func() {
var _ = Describe("redactRequest", func() {
It("handles anthropic shape", func() {
patterns, _ := pii.Compile(pii.DefaultPatterns())
r := pii.NewRedactor(patterns)
body := []byte(`{"model":"claude","max_tokens":10,"messages":[{"role":"user","content":"reach me at bob@example.org"}]}`)
d := &piiDispatcher{redactor: r, patternAction: map[string]pii.Action{}}
out, blocked, err := d.redactRequest(body, shapeAnthropicMessages, "corr-1")
d := &piiDispatcher{}
out, blocked, err := d.redactRequest(context.Background(), body, shapeAnthropicMessages, []pii.NERConfig{testDetectorCfg()}, "corr-1")
Expect(err).NotTo(HaveOccurred())
Expect(blocked).To(BeFalse(), "email is mask, not block — blocked should be false")
Expect(blocked).To(BeFalse(), "EMAIL is mask, not block — blocked should be false")
var parsed map[string]any
Expect(json.Unmarshal(out, &parsed)).To(Succeed())
msgs := parsed["messages"].([]any)
@@ -273,9 +256,6 @@ var _ = Describe("Proxy events", func() {
})
It("tunneled host emits connect event only", func() {
// A non-allowlisted CONNECT must record a proxy_connect with
// Intercepted=false and NOT a proxy_traffic event (tunneled
// bytes never reach the dispatcher).
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprint(w, "passthrough")
})

View File

@@ -1,8 +1,9 @@
// Package cloudproxy stitches the cloud-proxy gRPC backend to the
// HTTP edge: model rewrite, body shaping, and SSE-aware PII filtering
// on the response. The outbound HTTP request itself lives inside the
// cloud-proxy backend binary (backend/go/cloud-proxy), not here — this
// package is the core-side glue.
// HTTP edge: model rewrite and body shaping. The outbound HTTP request
// itself lives inside the cloud-proxy backend binary
// (backend/go/cloud-proxy), not here — this package is the core-side
// glue. PII redaction runs request-side (the NER middleware + MITM
// input path); response/output is forwarded unmodified.
package cloudproxy
import (
@@ -10,11 +11,8 @@ import (
"fmt"
"io"
"net/http"
"strings"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/services/cloudproxy/ssewire"
"github.com/mudler/LocalAI/core/services/routing/pii"
"github.com/mudler/xlog"
)
@@ -61,65 +59,30 @@ func forwardBuffered(c echo.Context, statusCode int, contentType string, body io
return err
}
// forwardStream applies SSE-aware PII rewriting as the response flows
// to the client. provider selects the dialect (openai vs anthropic);
// it comes from cfg.Proxy.Provider on the cloud-proxy backend.
func forwardStream(c echo.Context, body io.Reader, provider string, filter *pii.StreamFilter) error {
// forwardStream relays the upstream SSE response to the client,
// flushing per read so events arrive in real time. Response/output PII
// redaction is out of scope for now, so the stream is forwarded
// unmodified.
func forwardStream(c echo.Context, body io.Reader) error {
c.Response().Header().Set("Content-Type", "text/event-stream")
c.Response().Header().Set("Cache-Control", "no-cache")
c.Response().Header().Set("Connection", "keep-alive")
c.Response().WriteHeader(http.StatusOK)
emit := func(line string) error {
_, err := fmt.Fprint(c.Response().Writer, line)
if err != nil {
return err
}
c.Response().Flush()
return nil
}
flushResidual := func() {
if filter == nil {
return
}
residual := filter.Drain()
if residual == "" {
return
}
if line := ssewire.SynthResidualEvent(ssewire.Provider(provider), residual); line != "" {
_ = emit(line)
}
}
prov := ssewire.Provider(provider)
scanner := ssewire.NewScanner(body)
for scanner.Scan() {
ev := scanner.Event()
if ssewire.IsTerminalMarker(ev.DataLine, prov) {
flushResidual()
_ = emit(ev.Raw)
continue
}
out := ev.Raw
if filter != nil && ev.DataLine != "" {
rewritten, drop := ssewire.RewritePayload(ev.DataLine, prov, filter)
if drop {
continue
}
if rewritten != ev.DataLine {
// strings.Replace with n=1 touches only the data line,
// preserving any "event:"/"id:" preamble.
out = strings.Replace(ev.Raw, ev.DataLine, rewritten, 1)
buf := make([]byte, 32*1024)
for {
n, rErr := body.Read(buf)
if n > 0 {
if _, wErr := c.Response().Writer.Write(buf[:n]); wErr != nil {
return nil
}
c.Response().Flush()
}
if err := emit(out); err != nil {
if rErr != nil {
if rErr != io.EOF {
xlog.Debug("cloudproxy: stream read error", "error", rErr)
}
return nil
}
}
if err := scanner.Err(); err != nil && err != io.EOF {
xlog.Debug("cloudproxy: stream read error", "error", err)
}
flushResidual()
return nil
}

View File

@@ -1,218 +0,0 @@
// Package ssewire holds the SSE-format helpers shared between
// the request-shape cloud proxy (core/services/cloudproxy) and the
// TLS-terminating MITM proxy (core/services/cloudproxy/mitm). Both
// run a pii.StreamFilter over per-token text extracted from
// provider-specific JSON chunks; this package owns the JSON shapes
// so a future provider addition is one edit, not two.
package ssewire
import (
"bufio"
"encoding/json"
"io"
"strings"
"github.com/mudler/LocalAI/core/services/routing/pii"
)
// Provider is the upstream wire format an SSE stream conforms to.
type Provider string
const (
OpenAI Provider = "openai"
Anthropic Provider = "anthropic"
)
// Event is one SSE event with its exact wire bytes preserved in
// Raw (so unmodified events round-trip byte-for-byte) and the
// extracted JSON payload from the data: line in DataLine.
type Event struct {
Raw string
DataLine string
}
// Scanner reads SSE events one at a time from an upstream body.
type Scanner struct {
r *bufio.Reader
ev Event
err error
}
func NewScanner(r io.Reader) *Scanner {
return &Scanner{r: bufio.NewReaderSize(r, 64*1024)}
}
func (s *Scanner) Scan() bool {
var raw strings.Builder
var dataLine string
for {
line, err := s.r.ReadString('\n')
if line != "" {
raw.WriteString(line)
trimmed := strings.TrimRight(line, "\r\n")
if trimmed == "" {
if raw.Len() == len(line) {
raw.Reset()
continue
}
s.ev = Event{Raw: raw.String(), DataLine: dataLine}
return true
}
if strings.HasPrefix(trimmed, "data:") && dataLine == "" {
payload := strings.TrimPrefix(trimmed, "data:")
payload = strings.TrimPrefix(payload, " ")
dataLine = payload
}
}
if err != nil {
s.err = err
if raw.Len() > 0 {
s.ev = Event{Raw: raw.String(), DataLine: dataLine}
return true
}
return false
}
}
}
func (s *Scanner) Event() Event { return s.ev }
func (s *Scanner) Err() error { return s.err }
// IsTerminalMarker reports whether the data line is the per-provider
// end-of-stream sentinel. The streaming PII filter must drain its
// residue before the caller forwards a terminal marker — clients
// stop reading after it.
func IsTerminalMarker(dataLine string, provider Provider) bool {
if dataLine == "" {
return false
}
if strings.TrimSpace(dataLine) == "[DONE]" {
return true
}
if provider == Anthropic {
var probe struct {
Type string `json:"type"`
}
if err := json.Unmarshal([]byte(dataLine), &probe); err == nil {
return probe.Type == "message_stop"
}
}
return false
}
// RewritePayload runs the data line's content-bearing field through
// the streaming filter. drop=true tells the caller to suppress the
// SSE event entirely (the filter buffered the whole token while
// disambiguating a pattern boundary).
func RewritePayload(dataLine string, provider Provider, filter *pii.StreamFilter) (rewritten string, drop bool) {
if strings.TrimSpace(dataLine) == "[DONE]" {
return dataLine, false
}
switch provider {
case Anthropic:
return rewriteAnthropic(dataLine, filter)
default:
return rewriteOpenAI(dataLine, filter)
}
}
func rewriteOpenAI(dataLine string, filter *pii.StreamFilter) (string, bool) {
var m map[string]any
if err := json.Unmarshal([]byte(dataLine), &m); err != nil {
return dataLine, false
}
choices, ok := m["choices"].([]any)
if !ok || len(choices) == 0 {
return dataLine, false
}
first, ok := choices[0].(map[string]any)
if !ok {
return dataLine, false
}
delta, ok := first["delta"].(map[string]any)
if !ok {
return dataLine, false
}
content, ok := delta["content"].(string)
if !ok || content == "" {
return dataLine, false
}
rewritten := filter.Push(content)
if rewritten == "" {
return "", true
}
if rewritten == content {
return dataLine, false
}
delta["content"] = rewritten
out, err := json.Marshal(m)
if err != nil {
return dataLine, false
}
return string(out), false
}
func rewriteAnthropic(dataLine string, filter *pii.StreamFilter) (string, bool) {
var m map[string]any
if err := json.Unmarshal([]byte(dataLine), &m); err != nil {
return dataLine, false
}
if t, _ := m["type"].(string); t != "content_block_delta" {
return dataLine, false
}
delta, ok := m["delta"].(map[string]any)
if !ok {
return dataLine, false
}
if dt, _ := delta["type"].(string); dt != "text_delta" {
return dataLine, false
}
text, ok := delta["text"].(string)
if !ok || text == "" {
return dataLine, false
}
rewritten := filter.Push(text)
if rewritten == "" {
return "", true
}
if rewritten == text {
return dataLine, false
}
delta["text"] = rewritten
out, err := json.Marshal(m)
if err != nil {
return dataLine, false
}
return string(out), false
}
// SynthResidualEvent builds a provider-shaped SSE event carrying
// the streaming filter's drained tail so the response body remains
// a valid event stream after the proxy splices in held-back text.
func SynthResidualEvent(provider Provider, text string) string {
switch provider {
case Anthropic:
payload := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]string{"type": "text_delta", "text": text},
}
b, err := json.Marshal(payload)
if err != nil {
return ""
}
return "event: content_block_delta\ndata: " + string(b) + "\n\n"
default:
payload := map[string]any{
"object": "chat.completion.chunk",
"choices": []map[string]any{
{"index": 0, "delta": map[string]string{"content": text}},
},
}
b, err := json.Marshal(payload)
if err != nil {
return ""
}
return "data: " + string(b) + "\n\n"
}
}

View File

@@ -1,13 +0,0 @@
package ssewire
import (
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestSsewire(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "ssewire test suite")
}

View File

@@ -1,114 +0,0 @@
package ssewire
import (
"strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// Scanner contract: returns one Event per double-newline-terminated
// SSE block, preserving the raw bytes (so unmodified events round-trip
// exactly) and extracting the first data: payload as DataLine.
var _ = Describe("Scanner", func() {
It("scans a basic event", func() {
in := "event: foo\ndata: hello\n\n"
s := NewScanner(strings.NewReader(in))
Expect(s.Scan()).To(BeTrue(), "Scan returned false on a well-formed event; err=%v", s.Err())
ev := s.Event()
Expect(ev.Raw).To(Equal(in))
Expect(ev.DataLine).To(Equal("hello"))
Expect(s.Scan()).To(BeFalse(), "Scan should return false after the only event")
})
It("handles CRLF", func() {
// Some upstreams emit CRLF instead of LF. The scanner trims
// trailing \r off the data line so DataLine carries the same
// bytes whichever line ending the producer chose.
in := "event: foo\r\ndata: hello\r\n\r\n"
s := NewScanner(strings.NewReader(in))
Expect(s.Scan()).To(BeTrue(), "Scan returned false on CRLF event; err=%v", s.Err())
Expect(s.Event().DataLine).To(Equal("hello"))
})
It("scans multiple events", func() {
in := "data: one\n\ndata: two\n\ndata: three\n\n"
s := NewScanner(strings.NewReader(in))
got := []string{}
for s.Scan() {
got = append(got, s.Event().DataLine)
}
Expect(got).To(Equal([]string{"one", "two", "three"}))
})
It("handles empty data payload", func() {
// "data:" with no payload is valid SSE — DataLine should be empty
// and Scan should still surface the event so callers can decide.
in := "data:\n\n"
s := NewScanner(strings.NewReader(in))
Expect(s.Scan()).To(BeTrue(), "Scan returned false on empty data payload; err=%v", s.Err())
Expect(s.Event().DataLine).To(Equal(""))
})
It("skips leading blank lines", func() {
// A producer that prints a blank "keep-alive" before the first
// real event must not produce a phantom event.
in := "\n\n\ndata: real\n\n"
s := NewScanner(strings.NewReader(in))
Expect(s.Scan()).To(BeTrue(), "Scan returned false; err=%v", s.Err())
Expect(s.Event().DataLine).To(Equal("real"))
})
It("handles mid-event EOF", func() {
// EOF mid-event still surfaces the partial event with whatever
// data was extracted — the StreamFilter+caller decides how to
// handle a truncated upstream rather than silently dropping it.
in := "data: half"
s := NewScanner(strings.NewReader(in))
Expect(s.Scan()).To(BeTrue(), "Scan returned false on partial event")
ev := s.Event()
Expect(ev.DataLine).To(Equal("half"))
Expect(s.Scan()).To(BeFalse(), "Scan should not surface a second event after EOF")
})
})
var _ = Describe("IsTerminalMarker", func() {
cases := []struct {
name string
dataLine string
provider Provider
want bool
}{
{"openai DONE", "[DONE]", OpenAI, true},
{"openai DONE with whitespace", " [DONE] ", OpenAI, true},
{"anthropic DONE also recognised", "[DONE]", Anthropic, true},
{"anthropic message_stop", `{"type":"message_stop"}`, Anthropic, true},
{"anthropic content_block_delta is not terminal", `{"type":"content_block_delta"}`, Anthropic, false},
{"openai chat.completion.chunk is not terminal", `{"object":"chat.completion.chunk"}`, OpenAI, false},
{"openai message_stop is not terminal (wrong provider)", `{"type":"message_stop"}`, OpenAI, false},
{"empty data", "", OpenAI, false},
{"non-json garbage", "garbage", Anthropic, false},
}
for _, c := range cases {
It(c.name, func() {
Expect(IsTerminalMarker(c.dataLine, c.provider)).To(Equal(c.want))
})
}
})
var _ = Describe("SynthResidualEvent", func() {
It("anthropic", func() {
got := SynthResidualEvent(Anthropic, "tail")
Expect(strings.HasPrefix(got, "event: content_block_delta\ndata:")).To(BeTrue(), "Anthropic residual event missing event/data lines: %q", got)
Expect(strings.HasSuffix(got, "\n\n")).To(BeTrue(), "Anthropic residual event missing trailing blank line: %q", got)
Expect(got).To(ContainSubstring(`"text":"tail"`))
})
It("openai", func() {
got := SynthResidualEvent(OpenAI, "tail")
Expect(strings.HasPrefix(got, "data: ")).To(BeTrue(), "OpenAI residual event missing data: prefix: %q", got)
Expect(strings.HasSuffix(got, "\n\n")).To(BeTrue(), "OpenAI residual event missing trailing blank line: %q", got)
Expect(got).To(ContainSubstring(`"content":"tail"`))
})
})

View File

@@ -6,12 +6,13 @@ import (
"fmt"
"os"
"path/filepath"
"reflect"
"strings"
"dario.cat/mergo"
"gopkg.in/yaml.v3"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/config/meta"
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/model"
"github.com/mudler/LocalAI/pkg/utils"
@@ -114,9 +115,7 @@ func (s *ConfigService) PatchConfig(_ context.Context, name string, patch map[st
if existingMap == nil {
existingMap = map[string]any{}
}
if err := mergo.Merge(&existingMap, patch, mergo.WithOverride); err != nil {
return nil, fmt.Errorf("merge configs: %w", err)
}
patchMerge(existingMap, patch, mapLeafFieldPaths(), "")
yamlData, err := yaml.Marshal(existingMap)
if err != nil {
return nil, fmt.Errorf("marshal merged YAML: %w", err)
@@ -142,6 +141,55 @@ func (s *ConfigService) PatchConfig(_ context.Context, name string, patch map[st
return &updated, nil
}
// mapLeafFieldPaths returns the set of dotted config paths whose schema type is
// a map that the editor edits as one complete value (e.g.
// pii_detection.entity_actions, roles, engine_args). A PATCH must REPLACE these
// wholesale rather than union them: the deep-merge only adds and overrides
// keys, so a map entry the admin deleted in the editor would otherwise silently
// survive. Derived from the config schema so it stays correct as map fields are
// added. (UIType comes from reflection, independent of any registry override.)
func mapLeafFieldPaths() map[string]struct{} {
md := meta.BuildConfigMetadata(reflect.TypeFor[config.ModelConfig]())
out := make(map[string]struct{})
for _, f := range md.Fields {
if f.UIType == "map" {
out[f.Path] = struct{}{}
}
}
return out
}
// patchMerge deep-merges src into dst with the same shape as the previous
// mergo.WithOverride behaviour — scalars and slices replace; nested
// struct-maps (e.g. pii_detection, parameters) recurse so unknown sibling keys
// the editor doesn't model survive — EXCEPT that any path in mapLeaves is
// replaced wholesale, and removed when the patch sets it empty, so deletions
// inside a map field persist to disk.
func patchMerge(dst, src map[string]any, mapLeaves map[string]struct{}, prefix string) {
for k, sv := range src {
path := k
if prefix != "" {
path = prefix + "." + k
}
if _, isLeaf := mapLeaves[path]; isLeaf {
if m, ok := sv.(map[string]any); ok && len(m) == 0 {
delete(dst, k) // emptied map field -> drop it from the YAML
} else {
dst[k] = sv
}
continue
}
// Recurse into struct-like nesting so dst-only sibling keys survive.
if sm, ok := sv.(map[string]any); ok {
if dm, ok2 := dst[k].(map[string]any); ok2 {
patchMerge(dm, sm, mapLeaves, path)
continue
}
}
dst[k] = sv
}
}
// EditYAML replaces the YAML for an installed model, with optional rename
// support. ml may be nil; when set, EditYAML calls ml.ShutdownModel(oldName)
// after a successful write so the next inference picks up the new config.

View File

@@ -107,6 +107,64 @@ var _ = Describe("ConfigService", func() {
_, err := svc.PatchConfig(ctx, "qwen", map[string]any{})
Expect(err).To(MatchError(ErrEmptyBody))
})
It("replaces a map field wholesale so deleted entries do not survive", func() {
// A detector model with a populated entity_actions map. The editor
// removes SSN and re-sends the remaining map; a naive deep-merge
// would re-add SSN (it only adds/overrides keys, never deletes).
writeModelYAML(svc, dir, "ner", map[string]any{
"backend": "llama-cpp",
"known_usecases": []any{"token_classify"},
"pii_detection": map[string]any{
"default_action": "mask",
"entity_actions": map[string]any{"SSN": "block", "EMAIL": "mask"},
},
})
_, err := svc.PatchConfig(ctx, "ner", map[string]any{
"pii_detection": map[string]any{
"default_action": "mask",
"entity_actions": map[string]any{"EMAIL": "mask"},
},
})
Expect(err).ToNot(HaveOccurred())
raw, err := os.ReadFile(filepath.Join(dir, "ner.yaml"))
Expect(err).ToNot(HaveOccurred())
var got map[string]any
Expect(yaml.Unmarshal(raw, &got)).To(Succeed())
pii := got["pii_detection"].(map[string]any)
ea := pii["entity_actions"].(map[string]any)
Expect(ea).To(HaveKeyWithValue("EMAIL", "mask"))
Expect(ea).NotTo(HaveKey("SSN"), "deleted map entry must not survive the patch")
// The scalar sibling in the same nested block is still preserved.
Expect(pii).To(HaveKeyWithValue("default_action", "mask"))
})
It("drops a map field entirely when the patch empties it", func() {
writeModelYAML(svc, dir, "ner", map[string]any{
"backend": "llama-cpp",
"known_usecases": []any{"token_classify"},
"pii_detection": map[string]any{
"default_action": "mask",
"entity_actions": map[string]any{"SSN": "block"},
},
})
_, err := svc.PatchConfig(ctx, "ner", map[string]any{
"pii_detection": map[string]any{
"entity_actions": map[string]any{},
},
})
Expect(err).ToNot(HaveOccurred())
raw, err := os.ReadFile(filepath.Join(dir, "ner.yaml"))
Expect(err).ToNot(HaveOccurred())
var got map[string]any
Expect(yaml.Unmarshal(raw, &got)).To(Succeed())
pii := got["pii_detection"].(map[string]any)
Expect(pii).NotTo(HaveKey("entity_actions"))
})
})
Describe("EditYAML", func() {

View File

@@ -1,71 +0,0 @@
package pii
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
)
// FileConfig is the on-disk schema for pii.yaml. Each Pattern entry
// overrides the matching default by ID; missing fields fall back to
// the default. Unknown IDs are rejected at load time so an admin who
// fat-fingers a pattern name gets a clear error rather than a silent
// no-op.
type FileConfig struct {
Patterns []FilePattern `yaml:"patterns"`
}
type FilePattern struct {
ID string `yaml:"id"`
Action Action `yaml:"action"`
}
// LoadConfig reads pii.yaml from path and merges it on top of
// DefaultPatterns(). path == "" returns the defaults compiled and
// ready. The returned slice is already Compile()'d, so callers can
// pass it straight to NewRedactor.
func LoadConfig(path string) ([]Pattern, error) {
defaults := DefaultPatterns()
if path == "" {
return Compile(defaults)
}
raw, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("pii: read config %q: %w", path, err)
}
var cfg FileConfig
if err := yaml.Unmarshal(raw, &cfg); err != nil {
return nil, fmt.Errorf("pii: parse config %q: %w", path, err)
}
overrides := make(map[string]Action, len(cfg.Patterns))
known := make(map[string]bool, len(defaults))
for _, d := range defaults {
known[d.ID] = true
}
for _, p := range cfg.Patterns {
if !known[p.ID] {
return nil, fmt.Errorf("pii: unknown pattern id %q in %q", p.ID, path)
}
if p.Action == "" {
continue
}
switch p.Action {
case ActionMask, ActionBlock, ActionAllow:
overrides[p.ID] = p.Action
default:
return nil, fmt.Errorf("pii: invalid action %q for pattern %q", p.Action, p.ID)
}
}
merged := make([]Pattern, len(defaults))
for i, d := range defaults {
if a, ok := overrides[d.ID]; ok {
d.Action = a
}
merged[i] = d
}
return Compile(merged)
}

View File

@@ -1,56 +0,0 @@
package pii
import (
"os"
"path/filepath"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("LoadConfig", func() {
It("returns defaults when no path given", func() {
patterns, err := LoadConfig("")
Expect(err).NotTo(HaveOccurred())
Expect(patterns).To(HaveLen(len(DefaultPatterns())))
})
It("overrides action", func() {
dir := GinkgoT().TempDir()
path := filepath.Join(dir, "pii.yaml")
body := []byte(`patterns:
- id: email
action: block
- id: ssn
action: allow
`)
Expect(os.WriteFile(path, body, 0o600)).To(Succeed())
patterns, err := LoadConfig(path)
Expect(err).NotTo(HaveOccurred())
got := map[string]Action{}
for _, p := range patterns {
got[p.ID] = p.Action
}
Expect(got["email"]).To(Equal(ActionBlock))
Expect(got["ssn"]).To(Equal(ActionAllow))
// Unmentioned patterns keep their default action.
Expect(got["credit_card"]).To(Equal(ActionMask), "credit_card default action lost")
})
It("rejects unknown id", func() {
dir := GinkgoT().TempDir()
path := filepath.Join(dir, "pii.yaml")
Expect(os.WriteFile(path, []byte("patterns:\n - id: nonsense\n action: mask\n"), 0o600)).To(Succeed())
_, err := LoadConfig(path)
Expect(err).To(HaveOccurred(), "expected error on unknown pattern id")
})
It("rejects invalid action", func() {
dir := GinkgoT().TempDir()
path := filepath.Join(dir, "pii.yaml")
Expect(os.WriteFile(path, []byte("patterns:\n - id: email\n action: lolwhat\n"), 0o600)).To(Succeed())
_, err := LoadConfig(path)
Expect(err).To(HaveOccurred(), "expected error on invalid action")
})
})

View File

@@ -19,28 +19,71 @@ import (
// drag the http/middleware package into pii's import graph and create
// a cycle (http/middleware will import this one).
const (
ctxKeyCorrelationID = "routing.correlation_id"
ctxKeyPIIEventID = "routing.pii_event_id"
ctxKeyCorrelationID = "routing.correlation_id"
ctxKeyPIIEventID = "routing.pii_event_id"
// Must match the constants in core/http/middleware/request.go.
// Echoing them across packages would create an import cycle
// (http/middleware imports this package). Drift is caught by
// integration tests against the chat route.
ctxKeyParsedRequest = "LOCALAI_REQUEST"
ctxKeyModelConfig = "MODEL_CONFIG"
ctxKeyParsedRequest = "LOCALAI_REQUEST"
ctxKeyModelConfig = "MODEL_CONFIG"
)
// ModelPIIConfig is the duck-typed view this middleware needs of the
// per-model PII configuration carried on the echo context. *config.ModelConfig
// satisfies it via PIIIsEnabled / PIIPatternOverrides; the indirection
// keeps the pii package from importing core/config.
// per-model PII configuration carried on the echo context.
// *config.ModelConfig satisfies it via PIIIsEnabled / PIIDetectors; the
// indirection keeps the pii package from importing core/config.
//
// Consumers of the override map: the action returned from PIIPatternOverrides
// is the raw YAML string (e.g. "block"). Validation against the canonical
// ActionMask/Block/Allow constants happens here, so a typo in a model
// YAML logs and is ignored rather than panicking.
// PIIDetectors lists the token-classification models whose detections
// drive redaction for this (consuming) model. The detection policy lives
// on each named detector model — resolved via NERDetectorResolver — so
// this consuming view carries no per-entity actions of its own.
type ModelPIIConfig interface {
PIIIsEnabled() bool
PIIPatternOverrides() map[string]string
PIIDetectors() []string
}
// NERDetectorResolver resolves a detector model name to a ready-to-use
// NERConfig — the detector plus the policy (min score, entity→action
// map, default action) read from that model's own pii_detection block.
// ok is false when the name can't supply a detector (unknown model, not
// a token_classify model, or load failure); the middleware fails closed
// in that case. Supplied by the application layer, which owns the model
// loader and the core/backend dependency, keeping the pii package free of
// both. A nil resolver (or the option being unset) disables the NER tier.
type NERDetectorResolver func(modelName string) (NERConfig, bool)
// Option configures optional RequestMiddleware behaviour. Threaded as
// variadic options so adding the NER tier doesn't break the existing
// four-argument call sites (routes and tests).
type Option func(*mwOptions)
type mwOptions struct {
nerResolver NERDetectorResolver
policyResolver PolicyResolver
}
// PolicyResolver returns the effective (enabled, detectors) for the model
// carried on the request context, layering instance-wide PII defaults over the
// per-model config. Supplied by the application layer (which owns core/config),
// keeping this package decoupled from it — the middleware passes the raw
// context value through as `any`. When unset, the middleware falls back to the
// duck-typed ModelPIIConfig (explicit per-model config only, no global default).
type PolicyResolver func(modelCfg any) (enabled bool, detectors []string)
// WithPolicyResolver overrides how the middleware decides enablement and the
// detector list, so the instance-wide default detector / default-on usecases
// apply. Without it the middleware reads ModelPIIConfig off the context.
func WithPolicyResolver(r PolicyResolver) Option {
return func(o *mwOptions) { o.policyResolver = r }
}
// WithNERResolver enables the NER tier. When a request's model lists
// pii.detectors, the middleware resolves each to a NERConfig and runs
// RedactNER (the union of all detectors' hits, merged). Without this
// option, or when a model lists no detectors, redaction is a no-op.
func WithNERResolver(r NERDetectorResolver) Option {
return func(o *mwOptions) { o.nerResolver = r }
}
// ScannedText is one piece of user text from the request. Index is
@@ -84,30 +127,32 @@ type Adapter struct {
// no-auth identity. The middleware writes ctxKeyPIIEventID on the echo
// context so the usage middleware can later cross-reference the event
// with the UsageRecord.
func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fallbackUser *auth.User) echo.MiddlewareFunc {
func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fallbackUser *auth.User, opts ...Option) echo.MiddlewareFunc {
var o mwOptions
for _, opt := range opts {
opt(&o)
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if redactor == nil || len(redactor.Patterns()) == 0 || adapter.Scan == nil {
if redactor == nil || adapter.Scan == nil {
return next(c)
}
// Per-model gating: redaction is opt-in per model. If the
// resolved config disables PII for this model (the default
// for non-proxy backends), pass through immediately. We do
// this before parsing the request so a disabled model
// doesn't pay the regex scan cost.
if cfg, ok := c.Get(ctxKeyModelConfig).(ModelPIIConfig); ok {
if !cfg.PIIIsEnabled() {
return next(c)
}
} else {
// No ModelPIIConfig on context → fail-closed: skip
// redaction. This protects routes that wire the
// middleware before SetModelAndConfig runs (or non-chat
// routes that don't carry a model). The middleware was
// previously fail-open, applying the global redactor
// unconditionally; the new contract is per-model
// opt-in, and a missing model is treated as disabled.
// Per-model gating: redaction is opt-in per model. The policy
// resolver (when wired) layers instance-wide defaults over the
// per-model config; otherwise we read the per-model config
// directly. A missing config (non-chat routes, or middleware
// wired before SetModelAndConfig) or a not-enabled result passes
// through.
rawCfg := c.Get(ctxKeyModelConfig)
var enabled bool
var detectors []string
if o.policyResolver != nil {
enabled, detectors = o.policyResolver(rawCfg)
} else if cfg, ok := rawCfg.(ModelPIIConfig); ok {
enabled, detectors = cfg.PIIIsEnabled(), cfg.PIIDetectors()
}
if !enabled {
return next(c)
}
@@ -116,6 +161,12 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
return next(c)
}
// A PII-enabled model with no detectors (or no resolver wired)
// has nothing to scan with — pass through.
if len(detectors) == 0 || o.nerResolver == nil {
return next(c)
}
user := auth.GetUser(c)
if user == nil {
user = fallbackUser
@@ -126,24 +177,19 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
}
correlationID, _ := c.Get(ctxKeyCorrelationID).(string)
// Resolve per-model action overrides once per request. The
// raw map is YAML strings; convert to the typed Action set
// and silently drop unknown values rather than failing the
// request — model YAML typos shouldn't take chat down.
var overrides map[string]Action
if cfg, ok := c.Get(ctxKeyModelConfig).(ModelPIIConfig); ok {
if raw := cfg.PIIPatternOverrides(); len(raw) > 0 {
overrides = make(map[string]Action, len(raw))
for id, action := range raw {
switch Action(action) {
case ActionMask, ActionBlock, ActionAllow:
overrides[id] = Action(action)
default:
xlog.Warn("pii: ignoring unknown action in per-model override",
"pattern", id, "action", action)
}
}
// Resolve each named detector to its NERConfig (detector +
// the policy from that model's own pii_detection block). A
// configured detector that can't be resolved fails closed:
// serving the request without the semantic check the operator
// asked for is exactly the leak this tier exists to prevent.
cfgs := make([]NERConfig, 0, len(detectors))
for _, name := range detectors {
nc, ok := o.nerResolver(name)
if !ok {
xlog.Error("pii: configured detector model could not be resolved; blocking request (fail-closed)", "detector", name)
return blockNERUnavailable(c, store, correlationID, userID)
}
cfgs = append(cfgs, nc)
}
texts := adapter.Scan(parsed)
@@ -151,24 +197,38 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
var blocked bool
var firstEventID string
for _, st := range texts {
if st.Text == "" {
continue
}
res := redactor.RedactWithOverrides(st.Text, overrides)
// Scan the request as ONE document (messages joined) so the NER
// tier keeps conversational context — whether "4421" is a PIN is
// decided by the question in the previous message. The spans come
// back per message with local offsets for in-place rewriting.
segTexts := make([]string, len(texts))
for i, st := range texts {
segTexts[i] = st.Text
}
// Fail closed: a detector outage at request time must NOT
// silently serve the request. The NER tier was explicitly
// configured for this model, so the semantic check is part
// of the contract.
segResults, nerErr := RedactNERSegments(c.Request().Context(), segTexts, cfgs)
if nerErr != nil {
xlog.Error("pii: NER detector failed; blocking request (fail-closed)", "error", nerErr)
return blockNERUnavailable(c, store, correlationID, userID)
}
for i, res := range segResults {
st := texts[i]
if len(res.Spans) == 0 {
continue
}
// Persist one event per span so admins can see exactly
// which patterns fired in which positions. The action
// recorded is the resolved one (after override), so the
// events log reflects what actually happened to the
// request, not the global default.
// Persist one event per detected span. The action recorded
// is the one that actually fired (carried on the span after
// the overlap merge), so the events log reflects what
// happened to the request.
for _, span := range res.Spans {
action := actionForSpan(redactor.Patterns(), span.Pattern, overrides)
ev := PIIEvent{
ID: newEventID(),
Origin: OriginMiddleware,
CorrelationID: correlationID,
UserID: userID,
Direction: DirectionIn,
@@ -176,7 +236,8 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
ByteOffset: span.Start,
Length: span.End - span.Start,
HashPrefix: span.HashPrefix,
Action: action,
Action: span.Action,
Score: span.Score,
CreatedAt: time.Now().UTC(),
}
if firstEventID == "" {
@@ -223,24 +284,85 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
}
}
func actionForPattern(patterns []Pattern, id string) Action {
for _, p := range patterns {
if p.ID == id {
return p.Action
// nerUnavailablePattern is the sentinel PatternID recorded on the
// fail-closed audit event when a model's configured NER tier cannot
// run. It is not a real regex pattern — it marks a request blocked
// because the encoder/NER check was unavailable (model unresolved or
// backend error), so the events log distinguishes it from a content
// block (which carries a real pattern ID).
const nerUnavailablePattern = "__ner_unavailable__"
// blockNERUnavailable records a fail-closed audit event and returns the
// response used when a model has an NER tier configured but it could
// not run. Failing closed is deliberate for a PII filter: if the
// semantic check the operator asked for cannot execute, refusing the
// request is safer than serving it with only the cheap regex tier. The
// 503 (vs the 400 used for a content block) tells clients and operators
// this was a dependency outage, not sensitive data in the request.
func blockNERUnavailable(c echo.Context, store EventStore, correlationID, userID string) error {
ev := PIIEvent{
ID: newEventID(),
Kind: KindPII,
Origin: OriginMiddleware,
CorrelationID: correlationID,
UserID: userID,
Direction: DirectionIn,
PatternID: nerUnavailablePattern,
Action: ActionBlock,
CreatedAt: time.Now().UTC(),
}
if store != nil {
if err := store.Record(context.Background(), ev); err != nil {
xlog.Error("pii: failed to record NER-unavailable event", "error", err)
}
}
return ActionMask
c.Set(ctxKeyPIIEventID, ev.ID)
return c.JSON(http.StatusServiceUnavailable, map[string]any{
"error": map[string]string{
"message": "request blocked: PII NER check is configured but unavailable",
"type": "pii_ner_unavailable",
},
"correlation_id": correlationID,
"pii_event_id": ev.ID,
})
}
// actionForSpan returns the resolved action for a span, preferring a
// per-request override over the pattern's stored action. Used so the
// PIIEvent log reflects the action that actually fired (e.g., a model
// upgraded email from mask to block — the event row says "block").
func actionForSpan(patterns []Pattern, id string, overrides map[string]Action) Action {
if action, ok := overrides[id]; ok {
return action
// validAction converts a raw YAML action string to the typed Action,
// returning "" for anything that isn't a known action.
func validAction(raw string) Action {
switch Action(raw) {
case ActionMask, ActionBlock, ActionAllow:
return Action(raw)
default:
return ""
}
return actionForPattern(patterns, id)
}
// validActionOr is validAction with a fallback for empty/invalid input.
func validActionOr(raw string, fallback Action) Action {
if a := validAction(raw); a != "" {
return a
}
return fallback
}
// validActions converts a raw entity-group->action map to typed
// Actions, dropping (and logging) unknown actions so a model YAML typo
// is ignored rather than taking the request down — mirroring how the
// per-pattern overrides are validated above.
func validActions(raw map[string]string) map[string]Action {
if len(raw) == 0 {
return nil
}
out := make(map[string]Action, len(raw))
for group, action := range raw {
if a := validAction(action); a != "" {
out[group] = a
} else {
xlog.Warn("pii: ignoring unknown NER entity action", "group", group, "action", action)
}
}
return out
}
func newEventID() string {
@@ -248,3 +370,8 @@ func newEventID() string {
_, _ = rand.Read(b[:])
return "pii_" + hex.EncodeToString(b[:])
}
// NewEventID mints a fresh random event id in the package's standard shape.
// Exported so callers outside this package (the analyze/redact API handlers)
// record events with ids indistinguishable from the in-band middleware's.
func NewEventID() string { return newEventID() }

View File

@@ -3,12 +3,12 @@ package pii
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"strings"
"github.com/labstack/echo/v4"
"github.com/mudler/LocalAI/core/http/auth"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -56,21 +56,15 @@ func setRequestOnContext(req *fakeRequest) echo.MiddlewareFunc {
}
// fakeModelPIIConfig satisfies the duck-typed ModelPIIConfig interface
// the middleware expects on the echo context. The real implementation
// lives on *config.ModelConfig; using a fake here keeps these tests
// out of the core/config import graph.
// the middleware expects on the echo context (PIIIsEnabled + PIIDetectors).
type fakeModelPIIConfig struct {
enabled bool
overrides map[string]string
detectors []string
}
func (f fakeModelPIIConfig) PIIIsEnabled() bool { return f.enabled }
func (f fakeModelPIIConfig) PIIPatternOverrides() map[string]string { return f.overrides }
func (f fakeModelPIIConfig) PIIIsEnabled() bool { return f.enabled }
func (f fakeModelPIIConfig) PIIDetectors() []string { return f.detectors }
// withModelConfig wires a ModelPIIConfig onto the context so the
// middleware's per-model gate doesn't fail-closed during tests. Pass
// enabled=true for the default test path; explicit-false tests should
// use the gating spec further down instead.
func withModelConfig(cfg fakeModelPIIConfig) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
@@ -80,230 +74,257 @@ func withModelConfig(cfg fakeModelPIIConfig) echo.MiddlewareFunc {
}
}
func newTestRedactor(ids ...string) *Redactor {
patterns, err := Compile(pick(DefaultPatterns(), ids))
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
return NewRedactor(patterns)
// resolverFor returns a NERDetectorResolver that maps each named model to
// the supplied NERConfig. Names absent from the map resolve to (zero,
// false) so the middleware fails closed — mirroring an unresolvable model.
func resolverFor(byName map[string]NERConfig) NERDetectorResolver {
return func(name string) (NERConfig, bool) {
cfg, ok := byName[name]
return cfg, ok
}
}
var _ = Describe("RequestMiddleware", func() {
It("masks email", func() {
red := newTestRedactor("email")
store := NewMemoryEventStore(0)
defer func() { _ = store.Close() }()
user := &auth.User{ID: "user-1", Name: "alice"}
func serve(body *fakeRequest, cfg fakeModelPIIConfig, mw echo.MiddlewareFunc, withConfig bool) (*httptest.ResponseRecorder, *bool) {
called := new(bool)
e := echo.New()
chain := []echo.MiddlewareFunc{setRequestOnContext(body)}
if withConfig {
chain = append(chain, withModelConfig(cfg))
}
chain = append(chain, mw)
e.POST("/chat", func(c echo.Context) error {
*called = true
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
}, chain...)
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
w := httptest.NewRecorder()
e.ServeHTTP(w, req)
return w, called
}
body := &fakeRequest{Messages: []string{"contact me at alice@example.com"}}
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
func nerCfg(action Action, entities ...NEREntity) NERConfig {
return NERConfig{
Detector: &stubNERDetector{entities: entities},
DefaultAction: action,
}
}
e := echo.New()
e.POST("/chat", func(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw, func(next echo.HandlerFunc) echo.HandlerFunc {
// Inject the user as if upstream auth ran.
return func(c echo.Context) error {
c.Set("auth_user", user)
return next(c)
}
})
var _ = Describe("RequestMiddleware (NER)", func() {
store := func() EventStore { return NewMemoryEventStore(0) }
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
w := httptest.NewRecorder()
e.ServeHTTP(w, req)
It("masks a detected entity end-to-end", func() {
st := store()
body := &fakeRequest{Messages: []string{"Hi I'm Alice today"}}
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{
"privacy-filter": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
})))
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"privacy-filter"}}, mw, true)
Expect(w.Code).To(Equal(http.StatusOK), "body=%s", w.Body.String())
Expect(body.Messages[0]).NotTo(ContainSubstring("alice@example.com"), "request body should be redacted in place")
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:email]"))
events, err := store.List(context.Background(), ListQuery{Limit: 100})
Expect(err).NotTo(HaveOccurred(), "list events")
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:ner:PER]"))
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
Expect(events).To(HaveLen(1))
Expect(events[0].PatternID).To(Equal("email"))
Expect(events[0].PatternID).To(Equal("ner:PER"))
Expect(events[0].Direction).To(Equal(DirectionIn))
})
It("blocks api key", func() {
red := newTestRedactor("api_key_prefix")
store := NewMemoryEventStore(0)
defer func() { _ = store.Close() }()
body := &fakeRequest{Messages: []string{"my key is sk-abcdefghijklmnopqrstuvwxyz0123456789"}}
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
e := echo.New()
handlerCalled := false
e.POST("/chat", func(c echo.Context) error {
handlerCalled = true
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
w := httptest.NewRecorder()
e.ServeHTTP(w, req)
Expect(w.Code).To(Equal(http.StatusBadRequest), "expected 400 on block; body=%s", w.Body.String())
Expect(handlerCalled).To(BeFalse(), "handler must not run when request is blocked")
// Ensure the matched value never appears in the response body.
Expect(w.Body.String()).NotTo(ContainSubstring("abcdefghijklmnopqrstuvwxyz0123456789"), "blocked response leaks the matched value")
It("blocks (400) when a detected entity's action is block", func() {
st := store()
body := &fakeRequest{Messages: []string{"my password is hunter2 ok"}}
cfg := NERConfig{
Detector: &stubNERDetector{entities: []NEREntity{{Group: "PASSWORD", Start: 15, End: 22, Score: 0.99}}},
EntityActions: map[string]Action{"PASSWORD": ActionBlock},
}
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{"pf": cfg})))
w, called := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
Expect(w.Code).To(Equal(http.StatusBadRequest), "body=%s", w.Body.String())
Expect(*called).To(BeFalse(), "handler must not run when blocked")
var resp map[string]any
Expect(json.Unmarshal(w.Body.Bytes(), &resp)).To(Succeed())
errBlock, ok := resp["error"].(map[string]any)
Expect(ok).To(BeTrue())
errBlock, _ := resp["error"].(map[string]any)
Expect(errBlock["type"]).To(Equal("pii_blocked"))
})
It("allow leaves text intact but still records an event", func() {
patterns, _ := Compile([]Pattern{{
ID: "email", Description: "Email", Action: ActionAllow, MaxMatchLength: 254,
}})
red := NewRedactor(patterns)
store := NewMemoryEventStore(0)
defer func() { _ = store.Close() }()
It("allow leaves text intact but records an event", func() {
st := store()
body := &fakeRequest{Messages: []string{"hi at alice@example.com"}}
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
e := echo.New()
e.POST("/chat", func(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
w := httptest.NewRecorder()
e.ServeHTTP(w, req)
cfg := NERConfig{
Detector: &stubNERDetector{entities: []NEREntity{{Group: "EMAIL", Start: 6, End: 23, Score: 0.9}}},
EntityActions: map[string]Action{"EMAIL": ActionAllow},
}
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{"pf": cfg})))
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
Expect(w.Code).To(Equal(http.StatusOK))
// allow does NOT mutate the body — the model still sees the email.
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "allow should leave text intact")
// ...but the detection is still recorded for audit.
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
Expect(events).To(HaveLen(1), "allow should still record a PIIEvent")
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"))
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
Expect(events).To(HaveLen(1))
Expect(events[0].Action).To(Equal(ActionAllow))
})
It("no match passes through", func() {
red := newTestRedactor()
store := NewMemoryEventStore(0)
defer func() { _ = store.Close() }()
It("passes through on no match", func() {
st := store()
body := &fakeRequest{Messages: []string{"perfectly innocent text"}}
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
e := echo.New()
e.POST("/chat", func(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
w := httptest.NewRecorder()
e.ServeHTTP(w, req)
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{"pf": nerCfg(ActionMask)})))
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
Expect(w.Code).To(Equal(http.StatusOK))
Expect(body.Messages[0]).To(Equal("perfectly innocent text"), "body should be untouched")
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
Expect(events).To(BeEmpty(), "expected 0 events on no-match input")
Expect(body.Messages[0]).To(Equal("perfectly innocent text"))
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
Expect(events).To(BeEmpty())
})
It("skips when model config disabled", func() {
// Per-model gating is the new contract: a model with PIIIsEnabled
// returning false must bypass redaction entirely, even if the
// global redactor has matching patterns.
red := newTestRedactor("email")
store := NewMemoryEventStore(0)
defer func() { _ = store.Close() }()
body := &fakeRequest{Messages: []string{"contact alice@example.com"}}
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
e := echo.New()
e.POST("/chat", func(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: false}), mw)
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
w := httptest.NewRecorder()
e.ServeHTTP(w, req)
It("skips when the model has PII disabled", func() {
st := store()
body := &fakeRequest{Messages: []string{"Hi I'm Alice"}}
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{
"pf": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
})))
w, _ := serve(body, fakeModelPIIConfig{enabled: false, detectors: []string{"pf"}}, mw, true)
Expect(w.Code).To(Equal(http.StatusOK))
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "disabled model must not redact")
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
Expect(events).To(BeEmpty(), "disabled model must produce no events")
Expect(body.Messages[0]).To(Equal("Hi I'm Alice"), "disabled model must not redact")
})
It("fails closed without model config", func() {
// Routes that wire the middleware before SetModelAndConfig, or
// non-chat routes lacking a model, hit this path. The contract
// is fail-closed: pass through without redaction so a missing
// model can't accidentally leak through global defaults.
red := newTestRedactor("email")
store := NewMemoryEventStore(0)
defer func() { _ = store.Close() }()
body := &fakeRequest{Messages: []string{"contact alice@example.com"}}
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
e := echo.New()
// Note: no withModelConfig in the chain.
e.POST("/chat", func(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
}, setRequestOnContext(body), mw)
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
w := httptest.NewRecorder()
e.ServeHTTP(w, req)
It("passes through when the model lists no detectors", func() {
st := store()
body := &fakeRequest{Messages: []string{"Hi I'm Alice"}}
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{})))
w, _ := serve(body, fakeModelPIIConfig{enabled: true}, mw, true)
Expect(w.Code).To(Equal(http.StatusOK))
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "missing ModelPIIConfig should fail-closed (no redaction)")
Expect(body.Messages[0]).To(Equal("Hi I'm Alice"))
})
It("applies per-model override", func() {
// email defaults to mask. A per-model override upgrades it to
// block. The middleware short-circuits with 400, the request
// body is never touched, and the events log records action=block.
red := newTestRedactor("email")
store := NewMemoryEventStore(0)
defer func() { _ = store.Close() }()
It("fails closed without a model config", func() {
st := store()
body := &fakeRequest{Messages: []string{"Hi I'm Alice"}}
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{
"pf": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
})))
w, _ := serve(body, fakeModelPIIConfig{}, mw, false) // no model config on context
Expect(w.Code).To(Equal(http.StatusOK))
Expect(body.Messages[0]).To(Equal("Hi I'm Alice"), "missing ModelPIIConfig should pass through")
})
It("unions multiple detectors", func() {
st := store()
body := &fakeRequest{Messages: []string{"Alice at acme"}}
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{
"names": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 0, End: 5, Score: 0.9}),
"orgs": nerCfg(ActionMask, NEREntity{Group: "ORG", Start: 9, End: 13, Score: 0.9}),
})))
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"names", "orgs"}}, mw, true)
Expect(w.Code).To(Equal(http.StatusOK))
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:ner:PER]"))
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:ner:ORG]"))
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
Expect(events).To(HaveLen(2))
})
It("fails closed (503) when a detector errors", func() {
st := store()
body := &fakeRequest{Messages: []string{"contact alice@example.com"}}
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
cfg := NERConfig{Detector: &stubNERDetector{err: errors.New("backend offline")}, DefaultAction: ActionMask}
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{"pf": cfg})))
w, called := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
e := echo.New()
handlerCalled := false
e.POST("/chat", func(c echo.Context) error {
handlerCalled = true
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
}, setRequestOnContext(body),
withModelConfig(fakeModelPIIConfig{
enabled: true,
overrides: map[string]string{"email": "block"},
}), mw)
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
w := httptest.NewRecorder()
e.ServeHTTP(w, req)
Expect(w.Code).To(Equal(http.StatusBadRequest), "expected 400 from override-block; body=%s", w.Body.String())
Expect(handlerCalled).To(BeFalse(), "handler must not run when override blocks")
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
Expect(w.Code).To(Equal(http.StatusServiceUnavailable), "body=%s", w.Body.String())
Expect(*called).To(BeFalse())
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "request body must be untouched on a fail-closed block")
var resp map[string]any
Expect(json.Unmarshal(w.Body.Bytes(), &resp)).To(Succeed())
errBlock, _ := resp["error"].(map[string]any)
Expect(errBlock["type"]).To(Equal("pii_ner_unavailable"))
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
Expect(events).To(HaveLen(1))
Expect(events[0].Action).To(Equal(ActionBlock), "event must record the resolved (override) action")
Expect(events[0].PatternID).To(Equal(nerUnavailablePattern))
})
It("fails closed (503) when a configured detector can't be resolved", func() {
st := store()
body := &fakeRequest{Messages: []string{"contact alice@example.com"}}
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{}))) // "missing" not present
w, called := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"missing"}}, mw, true)
Expect(w.Code).To(Equal(http.StatusServiceUnavailable))
Expect(*called).To(BeFalse())
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
Expect(events).To(HaveLen(1))
Expect(events[0].PatternID).To(Equal(nerUnavailablePattern))
})
It("nil redactor is passthrough", func() {
body := &fakeRequest{Messages: []string{"alice@example.com"}}
mw := RequestMiddleware(nil, nil, fakeAdapter(), nil)
e := echo.New()
e.POST("/chat", func(c echo.Context) error {
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
w := httptest.NewRecorder()
e.ServeHTTP(w, req)
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
Expect(w.Code).To(Equal(http.StatusOK))
Expect(body.Messages[0]).To(Equal("alice@example.com"), "nil redactor must be a no-op")
})
It("WithPolicyResolver enables a model the per-model config left off (global default)", func() {
st := store()
body := &fakeRequest{Messages: []string{"Hi I'm Alice today"}}
// The per-model config is disabled with no detectors; the policy
// resolver (instance-wide default) turns it on and supplies one.
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{
"global-pf": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
})),
WithPolicyResolver(func(_ any) (bool, []string) { return true, []string{"global-pf"} }))
w, _ := serve(body, fakeModelPIIConfig{enabled: false}, mw, true)
Expect(w.Code).To(Equal(http.StatusOK), "body=%s", w.Body.String())
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:ner:PER]"))
})
It("WithPolicyResolver returning disabled short-circuits an otherwise-enabled model", func() {
st := store()
body := &fakeRequest{Messages: []string{"Hi I'm Alice today"}}
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{
"pf": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
})),
WithPolicyResolver(func(_ any) (bool, []string) { return false, nil }))
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
Expect(w.Code).To(Equal(http.StatusOK))
Expect(body.Messages[0]).To(Equal("Hi I'm Alice today"), "resolver disabled => no redaction")
})
It("scans all messages as one document so earlier-message context applies", func() {
st := store()
// The detector (pinAfterCard) only recognises "4421" when "card"
// appears earlier in the SAME text it is handed — so this only
// masks if the middleware joins the messages before scanning.
body := &fakeRequest{Messages: []string{
"What are the last four digits of your card?",
"it is 4421 ok",
}}
cfg := NERConfig{Detector: &funcNERDetector{fn: pinAfterCard}, DefaultAction: ActionMask}
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
WithNERResolver(resolverFor(map[string]NERConfig{"pf": cfg})))
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
Expect(w.Code).To(Equal(http.StatusOK), "body=%s", w.Body.String())
Expect(body.Messages[0]).To(Equal("What are the last four digits of your card?"), "question untouched")
Expect(body.Messages[1]).To(Equal("it is [REDACTED:ner:PIN] ok"))
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
Expect(events).To(HaveLen(1))
Expect(events[0].ByteOffset).To(Equal(6), "event offsets are message-local")
})
})

View File

@@ -28,6 +28,10 @@ type NEREntity struct {
Start int
End int
Score float32
// Text is the matched substring as the detector saw it. Carried for
// debug logging only (the persisted PIIEvent never stores the raw
// value); the redactor re-slices the original text for masking.
Text string
}
// NERConfig configures the encoder tier for one redactor invocation.
@@ -56,8 +60,22 @@ type NERConfig struct {
// entities silently" — useful when the model returns a broad
// taxonomy but the admin only cares about a subset.
DefaultAction Action
// Source labels where this detector's hits come from. It becomes the
// PatternID prefix on events and the [REDACTED:<id>] mask, so neural NER
// detections (Source "ner") and deterministic pattern-matcher detections
// (Source "pattern") are told apart in the events log and to the model.
// Empty defaults to "ner" for backward compatibility.
Source string
}
// Detector source labels (the PatternID prefix). Kept short and stable —
// they appear in the events log and the [REDACTED:...] mask.
const (
SourceNER = "ner"
SourcePattern = "pattern"
)
// ResolveAction returns the action configured for a detected entity
// group, falling back to DefaultAction. Returns ("", false) when the
// entity should be ignored entirely (no override + no default).
@@ -71,13 +89,39 @@ func (c NERConfig) ResolveAction(group string) (Action, bool) {
return "", false
}
// nerPatternID returns the synthetic pattern ID that audit rows carry
// for NER hits. Prefixing with "ner:" keeps these distinguishable from
// regex pattern IDs in the events tab and in filter queries; admins
// can switch off a single entity type with the same Disabled-pattern
// machinery used for regex.
func nerPatternID(group string) string {
return "ner:" + group
// NERConfigFromRaw builds a typed NERConfig from a detector plus the raw
// policy strings carried on a detector model's pii_detection config. An
// empty or invalid default_action becomes ActionMask — the safe-by-default
// policy for a PII filter (a detected entity is masked unless an admin
// downgrades it). Unknown per-entity actions are dropped (and logged by
// validActions). This is the single conversion point the application-layer
// resolver uses, so the detector model's policy reaches the redactor in
// exactly one shape. source labels the detector kind (SourceNER /
// SourcePattern) and becomes the PatternID prefix; empty defaults to
// SourceNER.
func NERConfigFromRaw(detector NERDetector, minScore float32, defaultAction string, entityActions map[string]string, source string) NERConfig {
if source == "" {
source = SourceNER
}
return NERConfig{
Detector: detector,
MinScore: minScore,
DefaultAction: validActionOr(defaultAction, ActionMask),
EntityActions: validActions(entityActions),
Source: source,
}
}
// patternID returns the synthetic pattern ID that audit rows and masks carry
// for this detector's hits, e.g. "ner:EMAIL" or "pattern:ANTHROPIC_KEY". The
// source prefix keeps neural and deterministic detections distinguishable in
// the events tab and in pattern_id filter queries.
func (c NERConfig) patternID(group string) string {
source := c.Source
if source == "" {
source = SourceNER
}
return source + ":" + group
}
// errNERDetector is a NERDetector that always returns the wrapped

View File

@@ -9,8 +9,7 @@ import (
)
// stubNERDetector returns a fixed slice of entities and tracks call
// count so tests can assert the detector isn't called when text is
// empty / no patterns / detector disabled.
// count so tests can assert the detector isn't called when text is empty.
type stubNERDetector struct {
entities []NEREntity
err error
@@ -22,43 +21,39 @@ func (s *stubNERDetector) Detect(_ context.Context, _ string) ([]NEREntity, erro
return s.entities, s.err
}
var _ = Describe("RedactWithNER", func() {
It("nil detector is regex-only", func() {
// When the NER tier is disabled (Detector == nil) the redactor
// must behave exactly like the existing regex-only path — no
// detector call, same Result shape, no error.
r := NewRedactor([]Pattern{pickEmail()})
res, err := r.RedactWithNER(context.Background(), "ping me at alice@example.com", nil, NERConfig{})
var _ = Describe("RedactNER", func() {
It("no detectors is a no-op", func() {
res, err := RedactNER(context.Background(), "ping me at alice@example.com", nil)
Expect(err).NotTo(HaveOccurred())
Expect(res.Redacted).To(ContainSubstring("[REDACTED:email]"), "regex tier should still run when Detector is nil")
Expect(res.Redacted).To(Equal("ping me at alice@example.com"))
Expect(res.Spans).To(BeEmpty())
})
It("applies entity actions", func() {
det := &stubNERDetector{entities: []NEREntity{
{Group: "PER", Start: 6, End: 11, Score: 0.95}, // "Alice" in "Hi I'm Alice today"
}}
r := NewRedactor(nil)
res, err := r.RedactWithNER(context.Background(), "Hi I'm Alice today", nil, NERConfig{
res, err := RedactNER(context.Background(), "Hi I'm Alice today", []NERConfig{{
Detector: det,
EntityActions: map[string]Action{"PER": ActionMask},
})
}})
Expect(err).NotTo(HaveOccurred())
Expect(det.calls).To(Equal(1))
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:PER]"))
Expect(res.Spans).To(HaveLen(1))
Expect(res.Spans[0].Pattern).To(Equal("ner:PER"))
Expect(res.Spans[0].Action).To(Equal(ActionMask))
})
It("filters below MinScore", func() {
det := &stubNERDetector{entities: []NEREntity{
{Group: "PER", Start: 0, End: 5, Score: 0.20},
}}
r := NewRedactor(nil)
res, err := r.RedactWithNER(context.Background(), "Alice", nil, NERConfig{
res, err := RedactNER(context.Background(), "Alice", []NERConfig{{
Detector: det,
MinScore: 0.50,
EntityActions: map[string]Action{"PER": ActionMask},
})
}})
Expect(err).NotTo(HaveOccurred())
Expect(res.Redacted).To(Equal("Alice"), "low-confidence entity should be dropped")
})
@@ -67,108 +62,120 @@ var _ = Describe("RedactWithNER", func() {
det := &stubNERDetector{entities: []NEREntity{
{Group: "ORG", Start: 7, End: 11, Score: 0.9}, // "Acme" in "Hello, Acme!"
}}
r := NewRedactor(nil)
res, err := r.RedactWithNER(context.Background(), "Hello, Acme!", nil, NERConfig{
res, err := RedactNER(context.Background(), "Hello, Acme!", []NERConfig{{
Detector: det,
DefaultAction: ActionMask,
})
}})
Expect(err).NotTo(HaveOccurred())
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:ORG]"), "DefaultAction should apply to ORG")
})
It("drops unconfigured groups with no default", func() {
// EntityActions has no entry for ORG and DefaultAction is empty —
// the detected entity must be ignored entirely (no audit row, no
// redaction).
det := &stubNERDetector{entities: []NEREntity{
{Group: "ORG", Start: 0, End: 4, Score: 0.9},
}}
r := NewRedactor(nil)
res, err := r.RedactWithNER(context.Background(), "Acme", nil, NERConfig{
res, err := RedactNER(context.Background(), "Acme", []NERConfig{{
Detector: det,
EntityActions: map[string]Action{"PER": ActionMask}, // ORG is unconfigured
})
}})
Expect(err).NotTo(HaveOccurred())
Expect(res.Redacted).To(Equal("Acme"))
Expect(res.Spans).To(BeEmpty())
})
It("overlapping hits keep stronger action", func() {
// Regex marks 0..10 as mask; NER marks 5..15 as block. After
// merge, the union 0..15 keeps the strongest action (block).
pat := Pattern{ID: "test", Action: ActionMask, regex: rangeRegex(0, 10)}
r := NewRedactor([]Pattern{pat})
det := &stubNERDetector{entities: []NEREntity{
{Group: "PER", Start: 5, End: 15, Score: 0.9},
}}
It("unions multiple detectors and keeps the stronger action on overlap", func() {
// Detector A marks 0..10 as mask; detector B marks 5..15 as block.
// After merge, the union 0..15 keeps the strongest action (block).
detA := &stubNERDetector{entities: []NEREntity{{Group: "A", Start: 0, End: 10, Score: 0.9}}}
detB := &stubNERDetector{entities: []NEREntity{{Group: "B", Start: 5, End: 15, Score: 0.9}}}
text := "0123456789ABCDEF"
res, err := r.RedactWithNER(context.Background(), text, nil, NERConfig{
Detector: det,
EntityActions: map[string]Action{"PER": ActionBlock},
res, err := RedactNER(context.Background(), text, []NERConfig{
{Detector: detA, EntityActions: map[string]Action{"A": ActionMask}},
{Detector: detB, EntityActions: map[string]Action{"B": ActionBlock}},
})
Expect(err).NotTo(HaveOccurred())
Expect(detA.calls).To(Equal(1))
Expect(detB.calls).To(Equal(1))
Expect(res.Blocked).To(BeTrue(), "overlapping mask+block should set Blocked=true")
})
It("detector error returns regex result and error", func() {
// Fail-open: when the NER detector errors, the redactor still
// returns regex-tier hits so an offline NER backend doesn't strip
// the cheap protection. Caller can read the error and decide
// whether to surface it.
det := &stubNERDetector{err: errors.New("backend offline")}
r := NewRedactor([]Pattern{pickEmail()})
res, err := r.RedactWithNER(context.Background(), "ping alice@example.com", nil, NERConfig{
Detector: det,
DefaultAction: ActionMask,
It("returns a best-effort result and the error when a detector fails (fail-closed contract)", func() {
// One healthy detector, one failing. RedactNER returns the healthy
// detector's hits AND the error, so the caller can fail closed.
good := &stubNERDetector{entities: []NEREntity{{Group: "PER", Start: 0, End: 5, Score: 0.9}}}
bad := &stubNERDetector{err: errors.New("backend offline")}
res, err := RedactNER(context.Background(), "Alice", []NERConfig{
{Detector: good, DefaultAction: ActionMask},
{Detector: bad, DefaultAction: ActionMask},
})
Expect(err).To(HaveOccurred(), "expected detector error to surface")
Expect(res.Redacted).To(ContainSubstring("[REDACTED:email]"), "regex tier should still apply on NER failure")
Expect(err).To(HaveOccurred())
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:PER]"), "healthy detector's hits should still apply")
})
It("out-of-bounds offsets are skipped", func() {
// A misconfigured / buggy backend could return offsets past the
// end of text. The redactor must not panic on slice OOB.
It("skips out-of-bounds offsets without panicking", func() {
det := &stubNERDetector{entities: []NEREntity{
{Group: "PER", Start: 0, End: 999, Score: 0.9},
{Group: "PER", Start: -1, End: 3, Score: 0.9},
{Group: "PER", Start: 5, End: 5, Score: 0.9}, // zero-length
}}
r := NewRedactor(nil)
res, err := r.RedactWithNER(context.Background(), "Alice", nil, NERConfig{
res, err := RedactNER(context.Background(), "Alice", []NERConfig{{
Detector: det,
DefaultAction: ActionMask,
})
}})
Expect(err).NotTo(HaveOccurred())
Expect(res.Redacted).To(Equal("Alice"))
Expect(res.Spans).To(BeEmpty())
})
})
// --- test helpers ---
var _ = Describe("NERConfigFromRaw", func() {
det := &stubNERDetector{}
// rangeMatcher is a deterministic regexpMatcher stub: it claims one
// fixed range regardless of input. Lets the overlap-merge test
// produce a known regex/NER intersection without depending on a real
// compiled regex.
type rangeMatcher struct{ start, end int }
It("defaults an empty default_action to mask and an empty source to ner", func() {
cfg := NERConfigFromRaw(det, 0.4, "", nil, "")
Expect(cfg.DefaultAction).To(Equal(ActionMask))
Expect(cfg.MinScore).To(BeNumerically("~", 0.4, 1e-6))
Expect(cfg.Source).To(Equal(SourceNER))
Expect(cfg.patternID("EMAIL")).To(Equal("ner:EMAIL"))
})
func (m rangeMatcher) FindAllStringIndex(_ string, _ int) [][]int {
return [][]int{{m.start, m.end}}
}
It("passes through valid actions and drops invalid ones", func() {
cfg := NERConfigFromRaw(det, 0, "block", map[string]string{
"PASSWORD": "block",
"EMAIL": "mask",
"BOGUS": "nonsense", // dropped
}, SourceNER)
Expect(cfg.DefaultAction).To(Equal(ActionBlock))
Expect(cfg.EntityActions).To(HaveKeyWithValue("PASSWORD", ActionBlock))
Expect(cfg.EntityActions).To(HaveKeyWithValue("EMAIL", ActionMask))
Expect(cfg.EntityActions).NotTo(HaveKey("BOGUS"))
})
func rangeRegex(start, end int) regexpMatcher { return rangeMatcher{start: start, end: end} }
It("prefixes pattern-detector hits with the pattern source", func() {
cfg := NERConfigFromRaw(det, 0, "mask", nil, SourcePattern)
Expect(cfg.Source).To(Equal(SourcePattern))
Expect(cfg.patternID("ANTHROPIC_KEY")).To(Equal("pattern:ANTHROPIC_KEY"))
})
})
// pickEmail returns the compiled "email" pattern from DefaultPatterns
// — the NER tests use it as the regex tier's contribution.
func pickEmail() Pattern {
for _, p := range DefaultPatterns() {
if p.ID == "email" {
compiled, err := Compile([]Pattern{p})
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
return compiled[0]
}
}
Fail("email pattern missing from DefaultPatterns")
return Pattern{}
}
var _ = Describe("NERConfig.ResolveAction", func() {
It("prefers an explicit entity action over the default", func() {
cfg := NERConfig{EntityActions: map[string]Action{"EMAIL": ActionBlock}, DefaultAction: ActionMask}
a, ok := cfg.ResolveAction("EMAIL")
Expect(ok).To(BeTrue())
Expect(a).To(Equal(ActionBlock))
})
It("falls back to the default action", func() {
cfg := NERConfig{DefaultAction: ActionMask}
a, ok := cfg.ResolveAction("ANYTHING")
Expect(ok).To(BeTrue())
Expect(a).To(Equal(ActionMask))
})
It("ignores a group with no override and no default", func() {
cfg := NERConfig{}
_, ok := cfg.ResolveAction("ANYTHING")
Expect(ok).To(BeFalse())
})
})

View File

@@ -1,188 +0,0 @@
package pii
import (
"fmt"
"regexp"
"strings"
)
// regexpMatcher is a thin wrapper so tests can swap in a deterministic
// matcher without touching the regexp package. Real usage uses
// regexpMatcherFromPattern; tests can construct fakes.
type regexpMatcher interface {
FindAllStringIndex(s string, n int) [][]int
}
type goRegexp struct{ r *regexp.Regexp }
func (g goRegexp) FindAllStringIndex(s string, n int) [][]int {
return g.r.FindAllStringIndex(s, n)
}
// DefaultPatterns returns the built-in regex set. Each entry includes
// a conservative MaxMatchLength so the streaming filter can size its
// tail buffer without re-parsing the regex at runtime.
//
// Caveats by design:
// - The phone pattern matches international and US formats but does
// not validate area codes. False positives on numbers that look
// phone-like (e.g., timestamps in some formats) are accepted in
// return for reliable coverage.
// - The credit card pattern requires the Luhn check (verifyLuhn) to
// reduce false positives — random 16-digit strings won't match.
// - The API-key pattern targets common provider prefixes (sk-, pk-,
// xoxb-, ghp_, github_pat_) rather than guessing entropy. Adding
// new providers should append a new Pattern, not extend an
// existing alternation, so the admin UI can show one row per
// provider with its own toggle.
func DefaultPatterns() []Pattern {
return []Pattern{
{
ID: "email",
Description: "Email address",
Action: ActionMask,
MaxMatchLength: 254, // RFC 5321 max
},
{
ID: "phone",
Description: "Phone number (international or US format)",
Action: ActionMask,
MaxMatchLength: 24,
},
{
ID: "ssn",
Description: "US Social Security Number (NNN-NN-NNNN)",
Action: ActionMask,
MaxMatchLength: 11,
},
{
ID: "credit_card",
Description: "Credit card number (Luhn-verified)",
Action: ActionMask,
MaxMatchLength: 19,
},
{
ID: "ipv4",
Description: "IPv4 address",
Action: ActionMask,
MaxMatchLength: 15,
},
{
ID: "api_key_prefix",
Description: "Common API key prefixes (sk-, pk-, xoxb-, ghp_, github_pat_)",
Action: ActionBlock, // tighter default — leaked credentials are higher harm
MaxMatchLength: 200,
},
}
}
// patternRegexps maps Pattern.ID to its compiled regex. Kept separate
// from the Pattern struct so DefaultPatterns can be data-only and
// tests can swap matchers via Compile().
var patternRegexps = map[string]*regexp.Regexp{
// Pragmatic email — does not implement RFC 5322 in full (no one
// sane does in a regex). Catches the common shape; the encoder
// NER tier (future) catches edge cases.
"email": regexp.MustCompile(`(?i)[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}`),
// US: (123) 456-7890, 123-456-7890, 123.456.7890, 1234567890.
// International: +<country>-<area>-<rest> with separators.
"phone": regexp.MustCompile(`(?:\+?\d{1,3}[\s\-.]?)?(?:\(\d{3}\)|\d{3})[\s\-.]?\d{3}[\s\-.]?\d{4}`),
"ssn": regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`),
// 13-19 digit Luhn-eligible runs. The verifier in match() rejects
// non-Luhn matches.
"credit_card": regexp.MustCompile(`\b(?:\d[ \-]?){13,19}\b`),
"ipv4": regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`),
// Common provider prefixes; each alternative is a separate
// well-known marker rather than a permissive entropy match.
"api_key_prefix": regexp.MustCompile(`(?:sk-[A-Za-z0-9]{20,}|pk-[A-Za-z0-9]{20,}|xoxb-[A-Za-z0-9\-]{20,}|ghp_[A-Za-z0-9]{20,}|github_pat_[A-Za-z0-9_]{20,})`),
}
// Compile attaches matchers to each pattern. Patterns whose ID is not
// in patternRegexps are returned as a typed error so an admin who
// adds a custom pattern via config gets a clear "no regex registered"
// message instead of silent skip.
func Compile(patterns []Pattern) ([]Pattern, error) {
out := make([]Pattern, len(patterns))
for i, p := range patterns {
r, ok := patternRegexps[p.ID]
if !ok {
return nil, fmt.Errorf("pii: no regex registered for pattern id %q", p.ID)
}
p.regex = goRegexp{r: r}
out[i] = p
}
return out, nil
}
// VerifyMatch applies pattern-specific post-checks (e.g. Luhn for
// credit_card). Returns the original match or "" to discard it.
func VerifyMatch(patternID, candidate string) string {
switch patternID {
case "credit_card":
digits := stripNonDigits(candidate)
if len(digits) < 13 || len(digits) > 19 {
return ""
}
if !verifyLuhn(digits) {
return ""
}
case "ipv4":
// Each octet must be 0..255. The regex allows 0..999 since
// regex isn't great at numeric ranges; we tighten here.
for oct := range strings.SplitSeq(candidate, ".") {
n := 0
for _, c := range oct {
if c < '0' || c > '9' {
return ""
}
n = n*10 + int(c-'0')
}
if n > 255 {
return ""
}
}
}
return candidate
}
func stripNonDigits(s string) string {
var b strings.Builder
b.Grow(len(s))
for _, c := range s {
if c >= '0' && c <= '9' {
b.WriteRune(c)
}
}
return b.String()
}
// verifyLuhn implements the Luhn checksum used by credit-card numbers.
// Returns true iff the digits pass.
func verifyLuhn(digits string) bool {
sum := 0
double := false
for i := len(digits) - 1; i >= 0; i-- {
d := int(digits[i] - '0')
if double {
d *= 2
if d > 9 {
d -= 9
}
}
sum += d
double = !double
}
return sum%10 == 0
}
// MaxPatternLength returns the longest MaxMatchLength across the input
// patterns. Used by the streaming filter to size its tail buffer.
func MaxPatternLength(patterns []Pattern) int {
max := 0
for _, p := range patterns {
if p.MaxMatchLength > max {
max = p.MaxMatchLength
}
}
return max
}

View File

@@ -4,212 +4,152 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"slices"
"sort"
"strings"
"sync"
"github.com/mudler/xlog"
)
// rawHit is one detection — regex-side or NER-side — before
// overlap-merging. Lifted to file scope so the regex and NER
// collectors can both produce them and feed the same merge/emit step.
// rawHit is one detection before overlap-merging. Lifted to file scope so
// the NER collector and the merge/emit step can share it.
type rawHit struct {
patternID string
action Action
start int
end int
score float32
}
// Redactor scans text against a configured pattern set and applies the
// per-pattern action. The pattern set itself is mutable at runtime via
// SetAction (the /api/pii/patterns/:id admin endpoint mutates it
// in-place); reads are guarded by a mutex so concurrent requests stay
// race-free.
type Redactor struct {
mu sync.RWMutex
patterns []Pattern
maxLen int
}
// Redactor is a stateless handle for the PII subsystem. The regex tier
// was removed: detection is driven entirely by per-model NER detectors
// (see RedactNER), whose policy lives on each detector model's
// pii_detection config. The type is retained (zero-field) as the
// on/off sentinel the application wiring and middleware gate on, so a
// nil *Redactor still means "PII subsystem unavailable".
type Redactor struct{}
// NewRedactor constructs a redactor from a list of compiled patterns
// (use Compile() to compile config-loaded patterns first). nil
// patterns is valid and produces a no-op redactor — convenient for the
// "PII disabled" deployment.
func NewRedactor(patterns []Pattern) *Redactor {
return &Redactor{
patterns: patterns,
maxLen: MaxPatternLength(patterns),
}
}
// MaxPatternLength is exposed so the streaming wrapper can size its
// tail buffer to match.
func (r *Redactor) MaxPatternLength() int { return r.maxLen }
// Patterns returns a copy of the configured pattern set so callers can
// iterate without holding the redactor lock. The compiled regexes are
// shared — they are immutable once built.
func (r *Redactor) Patterns() []Pattern {
r.mu.RLock()
defer r.mu.RUnlock()
return slices.Clone(r.patterns)
}
// SetAction overrides the action for a single pattern. Used by the
// /api/pii/patterns/:id admin endpoint and the set_pii_pattern_action
// MCP tool — transient until process restart unless persisted via
// --pii-config.
// RedactNER runs every configured NER detector over text, unions their
// detections, and emits one redacted output. Each NERConfig carries its
// own detector and policy (min score, entity→action map, default
// action), so a consuming model that references several detector models
// gets each model's policy applied to its own hits before the overlap
// merge (block > mask > allow) resolves any span two detectors both
// claim.
//
// Publishes a new slice so concurrent Redact callers iterating an
// older snapshot don't race on the per-element Action string (Go
// strings are not atomic two-word values).
func (r *Redactor) SetAction(id string, action Action) error {
if action != ActionMask && action != ActionBlock && action != ActionAllow {
return fmt.Errorf("unknown action %q (must be mask, block, or allow)", action)
}
r.mu.Lock()
defer r.mu.Unlock()
for i := range r.patterns {
if r.patterns[i].ID == id {
next := slices.Clone(r.patterns)
next[i].Action = action
r.patterns = next
return nil
}
}
return fmt.Errorf("unknown pattern id %q", id)
}
// SetDisabled toggles a pattern's enabled state in the live redactor.
// Same COW publish as SetAction.
func (r *Redactor) SetDisabled(id string, disabled bool) error {
r.mu.Lock()
defer r.mu.Unlock()
for i := range r.patterns {
if r.patterns[i].ID == id {
next := slices.Clone(r.patterns)
next[i].Disabled = disabled
r.patterns = next
return nil
}
}
return fmt.Errorf("unknown pattern id %q", id)
}
// Redact is a thin wrapper for callers that don't need per-request
// action overrides. It applies each pattern's compiled-in default
// action.
func (r *Redactor) Redact(text string) Result {
return r.RedactWithOverrides(text, nil)
}
// RedactWithOverrides scans text and returns the result. The override
// map is keyed by pattern id; when present, the value replaces the
// pattern's compiled-in action for this call only — the redactor's
// stored action is unchanged. Pattern ids missing from the map use
// their stored action.
// Any detector error is returned alongside a best-effort Result built
// from the detectors that did succeed, so the caller can fail closed
// (refuse the request) while still seeing what the healthy detectors
// found. Configs with a nil Detector are skipped.
//
// For every match it records a Span (with HashPrefix, never the value)
// and applies the resolved Action:
// - block: sets Result.Blocked, leaves text intact (caller decides
// whether to surface the redacted form).
// - mask: replaces the span with maskFor(pattern.ID), sets Result.Masked.
// - allow: leaves text intact and sets no flag (the span is still
// recorded so the match is auditable).
//
// Spans are returned in the original input's coordinate system so the
// PIIEvent record can be written without re-running the scan.
func (r *Redactor) RedactWithOverrides(text string, overrides map[string]Action) Result {
return r.redact(context.Background(), text, overrides, NERConfig{})
}
// RedactWithNER is the encoder-tier variant: runs both the regex tier
// (with per-pattern overrides) and the NER tier, merges hits, and
// emits one redacted output. A nil NERConfig.Detector skips the NER
// pass — callers can hand the same path the same NERConfig{} whether
// or not the model has NER configured.
//
// Errors from the NER detector are returned alongside a best-effort
// regex-only Result so the caller can decide whether to fail open
// (return the regex Result, log the error) or fail closed (refuse the
// request). The regex tier never errors.
func (r *Redactor) RedactWithNER(ctx context.Context, text string, overrides map[string]Action, nerCfg NERConfig) (Result, error) {
if nerCfg.Detector == nil {
return r.redact(ctx, text, overrides, nerCfg), nil
}
hits, err := r.collectRegexHits(text, overrides)
if err != nil {
return Result{Redacted: text}, err
}
nerHits, nerErr := collectNERHits(ctx, text, nerCfg)
if nerErr != nil {
// Return the regex-only result so a NER-backend outage doesn't
// strip the cheap protection. Caller decides fail-open vs
// fail-closed via the returned error.
return mergeAndEmit(text, hits), nerErr
}
return mergeAndEmit(text, append(hits, nerHits...)), nil
}
// redact is the internal regex-only entry point. RedactWithOverrides
// is the public wrapper; RedactWithNER routes through here only when
// the NER detector is nil (so the call site doesn't need a separate
// "regex-only" code path).
func (r *Redactor) redact(_ context.Context, text string, overrides map[string]Action, _ NERConfig) Result {
hits, _ := r.collectRegexHits(text, overrides)
return mergeAndEmit(text, hits)
}
// collectRegexHits walks the configured pattern set against text and
// returns each verified match as a rawHit. The redactor lock is held
// only long enough to snapshot the pattern slice — regex evaluation
// runs lock-free against the snapshot, so SetAction/SetDisabled don't
// stall a long-running Redact.
func (r *Redactor) collectRegexHits(text string, overrides map[string]Action) ([]rawHit, error) {
r.mu.RLock()
patterns := r.patterns
r.mu.RUnlock()
if len(patterns) == 0 || text == "" {
return nil, nil
// Package-level (no Redactor state): both the in-band request middleware
// and the MITM request path call it with their own resolved []NERConfig.
func RedactNER(ctx context.Context, text string, cfgs []NERConfig) (Result, error) {
if text == "" || len(cfgs) == 0 {
return Result{Redacted: text}, nil
}
var hits []rawHit
for _, p := range patterns {
if p.regex == nil {
// Pattern declared but Compile() not called. Skip rather
// than panic; the caller already saw an error from Compile.
var firstErr error
for _, cfg := range cfgs {
if cfg.Detector == nil {
continue
}
if p.Disabled {
h, err := collectNERHits(ctx, text, cfg)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
action := p.Action
if override, ok := overrides[p.ID]; ok {
action = override
hits = append(hits, h...)
}
return mergeAndEmit(text, hits), firstErr
}
// segmentSeparator joins per-message texts into the single document
// RedactNERSegments scans. Two newlines read as a paragraph break to the
// NER encoder — neutral, in-distribution context — and never carry PII
// themselves, so a detected span landing on the separator can only be the
// fringe of an entity that started in a real segment.
const segmentSeparator = "\n\n"
// RedactNERSegments scans texts as ONE concatenated document and maps the
// detections back to one Result per input text. Scanning the segments
// together is what gives the NER tier conversational context: whether
// "jdoe_42" is a USERNAME or "4421" is a PIN is decided by the question
// asked in the *previous* message, and a bidirectional encoder only sees
// that context if both messages are in the same forward pass. (Measured on
// privacy-filter-multilingual: "4421" alone detects nothing; preceded by
// "What are the last four digits of your card?" it detects PIN at 0.726.)
//
// Span offsets in each Result are local to its text, so callers rewrite
// fields in place exactly as with per-text RedactNER. A hit that crosses a
// segment boundary is split and each fragment keeps the hit's action —
// conservative, and only possible for an entity the model stretched across
// the separator. Error semantics mirror RedactNER: best-effort results
// plus the first detector error, so callers can fail closed.
func RedactNERSegments(ctx context.Context, texts []string, cfgs []NERConfig) ([]Result, error) {
results := make([]Result, len(texts))
if len(texts) == 0 || len(cfgs) == 0 {
for i := range results {
results[i] = Result{Redacted: texts[i]}
}
idxs := p.regex.FindAllStringIndex(text, -1)
for _, idx := range idxs {
candidate := text[idx[0]:idx[1]]
if VerifyMatch(p.ID, candidate) == "" {
return results, nil
}
var joined strings.Builder
starts := make([]int, len(texts))
ends := make([]int, len(texts))
for i, t := range texts {
if i > 0 {
joined.WriteString(segmentSeparator)
}
starts[i] = joined.Len()
joined.WriteString(t)
ends[i] = joined.Len()
}
doc := joined.String()
var hits []rawHit
var firstErr error
for _, cfg := range cfgs {
if cfg.Detector == nil {
continue
}
h, err := collectNERHits(ctx, doc, cfg)
if err != nil {
if firstErr == nil {
firstErr = err
}
continue
}
hits = append(hits, h...)
}
perSegment := make([][]rawHit, len(texts))
for _, h := range hits {
for i := range texts {
s := max(h.start, starts[i])
e := min(h.end, ends[i])
if s >= e {
continue
}
hits = append(hits, rawHit{
patternID: p.ID,
action: action,
start: idx[0],
end: idx[1],
})
local := h
local.start = s - starts[i]
local.end = e - starts[i]
perSegment[i] = append(perSegment[i], local)
}
}
return hits, nil
for i := range texts {
results[i] = mergeAndEmit(texts[i], perSegment[i])
}
return results, firstErr
}
// collectNERHits invokes the configured NERDetector and converts each
// returned entity into a rawHit using the NERConfig's action map.
// Entities below MinScore or with no resolved action are dropped — the
// detector doesn't know which entity groups the admin cares about, so
// the redactor filters here.
// the policy filters here.
func collectNERHits(ctx context.Context, text string, cfg NERConfig) ([]rawHit, error) {
if cfg.Detector == nil || text == "" {
return nil, nil
@@ -220,42 +160,58 @@ func collectNERHits(ctx context.Context, text string, cfg NERConfig) ([]rawHit,
}
var hits []rawHit
for _, e := range entities {
// One DEBUG line per raw detection with the model's confidence, the
// byte range, the matched substring, and the policy decision. This is
// the lowest-level view of why a request was masked/blocked — e.g. a
// phone number scored as SSN — and answers "what was in that range and
// how sure was the model" without re-running the detector. DEBUG-gated
// because the matched value is sensitive.
if e.Score < cfg.MinScore {
xlog.Debug("pii/ner: detection dropped (below min score)",
"group", e.Group, "score", e.Score, "min_score", cfg.MinScore,
"start", e.Start, "end", e.End, "text", e.Text)
continue
}
action, ok := cfg.ResolveAction(e.Group)
if !ok {
xlog.Debug("pii/ner: detection ignored (no action for group)",
"group", e.Group, "score", e.Score,
"start", e.Start, "end", e.End, "text", e.Text)
continue
}
if e.Start < 0 || e.End <= e.Start || e.End > len(text) {
// Defensive: the backend should return byte offsets into
// the original text, but a misconfigured model could
// produce garbage. Skip rather than panic on slice OOB.
// Defensive: the backend should return byte offsets into the
// original text, but a misconfigured model could produce
// garbage. Skip rather than panic on slice OOB.
xlog.Warn("pii/ner: detection has out-of-range offsets; skipping",
"group", e.Group, "start", e.Start, "end", e.End, "text_len", len(text))
continue
}
xlog.Debug("pii/ner: detection accepted",
"group", e.Group, "score", e.Score, "action", action,
"start", e.Start, "end", e.End, "text", e.Text)
hits = append(hits, rawHit{
patternID: nerPatternID(e.Group),
patternID: cfg.patternID(e.Group),
action: action,
start: e.Start,
end: e.End,
score: e.Score,
})
}
return hits, nil
}
// mergeAndEmit handles the overlap-merge + masked-output step that
// regex-only and combined regex+NER redactions both perform. Sorts by
// mergeAndEmit handles the overlap-merge + masked-output step. Sorts by
// start (stable on equal starts by descending action strength), drops
// overlapping hits in favour of the stronger action, and walks the
// text once to emit replacement spans.
// overlapping hits in favour of the stronger action, and walks the text
// once to emit replacement spans.
func mergeAndEmit(text string, hits []rawHit) Result {
if len(hits) == 0 {
return Result{Redacted: text}
}
// Sort and deduplicate overlapping hits — when two patterns claim
// the same span (e.g., a credit-card-shaped value also scans as
// digits, or NER tags a span the regex also caught), keep the one
// with the strongest action. Order: block > mask > allow.
// Sort and deduplicate overlapping hits — when two detectors claim
// the same span, keep the one with the strongest action. Order:
// block > mask > allow.
sort.Slice(hits, func(i, j int) bool {
if hits[i].start != hits[j].start {
return hits[i].start < hits[j].start
@@ -270,6 +226,7 @@ func mergeAndEmit(text string, hits []rawHit) Result {
if actionRank(h.action) > actionRank(last.action) {
last.action = h.action
last.patternID = h.patternID
last.score = h.score
}
if h.end > last.end {
last.end = h.end
@@ -291,6 +248,8 @@ func mergeAndEmit(text string, hits []rawHit) Result {
End: h.end,
Pattern: h.patternID,
HashPrefix: hashPrefix(matched),
Action: h.action,
Score: h.score,
}
res.Spans = append(res.Spans, span)
@@ -315,17 +274,15 @@ func mergeAndEmit(text string, hits []rawHit) Result {
// maskFor returns the placeholder that replaces a matched span. The
// shape "[REDACTED:<id>]" is intentionally stable — it surfaces the
// pattern id back to the model, which is sometimes useful (e.g., the
// model can say "I see you redacted an email"). Admins who want a
// less informative replacement can build one in front of this.
// detector group back to the model (e.g. "I see you redacted an email").
func maskFor(patternID string) string {
return "[REDACTED:" + patternID + "]"
}
// hashPrefix returns the first 8 chars of sha256(value). Two calls
// with the same input produce the same prefix so an admin auditing
// the PIIEvent log can spot a recurring leak ("the same SSN appears
// 200 times this hour") without ever recovering the value.
// hashPrefix returns the first 8 chars of sha256(value). Two calls with
// the same input produce the same prefix so an admin auditing the
// PIIEvent log can spot a recurring leak without ever recovering the
// value.
func hashPrefix(value string) string {
sum := sha256.Sum256([]byte(value))
return hex.EncodeToString(sum[:])[:8]

View File

@@ -1,66 +0,0 @@
package pii
import (
"sync"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// Redactor_SetActionConcurrentRedact pins the SetAction copy-on-
// write contract: concurrent SetAction must not race with readers
// iterating an older patterns snapshot. Run with -race to surface the
// regression that motivated the COW (in-place mutation of the
// per-element Action string is not atomic).
var _ = Describe("Redactor", func() {
It("SetAction concurrent with Redact", func() {
patterns, err := Compile(DefaultPatterns())
Expect(err).NotTo(HaveOccurred(), "compile")
r := NewRedactor(patterns)
const writers = 4
const readers = 8
const iter = 100
var wg sync.WaitGroup
stop := make(chan struct{})
for w := 0; w < writers; w++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < iter; i++ {
select {
case <-stop:
return
default:
}
action := ActionMask
if i%2 == 0 {
action = ActionBlock
}
_ = r.SetAction("email", action)
}
}()
}
for rd := 0; rd < readers; rd++ {
wg.Add(1)
go func() {
defer wg.Done()
text := "contact alice@example.com please"
for i := 0; i < iter*2; i++ {
select {
case <-stop:
return
default:
}
_ = r.Redact(text)
}
}()
}
wg.Wait()
close(stop)
})
})

View File

@@ -1,186 +1,182 @@
package pii
import (
"context"
"errors"
"strings"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func mustCompile(ids ...string) []Pattern {
all := DefaultPatterns()
if len(ids) == 0 {
out, err := Compile(all)
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
return out
}
pickP := pick(all, ids)
out, err := Compile(pickP)
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
return out
// detect builds a single-detector []NERConfig that reports one entity
// over the whole input under the given group/action.
func oneShot(group string, action Action, start, end int) []NERConfig {
return []NERConfig{{
Detector: &stubNERDetector{entities: []NEREntity{{Group: group, Start: start, End: end, Score: 1}}},
EntityActions: map[string]Action{group: action},
}}
}
func pick(all []Pattern, ids []string) []Pattern {
keep := map[string]bool{}
for _, id := range ids {
keep[id] = true
}
var out []Pattern
for _, p := range all {
if keep[p.ID] {
out = append(out, p)
}
}
return out
}
var _ = Describe("RedactNER emission", func() {
ctx := context.Background()
var _ = Describe("Redactor", func() {
It("masks email", func() {
r := NewRedactor(mustCompile("email"))
res := r.Redact("Contact me at alice@example.com any time.")
Expect(res.Blocked).To(BeFalse(), "email is mask-action by default, should not block")
Expect(res.Redacted).To(ContainSubstring("[REDACTED:email]"))
It("masks with a [REDACTED:ner:GROUP] placeholder and records a hash prefix", func() {
res, err := RedactNER(ctx, "Contact me at alice@example.com any time.", oneShot("EMAIL", ActionMask, 14, 31))
Expect(err).NotTo(HaveOccurred())
Expect(res.Masked).To(BeTrue())
Expect(res.Blocked).To(BeFalse())
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:EMAIL]"))
Expect(res.Redacted).NotTo(ContainSubstring("alice@example.com"))
Expect(res.Spans).To(HaveLen(1))
Expect(res.Spans[0].HashPrefix).NotTo(BeEmpty(), "hash prefix must be set so audits can dedupe leaks")
})
It("masks SSN", func() {
r := NewRedactor(mustCompile("ssn"))
res := r.Redact("call me about SSN 123-45-6789 please")
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ssn]"))
It("labels pattern-detector hits with the pattern source, not ner", func() {
cfgs := []NERConfig{{
Detector: &stubNERDetector{entities: []NEREntity{{Group: "ANTHROPIC_KEY", Start: 4, End: 24, Score: 1}}},
EntityActions: map[string]Action{"ANTHROPIC_KEY": ActionMask},
Source: SourcePattern,
}}
res, err := RedactNER(ctx, "use sk-ant-aaaaaaaaaaaaaaaa now", cfgs)
Expect(err).NotTo(HaveOccurred())
Expect(res.Redacted).To(ContainSubstring("[REDACTED:pattern:ANTHROPIC_KEY]"))
Expect(res.Redacted).NotTo(ContainSubstring("[REDACTED:ner:"))
Expect(res.Spans).To(HaveLen(1))
Expect(res.Spans[0].Pattern).To(Equal("pattern:ANTHROPIC_KEY"))
})
It("uses Luhn for credit card", func() {
r := NewRedactor(mustCompile("credit_card"))
// 4111 1111 1111 1111 — canonical Luhn-valid Visa test number.
good := r.Redact("card: 4111 1111 1111 1111")
Expect(good.Spans).To(HaveLen(1))
Expect(good.Redacted).To(ContainSubstring("[REDACTED:credit_card]"))
// 4111 1111 1111 1112 — same shape, fails Luhn. Must NOT match.
bad := r.Redact("card: 4111 1111 1111 1112")
Expect(bad.Spans).To(BeEmpty(), "Luhn-invalid 16-digit run must not be redacted")
Expect(bad.Redacted).To(ContainSubstring("1112"), "Luhn-invalid input should pass through untouched")
It("block leaves the matched span intact and sets Blocked", func() {
res, err := RedactNER(ctx, "token sk-abcdef here", oneShot("PASSWORD", ActionBlock, 6, 15))
Expect(err).NotTo(HaveOccurred())
Expect(res.Blocked).To(BeTrue())
Expect(res.Redacted).To(ContainSubstring("sk-abcdef"), "block leaves the value intact for the caller to discard")
Expect(res.Spans[0].Action).To(Equal(ActionBlock))
})
It("validates IPv4 octets", func() {
r := NewRedactor(mustCompile("ipv4"))
good := r.Redact("server at 192.168.1.10 is up")
Expect(good.Spans).To(HaveLen(1))
// 999.999.999.999 — regex matches but octet > 255 must reject.
bad := r.Redact("not an ip: 999.999.999.999")
Expect(bad.Spans).To(BeEmpty(), "ipv4 with octet>255 must not match")
})
It("api_key defaults to block", func() {
r := NewRedactor(mustCompile("api_key_prefix"))
res := r.Redact("here's a token sk-abcdefghijklmnopqrstuvwxyz0123456789 to use")
Expect(res.Blocked).To(BeTrue(), "api_key default action is block; Result.Blocked must be true")
// The redacted output keeps the matched value when blocking — the
// caller is expected to refuse the request, not to forward a partial.
Expect(res.Redacted).To(ContainSubstring("sk-abcdefghijklmn"), "blocked actions leave the matched span intact for caller inspection")
})
It("preserves non-matching text", func() {
r := NewRedactor(mustCompile()) // all default patterns
in := "no PII here at all, just words and numbers like 42 and 1.5"
res := r.Redact(in)
Expect(res.Redacted).To(Equal(in), "non-PII input should pass through unchanged")
Expect(res.Spans).To(BeEmpty())
})
It("handles empty input", func() {
r := NewRedactor(mustCompile())
res := r.Redact("")
Expect(res.Redacted).To(BeEmpty())
Expect(res.Blocked).To(BeFalse())
It("allow leaves text intact but still records the span", func() {
res, err := RedactNER(ctx, "Hello Acme!", oneShot("ORG", ActionAllow, 6, 10))
Expect(err).NotTo(HaveOccurred())
Expect(res.Masked).To(BeFalse())
Expect(res.Blocked).To(BeFalse())
Expect(res.Redacted).To(Equal("Hello Acme!"))
Expect(res.Spans).To(HaveLen(1))
})
It("passes non-matching text through unchanged", func() {
det := &stubNERDetector{} // no entities
res, err := RedactNER(ctx, "no PII here, just words", []NERConfig{{Detector: det, DefaultAction: ActionMask}})
Expect(err).NotTo(HaveOccurred())
Expect(res.Redacted).To(Equal("no PII here, just words"))
Expect(res.Spans).To(BeEmpty())
})
It("nil patterns is a no-op", func() {
// Disabled-PII deployment: pii.NewRedactor(nil) is a no-op.
r := NewRedactor(nil)
res := r.Redact("alice@example.com sent it")
Expect(res.Redacted).To(Equal("alice@example.com sent it"))
It("handles empty input without calling the detector", func() {
det := &stubNERDetector{entities: []NEREntity{{Group: "X", Start: 0, End: 1, Score: 1}}}
res, err := RedactNER(ctx, "", []NERConfig{{Detector: det, DefaultAction: ActionMask}})
Expect(err).NotTo(HaveOccurred())
Expect(res.Redacted).To(BeEmpty())
Expect(res.Spans).To(BeEmpty())
Expect(det.calls).To(Equal(0))
})
It("hash prefix is stable", func() {
r := NewRedactor(mustCompile("email"))
a := r.Redact("a@b.com")
b := r.Redact("hi a@b.com again")
It("produces a stable hash prefix for the same matched value", func() {
a, _ := RedactNER(ctx, "a@b.com", oneShot("EMAIL", ActionMask, 0, 7))
b, _ := RedactNER(ctx, "hi a@b.com", oneShot("EMAIL", ActionMask, 3, 10))
Expect(a.Spans).To(HaveLen(1))
Expect(b.Spans).To(HaveLen(1))
Expect(a.Spans[0].HashPrefix).To(Equal(b.Spans[0].HashPrefix), "same matched value must produce same hash prefix")
})
})
var _ = Describe("Compile", func() {
It("rejects unknown pattern id", func() {
_, err := Compile([]Pattern{{ID: "nonexistent", Action: ActionMask}})
Expect(err).To(HaveOccurred(), "Compile must error on unknown pattern id")
// funcNERDetector computes entities from the text it is handed — used to
// prove the segment scan gives the detector the JOINED document, the way a
// context-sensitive encoder behaves.
type funcNERDetector struct {
fn func(text string) ([]NEREntity, error)
}
func (f *funcNERDetector) Detect(_ context.Context, text string) ([]NEREntity, error) {
return f.fn(text)
}
// pinAfterCard mimics the real encoder's context sensitivity: "4421" is a
// PIN only when "card" appears earlier in the same document (measured on
// privacy-filter-multilingual: alone it detects nothing, with the eliciting
// question it detects PIN).
func pinAfterCard(text string) ([]NEREntity, error) {
i := strings.Index(text, "4421")
if i < 0 || !strings.Contains(text[:i], "card") {
return nil, nil
}
return []NEREntity{{Group: "PIN", Start: i, End: i + 4, Score: 0.9}}, nil
}
var _ = Describe("RedactNERSegments", func() {
ctx := context.Background()
maskCfg := func(d NERDetector) []NERConfig {
return []NERConfig{{Detector: d, DefaultAction: ActionMask}}
}
It("scans segments as one document so context crosses messages", func() {
det := &funcNERDetector{fn: pinAfterCard}
// Scanned alone the digits are invisible...
alone, err := RedactNER(ctx, "it is 4421 ok", maskCfg(det))
Expect(err).NotTo(HaveOccurred())
Expect(alone.Spans).To(BeEmpty())
// ...as a segment after the eliciting question they are detected,
// and the span maps back to the second segment with local offsets.
res, err := RedactNERSegments(ctx,
[]string{"What are the last four digits of your card?", "it is 4421 ok"},
maskCfg(det))
Expect(err).NotTo(HaveOccurred())
Expect(res).To(HaveLen(2))
Expect(res[0].Spans).To(BeEmpty())
Expect(res[0].Redacted).To(Equal("What are the last four digits of your card?"))
Expect(res[1].Spans).To(HaveLen(1))
Expect(res[1].Spans[0].Start).To(Equal(6))
Expect(res[1].Spans[0].End).To(Equal(10))
Expect(res[1].Masked).To(BeTrue())
Expect(res[1].Redacted).To(Equal("it is [REDACTED:ner:PIN] ok"))
})
It("splits a hit crossing a segment boundary, masking both fragments", func() {
det := &funcNERDetector{fn: func(text string) ([]NEREntity, error) {
i := strings.Index(text, "22 Baker")
j := strings.Index(text, "Street")
if i < 0 || j < 0 {
return nil, nil
}
return []NEREntity{{Group: "STREET", Start: i, End: j + len("Street"), Score: 0.9}}, nil
}}
res, err := RedactNERSegments(ctx, []string{"22 Baker", "Street"}, maskCfg(det))
Expect(err).NotTo(HaveOccurred())
Expect(res[0].Redacted).To(Equal("[REDACTED:ner:STREET]"))
Expect(res[1].Redacted).To(Equal("[REDACTED:ner:STREET]"))
})
It("returns best-effort results with the first detector error", func() {
bad := NERConfig{Detector: &stubNERDetector{err: errors.New("backend down")}, DefaultAction: ActionMask}
good := NERConfig{
Detector: &stubNERDetector{entities: []NEREntity{{Group: "PER", Start: 0, End: 5, Score: 0.9}}},
DefaultAction: ActionMask,
}
res, err := RedactNERSegments(ctx, []string{"Alice", "rest"}, []NERConfig{bad, good})
Expect(err).To(HaveOccurred())
Expect(res[0].Spans).To(HaveLen(1), "healthy detector's hits still apply")
})
It("is a per-text no-op without detectors or texts", func() {
res, err := RedactNERSegments(ctx, []string{"a", ""}, nil)
Expect(err).NotTo(HaveOccurred())
Expect(res).To(HaveLen(2))
Expect(res[0].Redacted).To(Equal("a"))
Expect(res[1].Redacted).To(Equal(""))
res, err = RedactNERSegments(ctx, nil, maskCfg(&stubNERDetector{}))
Expect(err).NotTo(HaveOccurred())
Expect(res).To(BeEmpty())
})
})
var _ = Describe("MaxPatternLength", func() {
It("returns the longest pattern's max length", func() {
patterns := mustCompile("email", "ssn")
got := MaxPatternLength(patterns)
// email is the longer of the two (254). The streaming filter
// will use this to size its tail buffer.
Expect(got).To(Equal(254))
})
})
var _ = Describe("RedactWithOverrides", func() {
It("upgrades action", func() {
// email is mask by default; the per-model override turns it into a
// hard block for one request without mutating the redactor.
r := NewRedactor(mustCompile("email"))
res := r.RedactWithOverrides("contact alice@example.com",
map[string]Action{"email": ActionBlock})
Expect(res.Blocked).To(BeTrue(), "override should have set Blocked")
// Block leaves the value intact (the caller short-circuits the
// request) — the redactor never echoes the matched text.
Expect(res.Redacted).To(ContainSubstring("alice@example.com"), "block leaves text intact for the caller to discard")
// Stored action is unchanged so a subsequent default Redact still
// masks rather than blocks.
res2 := r.Redact("contact alice@example.com")
Expect(res2.Blocked).To(BeFalse(), "override must not mutate stored action")
})
It("ignores unknown IDs", func() {
// An override for a pattern this redactor doesn't know about is a
// no-op rather than an error — per-model configs may reference
// patterns from a wider catalogue than the active redactor holds.
r := NewRedactor(mustCompile("email"))
res := r.RedactWithOverrides("contact alice@example.com",
map[string]Action{"ssn": ActionBlock})
Expect(res.Blocked).To(BeFalse(), "ssn override against email-only redactor must be no-op")
})
})
var _ = Describe("SetAction", func() {
It("swaps in place", func() {
r := NewRedactor(mustCompile("email"))
Expect(r.SetAction("email", ActionAllow)).To(Succeed())
res := r.Redact("contact alice@example.com")
Expect(res.Masked).To(BeFalse(), "allow leaves text intact, so nothing is masked")
Expect(res.Redacted).To(ContainSubstring("alice@example.com"), "allow should leave the match in place")
Expect(res.Spans).To(HaveLen(1), "allow still records the match")
Expect(res.Blocked).To(BeFalse(), "SetAction(allow) should not block")
})
It("rejects unknown id", func() {
r := NewRedactor(mustCompile("email"))
Expect(r.SetAction("nonexistent", ActionMask)).NotTo(Succeed(), "expected error for unknown pattern id")
})
It("rejects unknown action", func() {
r := NewRedactor(mustCompile("email"))
Expect(r.SetAction("email", Action("frobnicate"))).NotTo(Succeed(), "expected error for unknown action")
})
})

View File

@@ -27,7 +27,10 @@ type ListQuery struct {
UserID string
PatternID string
Kind EventKind
Limit int
// Origin scopes the search to redaction events from one surface
// (middleware | proxy | pii_analyze | pii_redact); empty matches any.
Origin Origin
Limit int
}
// NewMemoryEventStore returns an in-memory ring-buffer event store.
@@ -91,6 +94,9 @@ func (s *memoryEventStore) List(_ context.Context, q ListQuery) ([]PIIEvent, err
if q.Kind != "" && e.ResolvedKind() != q.Kind {
return false
}
if q.Origin != "" && e.Origin != q.Origin {
return false
}
out = append(out, e)
return len(out) >= limit
}

View File

@@ -0,0 +1,48 @@
package pii
import (
"context"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("EventStore Origin filter", func() {
var store EventStore
ctx := context.Background()
BeforeEach(func() {
store = NewMemoryEventStore(0)
// Three redaction events from three different surfaces.
for _, o := range []Origin{OriginMiddleware, OriginRedactAPI, OriginAnalyzeAPI} {
Expect(store.Record(ctx, PIIEvent{
ID: NewEventID(),
Kind: KindPII,
Origin: o,
PatternID: "ner:EMAIL",
})).To(Succeed())
}
// An older row with no Origin (pre-field) must not match any origin filter.
Expect(store.Record(ctx, PIIEvent{ID: NewEventID(), Kind: KindPII, PatternID: "ner:EMAIL"})).To(Succeed())
})
It("returns only events from the requested origin", func() {
got, err := store.List(ctx, ListQuery{Origin: OriginRedactAPI})
Expect(err).ToNot(HaveOccurred())
Expect(got).To(HaveLen(1))
Expect(got[0].Origin).To(Equal(OriginRedactAPI))
})
It("an empty origin matches every event (including pre-field rows)", func() {
got, err := store.List(ctx, ListQuery{})
Expect(err).ToNot(HaveOccurred())
Expect(got).To(HaveLen(4))
})
It("does not match a pre-field (empty-origin) row against a concrete origin", func() {
got, err := store.List(ctx, ListQuery{Origin: OriginMiddleware})
Expect(err).ToNot(HaveOccurred())
Expect(got).To(HaveLen(1))
Expect(got[0].Origin).To(Equal(OriginMiddleware))
})
})

View File

@@ -1,198 +0,0 @@
package pii
import (
"context"
"crypto/rand"
"encoding/hex"
"strings"
"time"
"unicode/utf8"
)
// StreamFilter applies the regex PII tier to a streaming response,
// chunk by chunk, with a buffered-emit invariant: for any active
// pattern with bounded max-length L, the filter never emits the
// trailing L-1 characters of the cumulative input until either
//
// (a) more text arrives that disambiguates the boundary, or
// (b) the stream closes (Drain).
//
// That keeps the redactor honest across chunk splits — an email
// arriving as "alice@" + "example.com" still masks the same way as
// "alice@example.com" arriving in one piece.
//
// Action handling in stream mode differs from the request-side
// middleware. Earlier chunks of the response are already on the wire
// by the time later chunks are scanned, so a "block" can't actually
// reject the request. We remap block → mask for redaction purposes
// while still recording PIIEvent rows with action="block" so audits
// surface the original intent ("the model would have leaked X here,
// suppressed in flight"). allow on the output side is a no-op — the
// text is left intact, matching its request-side detect-and-log
// behaviour.
//
// StreamFilter is NOT safe for concurrent use across goroutines; one
// instance per response stream.
type StreamFilter struct {
redactor *Redactor
maskOverrides map[string]Action // block → mask map used for redaction
auditActions map[string]Action // original action per pattern, used for events
store EventStore
correlationID string
userID string
holdLen int
buffer strings.Builder
emittedBytes int
}
// NewStreamFilter constructs a per-response filter. modelOverrides is
// the per-model action override map (same shape the request-side
// middleware uses); it can be nil when the model only accepts global
// defaults.
//
// store may be nil — events are then computed but not persisted, which
// is what the chat handler does when --disable-stats is set.
func NewStreamFilter(redactor *Redactor, modelOverrides map[string]Action, store EventStore, correlationID, userID string) *StreamFilter {
if redactor == nil {
return &StreamFilter{}
}
patterns := redactor.Patterns()
// auditActions: the action we *would* have applied if this match
// occurred on the request side. Honours the per-model override.
auditActions := make(map[string]Action, len(patterns))
for _, p := range patterns {
auditActions[p.ID] = p.Action
}
for id, action := range modelOverrides {
auditActions[id] = action
}
// maskOverrides: the action we actually apply to the stream. Same
// as auditActions, but with every block remapped to mask.
maskOverrides := make(map[string]Action, len(auditActions))
for id, action := range auditActions {
if action == ActionBlock {
maskOverrides[id] = ActionMask
} else {
maskOverrides[id] = action
}
}
return &StreamFilter{
redactor: redactor,
maskOverrides: maskOverrides,
auditActions: auditActions,
store: store,
correlationID: correlationID,
userID: userID,
holdLen: redactor.MaxPatternLength() - 1,
}
}
// Push appends new text to the filter's buffer and returns the prefix
// safe to emit downstream — the cumulative input minus a tail of
// holdLen characters that might still be the start of a longer match.
// Returned text has masks already applied.
//
// Returns an empty string when not enough text has arrived to clear
// the hold window.
func (sf *StreamFilter) Push(text string) string {
if sf.redactor == nil || sf.holdLen <= 0 {
return text
}
sf.buffer.WriteString(text)
bufStr := sf.buffer.String()
n := len(bufStr)
if n <= sf.holdLen {
return ""
}
emitBoundary := n - sf.holdLen
// Scan the entire buffer. A match whose start is before the
// boundary but whose end runs past it crosses the window — pull
// the boundary back to match.start so the pattern stays whole in
// the buffer for the next Push to scan again.
full := sf.redactor.RedactWithOverrides(bufStr, sf.maskOverrides)
for _, span := range full.Spans {
if span.Start < emitBoundary && span.End > emitBoundary {
emitBoundary = span.Start
}
}
// holdLen is byte-sized but a chunk boundary may land mid-codepoint.
// Snap back to the nearest rune start so neither the emitted prefix
// nor the retained tail contains a split codepoint — otherwise the
// next regex scan over an invalid-UTF-8 prefix could mis-match.
for emitBoundary > 0 && emitBoundary < n && !utf8.RuneStart(bufStr[emitBoundary]) {
emitBoundary--
}
if emitBoundary <= 0 {
return ""
}
emitted := sf.applyAndEmit(bufStr[:emitBoundary])
sf.buffer.Reset()
sf.buffer.WriteString(bufStr[emitBoundary:])
return emitted
}
// Drain emits whatever's left in the buffer with all matches applied.
// Call exactly once when the stream closes — repeat calls return the
// empty string.
func (sf *StreamFilter) Drain() string {
if sf.redactor == nil {
return sf.buffer.String()
}
bufStr := sf.buffer.String()
if bufStr == "" {
return ""
}
emitted := sf.applyAndEmit(bufStr)
sf.buffer.Reset()
return emitted
}
// applyAndEmit runs the redactor over a committed-for-emit fragment,
// substitutes mask/block placeholders inline, and records one
// PIIEvent per matched span (with the audit action, not the masked
// one). ByteOffset is referenced to the cumulative emitted output so
// admins can correlate event positions against the streamed body.
func (sf *StreamFilter) applyAndEmit(fragment string) string {
res := sf.redactor.RedactWithOverrides(fragment, sf.maskOverrides)
output := res.Redacted
if len(res.Spans) > 0 {
now := time.Now().UTC()
for _, span := range res.Spans {
ev := PIIEvent{
ID: newStreamEventID(),
CorrelationID: sf.correlationID,
UserID: sf.userID,
Direction: DirectionOut,
PatternID: span.Pattern,
ByteOffset: sf.emittedBytes + span.Start,
Length: span.End - span.Start,
HashPrefix: span.HashPrefix,
Action: sf.auditActions[span.Pattern],
CreatedAt: now,
}
if sf.store != nil {
_ = sf.store.Record(context.Background(), ev)
}
}
}
sf.emittedBytes += len(fragment)
return output
}
func newStreamEventID() string {
var b [12]byte
_, _ = rand.Read(b[:])
return "pii_" + hex.EncodeToString(b[:])
}

View File

@@ -1,184 +0,0 @@
package pii
import (
"context"
"fmt"
"math/rand"
"strings"
"unicode/utf8"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func newStreamRedactor(ids ...string) *Redactor {
all := DefaultPatterns()
chosen := all
if len(ids) > 0 {
chosen = pick(all, ids)
}
patterns, err := Compile(chosen)
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
return NewRedactor(patterns)
}
var _ = Describe("StreamFilter", func() {
It("masks across chunks", func() {
// The most important streaming test: an email split arbitrarily
// across chunk boundaries must mask exactly the same way as one
// arriving in a single Push.
red := newStreamRedactor("email")
sf := NewStreamFilter(red, nil, nil, "", "")
// "alice@example.com" (17 bytes) split between '@' and 'e'.
out := ""
out += sf.Push("hi alice@")
out += sf.Push("example.com! end")
out += sf.Drain()
Expect(out).NotTo(ContainSubstring("alice@example.com"), "stream leaked email across chunk boundary")
Expect(out).To(ContainSubstring("[REDACTED:email]"))
})
It("block becomes mask", func() {
// api_key_prefix is block by default. In stream mode the earlier
// chunks are already on the wire so block is impossible — the
// filter remaps to mask while still recording action="block" so
// the audit log keeps the original intent.
red := newStreamRedactor("api_key_prefix")
store := NewMemoryEventStore(0)
defer func() { _ = store.Close() }()
sf := NewStreamFilter(red, nil, store, "corr-1", "user-1")
out := sf.Push("here is your token: sk-abcdefghijklmnopqrstuvwxyz0123456789 done")
out += sf.Drain()
Expect(out).NotTo(ContainSubstring("abcdefghijklmnopqrstuvwxyz0123456789"), "block-in-stream must mask, leaked the value")
Expect(out).To(ContainSubstring("[REDACTED:api_key_prefix]"))
events, _ := store.List(context.Background(), ListQuery{Limit: 10})
Expect(events).To(HaveLen(1))
Expect(events[0].Action).To(Equal(ActionBlock), "audit must record original block action")
Expect(events[0].Direction).To(Equal(DirectionOut), "stream events must be DirectionOut")
})
It("no match passthrough", func() {
red := newStreamRedactor("email")
sf := NewStreamFilter(red, nil, nil, "", "")
out := sf.Push("perfectly clean text that should") + sf.Push(" pass through unchanged.") + sf.Drain()
Expect(out).To(Equal("perfectly clean text that should pass through unchanged."))
})
It("nil redactor passthrough", func() {
// --disable-pii path: NewStreamFilter(nil, ...) returns a filter
// that just forwards Push input verbatim.
sf := NewStreamFilter(nil, nil, nil, "", "")
out := sf.Push("any old text including alice@example.com") + sf.Drain()
Expect(out).To(Equal("any old text including alice@example.com"))
})
It("per-model overrides", func() {
// email defaults to mask; per-model override upgrades to block.
// In stream mode the override still maps to mask placeholder, but
// the audit event records action="block".
red := newStreamRedactor("email")
store := NewMemoryEventStore(0)
defer func() { _ = store.Close() }()
sf := NewStreamFilter(red, map[string]Action{"email": ActionBlock}, store, "corr-2", "user-2")
out := sf.Push("contact alice@example.com please") + sf.Drain()
Expect(out).NotTo(ContainSubstring("alice@example.com"), "override block-in-stream must mask")
events, _ := store.List(context.Background(), ListQuery{Limit: 10})
Expect(events).To(HaveLen(1))
Expect(events[0].Action).To(Equal(ActionBlock))
})
// StreamFilter_BufferedEmitInvariant feeds the redactor a corpus
// one rune at a time, randomly chunked, and asserts:
//
// 1. Across all (input, splitting) pairs, the cumulative emitted
// output never contains any of the secret values that were
// embedded in the input.
// 2. The output, fully drained, equals what Redact would have
// produced on the unsplit input.
//
// This is the load-bearing property of streaming PII: regardless of
// where chunks split, the emitted bytes cannot contain a value that a
// single-shot redactor would have masked.
It("buffered emit invariant", func() {
corpus := []struct {
text string
secrets []string
}{
{"contact alice@example.com or bob@example.org", []string{"alice@example.com", "bob@example.org"}},
{"my SSN is 123-45-6789 and his is 987-65-4321", []string{"123-45-6789", "987-65-4321"}},
{"sk-abcdefghijklmnopqrstuvwxyz0123456789 leaked", []string{"sk-abcdefghijklmnopqrstuvwxyz0123456789"}},
{"repeats: alice@example.com / alice@example.com / alice@example.com", []string{"alice@example.com"}},
// Multibyte UTF-8 corpora pin the rune-boundary snap in
// StreamFilter.Push: holdLen is byte-sized, so a chunk boundary
// may land mid-codepoint. Without the snap, the retained tail
// has a partial codepoint and the next regex scan can mis-align.
// Each entry mixes ASCII secrets with surrounding multibyte text
// so a byte-aligned cut would land inside a CJK or accented
// character on at least some splits.
{"こんにちは alice@example.com さようなら", []string{"alice@example.com"}},
{"クレジットカード: 4111-1111-1111-1111 終わり", []string{"4111-1111-1111-1111"}},
{"naïve résumé: alice@example.com, façade", []string{"alice@example.com"}},
}
red := newStreamRedactor() // all default patterns
rng := rand.New(rand.NewSource(1)) // seeded for reproducibility
for _, tc := range corpus {
for trial := 0; trial < 10; trial++ {
sf := NewStreamFilter(red, nil, nil, "", "")
var out strings.Builder
for i := 0; i < utf8.RuneCountInString(tc.text); {
// Random chunk size 1-8 runes, never crossing the end.
chunk := 1 + rng.Intn(8)
if i+chunk > utf8.RuneCountInString(tc.text) {
chunk = utf8.RuneCountInString(tc.text) - i
}
out.WriteString(sf.Push(stringSlice(tc.text, i, i+chunk)))
i += chunk
}
out.WriteString(sf.Drain())
result := out.String()
// Property 1: no secret value appears anywhere in the
// output.
for _, secret := range tc.secrets {
Expect(result).NotTo(ContainSubstring(secret),
fmt.Sprintf("trial %d: secret %q leaked through streaming\n input: %q\n output: %q", trial, secret, tc.text, result))
}
// Property 2: the streamed output equals what a single-shot
// Redact would have produced on the same input. (Block
// patterns get masked in stream mode, so we compare against
// a remapped redaction.)
expected := singleShotMaskAll(red, tc.text)
Expect(result).To(Equal(expected),
fmt.Sprintf("trial %d: stream != single-shot\n input: %q", trial, tc.text))
}
}
})
})
// singleShotMaskAll runs the redactor in one pass with all blocks
// remapped to mask — the same view the StreamFilter produces.
func singleShotMaskAll(red *Redactor, text string) string {
patterns := red.Patterns()
overrides := make(map[string]Action, len(patterns))
for _, p := range patterns {
if p.Action == ActionBlock {
overrides[p.ID] = ActionMask
}
}
res := red.RedactWithOverrides(text, overrides)
return res.Redacted
}
func stringSlice(s string, fromRune, toRune int) string {
runes := []rune(s)
return string(runes[fromRune:toRune])
}

View File

@@ -62,10 +62,12 @@ const (
// substring slicing; call sites that need to log it strip it via
// HashPrefix.
type Span struct {
Start int
End int
Pattern string // matches Pattern.ID
HashPrefix string // first 8 chars of sha256(matched value); audit-safe
Start int
End int
Pattern string // synthetic detector id, "<source>:<GROUP>" (e.g. "ner:EMAIL", "pattern:ANTHROPIC_KEY")
HashPrefix string // first 8 chars of sha256(matched value); audit-safe
Action Action // the action that fired for this span (after merge)
Score float32 // detector confidence for the (winning) hit, 0..1
}
// Result is what Redact returns. Redacted is the input string after
@@ -88,30 +90,6 @@ type Result struct {
Masked bool
}
// Pattern is one configurable rule. Description is shown in the admin
// UI alongside the pattern; the regex itself stays an implementation
// detail (a leak-prone admin showing an SSN regex with a sample value
// in the field is a risk we deliberately design around).
type Pattern struct {
ID string
Description string
Action Action
// Disabled skips the pattern entirely when true — useful for
// admins who want to keep a regex around (visible in the UI) but
// turn it off without removing the YAML entry. Default-false so
// every existing pattern stays active without touching its config.
Disabled bool
// MaxMatchLength is the longest possible match in characters. The
// streaming filter (subsystem 3, follow-up commit) uses this to
// size its tail buffer. For regex patterns we compute it at
// compile time from the pattern's structure when possible, or set
// a conservative upper bound otherwise.
MaxMatchLength int
// internal — populated by Compile().
regex regexpMatcher
}
// EventKind classifies a stored audit event. The store is shared by the
// PII filter (its original use), the MITM proxy (connect decisions and
// per-request traffic counters), and — when subsystem 2 lands — the
@@ -135,6 +113,20 @@ const (
KindAdmission EventKind = "admission"
)
// Origin labels which surface produced a redaction event, so the events
// log distinguishes an inline chat redaction from a MITM-proxy one and
// from an explicit /api/pii/{analyze,redact} call. It is set on PII
// redaction events only (Kind KindPII); connection/admission events leave
// it empty. An empty Origin on an older row reads as "unknown".
type Origin = string
const (
OriginMiddleware Origin = "middleware" // in-band chat/completions PII middleware
OriginProxy Origin = "proxy" // cloud-proxy MITM input path
OriginAnalyzeAPI Origin = "pii_analyze" // POST /api/pii/analyze
OriginRedactAPI Origin = "pii_redact" // POST /api/pii/redact
)
// PIIEvent is the persisted record. The Hash field is the first 8 chars
// of sha256(matched value) — enough to deduplicate "is this the same
// thing as last time" without ever storing the value itself.
@@ -146,6 +138,7 @@ const (
type PIIEvent struct {
ID string `json:"id"`
Kind EventKind `json:"kind,omitempty"`
Origin Origin `json:"origin,omitempty"`
CorrelationID string `json:"correlation_id,omitempty"`
UserID string `json:"user_id,omitempty"`
Direction Direction `json:"direction,omitempty"`
@@ -154,7 +147,11 @@ type PIIEvent struct {
Length int `json:"length,omitempty"`
HashPrefix string `json:"hash_prefix,omitempty"`
Action Action `json:"action,omitempty"`
CreatedAt time.Time `json:"created_at"`
// Score is the detector confidence (0..1) for an NER PII hit. Metadata
// only — never the matched value. Lets admins see how sure the model was
// about a (possibly false-positive) detection without re-running it.
Score float32 `json:"score,omitempty"`
CreatedAt time.Time `json:"created_at"`
Host string `json:"host,omitempty"`
Intercepted *bool `json:"intercepted,omitempty"`

View File

@@ -0,0 +1,119 @@
package piiadapter
import (
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services/routing/pii"
)
// OllamaChat returns a pii.Adapter for *schema.OllamaChatRequest (POST
// /api/chat). It scans each message's text content (Ollama messages carry a
// plain string, no multimodal block form) and writes redacted text back.
func OllamaChat() pii.Adapter {
return pii.Adapter{
Scan: func(parsed any) []pii.ScannedText {
req, ok := parsed.(*schema.OllamaChatRequest)
if !ok || req == nil {
return nil
}
var out []pii.ScannedText
for i := range req.Messages {
if req.Messages[i].Content != "" {
out = append(out, pii.ScannedText{Index: i, Text: req.Messages[i].Content})
}
}
return out
},
Apply: func(parsed any, updates []pii.ScannedText) {
req, ok := parsed.(*schema.OllamaChatRequest)
if !ok || req == nil {
return
}
for _, u := range updates {
if u.Index >= 0 && u.Index < len(req.Messages) {
req.Messages[u.Index].Content = u.Text
}
}
},
}
}
// Field selectors for OllamaGenerate (Prompt + System).
const (
ollamaGenPrompt = iota
ollamaGenSystem
)
// OllamaGenerate returns a pii.Adapter for *schema.OllamaGenerateRequest (POST
// /api/generate). It scans the Prompt and System strings.
func OllamaGenerate() pii.Adapter {
return pii.Adapter{
Scan: func(parsed any) []pii.ScannedText {
req, ok := parsed.(*schema.OllamaGenerateRequest)
if !ok || req == nil {
return nil
}
var out []pii.ScannedText
if req.Prompt != "" {
out = append(out, pii.ScannedText{Index: ollamaGenPrompt, Text: req.Prompt})
}
if req.System != "" {
out = append(out, pii.ScannedText{Index: ollamaGenSystem, Text: req.System})
}
return out
},
Apply: func(parsed any, updates []pii.ScannedText) {
req, ok := parsed.(*schema.OllamaGenerateRequest)
if !ok || req == nil {
return
}
for _, u := range updates {
switch u.Index {
case ollamaGenPrompt:
req.Prompt = u.Text
case ollamaGenSystem:
req.System = u.Text
}
}
},
}
}
// Field selectors for OllamaEmbed (Input + its Prompt alias). Reuses the
// shared encField/decField packing.
const (
ollamaEmbInput = iota
ollamaEmbPrompt
)
// OllamaEmbed returns a pii.Adapter for *schema.OllamaEmbedRequest (POST
// /api/embed, /api/embeddings). Input and its Prompt alias may be a string or
// a []any of strings; non-string elements are skipped.
func OllamaEmbed() pii.Adapter {
return pii.Adapter{
Scan: func(parsed any) []pii.ScannedText {
req, ok := parsed.(*schema.OllamaEmbedRequest)
if !ok || req == nil {
return nil
}
var out []pii.ScannedText
scanAnyText(ollamaEmbInput, req.Input, &out)
scanAnyText(ollamaEmbPrompt, req.Prompt, &out)
return out
},
Apply: func(parsed any, updates []pii.ScannedText) {
req, ok := parsed.(*schema.OllamaEmbedRequest)
if !ok || req == nil {
return
}
for _, u := range updates {
field, elem := decField(u.Index)
switch field {
case ollamaEmbInput:
req.Input = applyAnyText(req.Input, elem, u.Text)
case ollamaEmbPrompt:
req.Prompt = applyAnyText(req.Prompt, elem, u.Text)
}
}
},
}
}

View File

@@ -0,0 +1,46 @@
package piiadapter
import (
"github.com/mudler/LocalAI/core/schema"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
var _ = Describe("Ollama adapters", func() {
It("OllamaChat scans and rewrites message content", func() {
req := &schema.OllamaChatRequest{Messages: []schema.OllamaMessage{
{Role: "user", Content: "I'm alice@example.com"},
{Role: "assistant", Content: ""},
}}
a := OllamaChat()
Expect(a.Scan(req)).To(HaveLen(1))
applyAll(a, req, func(string) string { return "X" })
Expect(req.Messages[0].Content).To(Equal("X"))
Expect(req.Messages[1].Content).To(Equal(""))
})
It("OllamaGenerate scans Prompt and System", func() {
req := &schema.OllamaGenerateRequest{Prompt: "ssn 123", System: "be terse"}
a := OllamaGenerate()
Expect(a.Scan(req)).To(HaveLen(2))
applyAll(a, req, func(string) string { return "Y" })
Expect(req.Prompt).To(Equal("Y"))
Expect(req.System).To(Equal("Y"))
})
It("OllamaEmbed scans string and array Input, skipping non-strings", func() {
a := OllamaEmbed()
s := &schema.OllamaEmbedRequest{Input: "secret email"}
Expect(a.Scan(s)).To(HaveLen(1))
applyAll(a, s, func(string) string { return "Z" })
Expect(s.Input).To(Equal("Z"))
arr := &schema.OllamaEmbedRequest{Input: []any{"a secret", float64(1)}}
Expect(a.Scan(arr)).To(HaveLen(1))
applyAll(a, arr, func(string) string { return "Z" })
got, _ := arr.Input.([]any)
Expect(got).To(Equal([]any{"Z", float64(1)}))
})
})

View File

@@ -6,6 +6,8 @@
package piiadapter
import (
"strings"
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services/routing/pii"
)
@@ -74,17 +76,35 @@ func OpenAI() pii.Adapter {
}
msg := &req.Messages[msgIdx]
if blockIdx < 0 {
// Whole-string content.
// Whole-string content. Write BOTH the serializable
// Content and the StringContent staging buffer: the
// rendered-template path (evaluator.TemplateMessages,
// taken whenever use_tokenizer_template is off — e.g.
// cloud-proxy translate and Go-templated chat models)
// reads StringContent, not Content. Masking only Content
// would leave the original in StringContent and leak it
// to the backend/upstream.
msg.Content = u.Text
msg.StringContent = u.Text
continue
}
blocks, ok := msg.Content.([]any)
if !ok || blockIdx >= len(blocks) {
continue
}
if blockMap, ok := blocks[blockIdx].(map[string]any); ok {
blockMap["text"] = u.Text
blockMap, ok := blocks[blockIdx].(map[string]any)
if !ok {
continue
}
// Keep the StringContent projection in sync. For multimodal
// messages StringContent is the text blocks flattened with
// media markers injected (see middleware/request.go), so we
// can't just overwrite it — replace this block's original text
// run in place, preserving the markers around it.
if orig, ok := blockMap["text"].(string); ok && orig != "" && msg.StringContent != "" {
msg.StringContent = strings.Replace(msg.StringContent, orig, u.Text, 1)
}
blockMap["text"] = u.Text
}
},
}

View File

@@ -0,0 +1,91 @@
package piiadapter
import (
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services/routing/pii"
)
// Field selectors for the prompt-style OpenAI requests (/v1/completions,
// /v1/embeddings, /v1/edits), which carry user text in Prompt / Input /
// Instruction rather than Messages.
const (
fldPrompt = iota
fldInput
fldInstruction
)
// encField packs (field, element) into one ScannedText.Index. element=-1
// means the field is a whole string; element>=0 indexes into a []any value.
// Stored as element+1 so -1 maps to 0, with the field in the high bits.
func encField(field, elem int) int { return (field << 24) | (elem + 1) }
func decField(p int) (field, elem int) { return p >> 24, (p & 0xFFFFFF) - 1 }
// scanAnyText appends scannable strings from a string-or-[]any field. Non-string
// array elements (token-id arrays, numbers) are skipped — only human text is
// redacted.
func scanAnyText(field int, v any, out *[]pii.ScannedText) {
switch t := v.(type) {
case string:
if t != "" {
*out = append(*out, pii.ScannedText{Index: encField(field, -1), Text: t})
}
case []any:
for k, e := range t {
if s, ok := e.(string); ok && s != "" {
*out = append(*out, pii.ScannedText{Index: encField(field, k), Text: s})
}
}
}
}
// applyAnyText writes redacted text back to a string-or-[]any field, returning
// the (possibly replaced) value to assign back to the struct field.
func applyAnyText(v any, elem int, text string) any {
if elem < 0 {
return text
}
if arr, ok := v.([]any); ok && elem >= 0 && elem < len(arr) {
arr[elem] = text
}
return v
}
// OpenAICompletion returns a pii.Adapter for the prompt-style OpenAI requests
// (completions, embeddings, edits) on *schema.OpenAIRequest. It scans Prompt,
// Input and Instruction — the string form and the string elements of an array
// form — and writes redacted text back. Chat uses the separate OpenAI()
// adapter (Messages); these endpoints leave Messages empty and vice versa.
func OpenAICompletion() pii.Adapter {
return pii.Adapter{
Scan: func(parsed any) []pii.ScannedText {
req, ok := parsed.(*schema.OpenAIRequest)
if !ok || req == nil {
return nil
}
var out []pii.ScannedText
scanAnyText(fldPrompt, req.Prompt, &out)
scanAnyText(fldInput, req.Input, &out)
if req.Instruction != "" {
out = append(out, pii.ScannedText{Index: encField(fldInstruction, -1), Text: req.Instruction})
}
return out
},
Apply: func(parsed any, updates []pii.ScannedText) {
req, ok := parsed.(*schema.OpenAIRequest)
if !ok || req == nil {
return
}
for _, u := range updates {
field, elem := decField(u.Index)
switch field {
case fldPrompt:
req.Prompt = applyAnyText(req.Prompt, elem, u.Text)
case fldInput:
req.Input = applyAnyText(req.Input, elem, u.Text)
case fldInstruction:
req.Instruction = u.Text
}
}
},
}
}

View File

@@ -0,0 +1,59 @@
package piiadapter
import (
"github.com/mudler/LocalAI/core/schema"
"github.com/mudler/LocalAI/core/services/routing/pii"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
// applyAll feeds every scanned span back through Apply with the text
// transformed by fn — the shape the middleware uses (scan, redact, apply).
func applyAll(a pii.Adapter, parsed any, fn func(string) string) {
scanned := a.Scan(parsed)
updates := make([]pii.ScannedText, 0, len(scanned))
for _, s := range scanned {
updates = append(updates, pii.ScannedText{Index: s.Index, Text: fn(s.Text)})
}
a.Apply(parsed, updates)
}
var _ = Describe("OpenAICompletion adapter", func() {
a := OpenAICompletion()
It("scans and rewrites a string prompt", func() {
req := &schema.OpenAIRequest{}
req.Prompt = "contact alice@example.com"
got := a.Scan(req)
Expect(got).To(HaveLen(1))
Expect(got[0].Text).To(Equal("contact alice@example.com"))
applyAll(a, req, func(string) string { return "REDACTED" })
Expect(req.Prompt).To(Equal("REDACTED"))
})
It("scans array prompt elements and skips non-strings (token ids)", func() {
req := &schema.OpenAIRequest{}
req.Prompt = []any{"first secret", float64(42), "second secret"}
got := a.Scan(req)
Expect(got).To(HaveLen(2))
applyAll(a, req, func(s string) string { return "[X]" })
arr, _ := req.Prompt.([]any)
Expect(arr).To(Equal([]any{"[X]", float64(42), "[X]"}))
})
It("scans Input and Instruction (the edit/embeddings shape)", func() {
req := &schema.OpenAIRequest{Instruction: "fix the SSN 123-45-6789"}
req.Input = "my email is bob@example.com"
got := a.Scan(req)
Expect(got).To(HaveLen(2))
applyAll(a, req, func(string) string { return "*" })
Expect(req.Input).To(Equal("*"))
Expect(req.Instruction).To(Equal("*"))
})
It("returns nothing for an empty / non-matching request", func() {
Expect(a.Scan(&schema.OpenAIRequest{})).To(BeEmpty())
Expect(a.Scan(nil)).To(BeNil())
})
})

View File

@@ -54,6 +54,55 @@ var _ = Describe("OpenAI adapter", func() {
Expect(req.Messages[1].Content.(string)).To(Equal("REDACTED-1"))
})
It("Apply keeps StringContent in sync for string content", func() {
// Regression: the request middleware fills StringContent from Content
// at parse time, and the rendered-template path (TemplateMessages)
// reads StringContent, not Content. Apply must redact both or the
// original leaks to the backend/upstream (e.g. cloud-proxy translate).
req := &schema.OpenAIRequest{
Messages: []schema.Message{
{Role: "user", Content: "my key is sk-secret", StringContent: "my key is sk-secret"},
},
}
adapter := OpenAI()
scans := adapter.Scan(req)
Expect(scans).To(HaveLen(1))
scans[0].Text = "my key is [REDACTED]"
adapter.Apply(req, scans)
Expect(req.Messages[0].Content.(string)).To(Equal("my key is [REDACTED]"))
Expect(req.Messages[0].StringContent).To(Equal("my key is [REDACTED]"),
"StringContent (what TemplateMessages renders) must be redacted too")
})
It("Apply keeps StringContent in sync for content blocks, preserving media markers", func() {
// For multimodal content StringContent is the flattened text with
// media markers injected (request.go), so Apply must redact the text
// run in place rather than clobber the whole buffer.
req := &schema.OpenAIRequest{
Messages: []schema.Message{
{
Role: "user",
Content: []any{
map[string]any{"type": "text", "text": "leak sk-secret here"},
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,xyz"}},
},
StringContent: "leak sk-secret here<__media__>",
},
},
}
adapter := OpenAI()
scans := adapter.Scan(req)
Expect(scans).To(HaveLen(1))
scans[0].Text = "leak [REDACTED] here"
adapter.Apply(req, scans)
blocks := req.Messages[0].Content.([]any)
Expect(blocks[0].(map[string]any)["text"]).To(Equal("leak [REDACTED] here"))
Expect(req.Messages[0].StringContent).To(Equal("leak [REDACTED] here<__media__>"),
"StringContent must be redacted in place, keeping the media marker")
})
It("Apply mutates content block selectively", func() {
req := &schema.OpenAIRequest{
Messages: []schema.Message{

View File

@@ -0,0 +1,86 @@
// Package piidetector adapts the core/backend token-classification
// wrapper to the PII redactor's pii.NERDetector seam. It lives outside
// the pii package so pii stays free of core/backend imports (the
// redactor is unit-tested with stub detectors). The dependency runs one
// way: piidetector -> {core/backend, pii}.
package piidetector
import (
"context"
"unicode/utf8"
"github.com/mudler/xlog"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/routing/pii"
model "github.com/mudler/LocalAI/pkg/model"
)
// New builds a pii.NERDetector backed by the token-classification model
// in modelConfig. Phase 0: the Python `transformers` backend loaded with
// Type=TokenClassification; Phase 2: the GGML privacy-filter backend —
// both speak the same gRPC TokenClassify contract, so this adapter is
// unchanged across the swap. The model is resolved lazily on first
// Detect, so building a detector for a not-yet-loaded model is cheap and
// never blocks startup.
func New(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) pii.NERDetector {
return &nerDetector{
classifier: backend.NewTokenClassifier(loader, modelConfig, appConfig, backend.TokenClassifyOptions{}),
modelName: modelConfig.Name,
}
}
type nerDetector struct {
classifier backend.TokenClassifier
modelName string
}
// Detect runs the model and maps its spans onto pii.NEREntity. Offsets
// pass through as BYTE offsets per the TokenClassify proto contract.
// Spans whose offsets fall outside the text or land off a UTF-8 rune
// boundary are dropped: a bad offset must never reach the redactor,
// which splices text[Start:End] and would otherwise corrupt output or
// panic. The redactor applies NERConfig.MinScore and the entity->action
// map itself, so we deliberately return every (validated) span here.
//
// CONTRACT NOTE: the proto defines start/end as UTF-8 byte offsets. The
// Python transformers backend converts HuggingFace's codepoint offsets to
// bytes before responding (see TokenClassify in backend.py), and the GGML
// privacy-filter backend will emit bytes natively. The boundary check
// below is defense-in-depth against a backend that regresses to codepoint
// offsets: it downgrades the bug from "corrupted redaction / panic" to
// "dropped span + warning" rather than trusting the wire blindly.
func (d *nerDetector) Detect(ctx context.Context, text string) ([]pii.NEREntity, error) {
ents, err := d.classifier.TokenClassify(ctx, text)
if err != nil {
return nil, err
}
n := len(text)
out := make([]pii.NEREntity, 0, len(ents))
for _, e := range ents {
if e.Group == "" || e.Start < 0 || e.Start >= e.End || e.End > n {
xlog.Warn("pii NER: dropping span with invalid byte range",
"model", d.modelName, "group", e.Group, "start", e.Start, "end", e.End, "len", n)
continue
}
// text[e.Start] is safe (Start < End <= n => Start < n). End is
// exclusive: when End < n, text[End] is the first byte past the
// span and must itself start a rune. Off-boundary offsets are the
// signature of codepoint-vs-byte offset confusion.
if !utf8.RuneStart(text[e.Start]) || (e.End < n && !utf8.RuneStart(text[e.End])) {
xlog.Warn("pii NER: dropping span off UTF-8 boundary (offset units mismatch?)",
"model", d.modelName, "group", e.Group, "start", e.Start, "end", e.End)
continue
}
out = append(out, pii.NEREntity{
Group: e.Group,
Start: e.Start,
End: e.End,
Score: e.Score,
Text: e.Text,
})
}
return out, nil
}

View File

@@ -0,0 +1,80 @@
package piidetector
import (
"context"
"time"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/routing/pii"
"github.com/mudler/LocalAI/core/services/routing/piipattern"
"github.com/mudler/LocalAI/core/trace"
)
// NewPattern builds a pii.NERDetector that matches secrets with the restricted
// regex tier (built-ins + operator-defined patterns) instead of a neural model.
// It runs entirely in-process — no backend, GGUF, or VRAM — and the patterns
// compile once here, so an invalid pattern is reported now (the resolver fails
// closed) rather than per request. Matches are reported under their group with
// a deterministic Score of 1.0.
func NewPattern(modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (pii.NERDetector, error) {
custom := make([]piipattern.Pattern, 0, len(modelConfig.PIIDetection.Patterns))
for _, p := range modelConfig.PIIDetection.Patterns {
custom = append(custom, piipattern.Pattern{Group: p.Name, Pattern: p.Match, MinLen: p.MinLen})
}
m, err := piipattern.NewMatcher(modelConfig.PIIDetection.Builtins, custom)
if err != nil {
return nil, err
}
return &patternDetector{matcher: m, modelName: modelConfig.Name, appConfig: appConfig}, nil
}
type patternDetector struct {
matcher *piipattern.Matcher
modelName string
appConfig *config.ApplicationConfig
}
// Detect runs the compiled patterns and maps each match onto a pii.NEREntity.
// When tracing is enabled it records a pattern_pii BackendTrace so the matches
// (group, byte range, text) show in the Traces UI alongside NER detections.
func (d *patternDetector) Detect(_ context.Context, text string) ([]pii.NEREntity, error) {
var start time.Time
if d.appConfig != nil && d.appConfig.EnableTracing {
trace.InitBackendTracingIfEnabled(d.appConfig.TracingMaxItems, d.appConfig.TracingMaxBodyBytes)
start = time.Now()
}
matches := d.matcher.Find(text)
out := make([]pii.NEREntity, 0, len(matches))
var traceEnts []backend.TokenEntity
for _, mt := range matches {
out = append(out, pii.NEREntity{Group: mt.Group, Start: mt.Start, End: mt.End, Score: 1.0, Text: mt.Text})
if d.appConfig != nil && d.appConfig.EnableTracing {
traceEnts = append(traceEnts, backend.TokenEntity{Group: mt.Group, Start: mt.Start, End: mt.End, Score: 1.0, Text: mt.Text})
}
}
if d.appConfig != nil && d.appConfig.EnableTracing {
trace.RecordBackendTrace(patternPIITrace(d.modelName, text, traceEnts, start))
}
return out, nil
}
// patternPIITrace assembles the Traces-UI row for one pattern-detector run.
// Split out so the Data assembly is unit-testable without a request.
func patternPIITrace(modelName, text string, entities []backend.TokenEntity, start time.Time) trace.BackendTrace {
return trace.BackendTrace{
Timestamp: start,
Duration: time.Since(start),
Type: trace.BackendTracePatternPII,
ModelName: modelName,
Backend: "pattern",
Summary: trace.TruncateString(text, 200),
Data: map[string]any{
"input_chars": len(text),
"matches": len(entities),
"entities": entities,
},
}
}

View File

@@ -0,0 +1,61 @@
package piidetector_test
import (
"context"
"testing"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/services/routing/pii"
"github.com/mudler/LocalAI/core/services/routing/piidetector"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestPiidetector(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "piidetector suite")
}
func patternModel() config.ModelConfig {
c := config.ModelConfig{Name: "secret-filter", Backend: "pattern"}
c.PIIDetection.Builtins = []string{"anthropic_api_key"}
c.PIIDetection.Patterns = []config.PIIPattern{{Name: "INTERNAL_TOKEN", Match: `tok-[A-Za-z0-9]{8,}`}}
return c
}
var _ = Describe("pattern detector", func() {
It("matches built-in and custom secrets as whole-span deterministic hits", func() {
det, err := piidetector.NewPattern(patternModel(), &config.ApplicationConfig{})
Expect(err).NotTo(HaveOccurred())
ents, err := det.Detect(context.Background(), "use sk-ant-api03-AAAABBBBCCCCDDDDEEEE and tok-ABCD1234 ok")
Expect(err).NotTo(HaveOccurred())
byGroup := map[string]pii.NEREntity{}
for _, e := range ents {
byGroup[e.Group] = e
Expect(e.Score).To(BeEquivalentTo(float32(1.0)), "pattern matches are deterministic")
}
Expect(byGroup).To(HaveKey("ANTHROPIC_KEY"))
Expect(byGroup["INTERNAL_TOKEN"].Text).To(Equal("tok-ABCD1234"))
})
It("still detects (and exercises the trace path) with tracing enabled", func() {
det, err := piidetector.NewPattern(patternModel(), &config.ApplicationConfig{
EnableTracing: true, TracingMaxItems: 8,
})
Expect(err).NotTo(HaveOccurred())
ents, err := det.Detect(context.Background(), "sk-ant-api03-AAAABBBBCCCCDDDDEEEE")
Expect(err).NotTo(HaveOccurred())
Expect(ents).To(HaveLen(1))
Expect(ents[0].Group).To(Equal("ANTHROPIC_KEY"))
})
It("fails to build on an invalid (unanchored) custom pattern", func() {
c := config.ModelConfig{Name: "bad", Backend: "pattern"}
c.PIIDetection.Patterns = []config.PIIPattern{{Name: "X", Match: `.*`}}
_, err := piidetector.NewPattern(c, &config.ApplicationConfig{})
Expect(err).To(HaveOccurred())
})
})

View File

@@ -0,0 +1,61 @@
package piipattern
import "sort"
// Builtin is a named, ready-made secret pattern. Group is the uppercase entity
// label a match is reported under (so it keys into a detector model's
// pii_detection.entity_actions, exactly like an NER group). Every Builtin
// pattern is written in the restricted subset and is verified at test time to
// pass ValidatePattern and compile.
type Builtin struct {
Name string
Group string
Pattern string
Description string
}
// builtins is the curated catalogue. Patterns intentionally anchor on each
// provider's fixed prefix and require a long high-entropy tail, so they fire on
// real credentials and not on ordinary prose. Names are stable identifiers
// referenced from a model config's pii_detection.builtins list.
var builtins = []Builtin{
{"anthropic_api_key", "ANTHROPIC_KEY", `sk-ant-[A-Za-z0-9_-]{20,}`, "Anthropic API key (sk-ant-…)"},
{"openai_api_key", "OPENAI_KEY", `sk-(?:proj-)?[A-Za-z0-9_-]{20,}`, "OpenAI API key (sk-… / sk-proj-…)"},
{"github_token", "GITHUB_TOKEN", `(?:ghp|gho|ghs|ghr|ghu)_[A-Za-z0-9]{36,}`, "GitHub access token (ghp_/gho_/ghs_/ghr_/ghu_)"},
{"github_pat", "GITHUB_TOKEN", `github_pat_[A-Za-z0-9_]{20,}`, "GitHub fine-grained personal access token"},
{"aws_access_key", "AWS_ACCESS_KEY", `AKIA[0-9A-Z]{16}`, "AWS access key ID (AKIA…)"},
{"google_api_key", "GOOGLE_API_KEY", `AIza[0-9A-Za-z_-]{35}`, "Google API key (AIza…)"},
{"slack_token", "SLACK_TOKEN", `xox[baprs]-[0-9A-Za-z-]{10,}`, "Slack token (xoxb-/xoxa-/xoxp-/xoxr-/xoxs-)"},
{"stripe_key", "STRIPE_KEY", `(?:sk|rk)_live_[0-9A-Za-z]{16,}`, "Stripe live secret/restricted key"},
{"jwt", "JWT", `eyJ[A-Za-z0-9_-]{10,}\.eyJ[A-Za-z0-9_-]{10,}\.[A-Za-z0-9_-]{10,}`, "JSON Web Token (eyJ….eyJ….…)"},
{"private_key_block", "PRIVATE_KEY", `-----BEGIN [A-Z ]*PRIVATE KEY-----`, "PEM private-key header"},
}
// BuiltinCatalogue returns the built-in patterns sorted by name. Used by the
// config-metadata registry to populate the editor's builtins checklist.
func BuiltinCatalogue() []Builtin {
out := make([]Builtin, len(builtins))
copy(out, builtins)
sort.Slice(out, func(i, j int) bool { return out[i].Name < out[j].Name })
return out
}
// BuiltinNames returns the built-in pattern names, sorted.
func BuiltinNames() []string {
out := make([]string, 0, len(builtins))
for _, b := range builtins {
out = append(out, b.Name)
}
sort.Strings(out)
return out
}
// LookupBuiltin finds a built-in by name.
func LookupBuiltin(name string) (Builtin, bool) {
for _, b := range builtins {
if b.Name == name {
return b, true
}
}
return Builtin{}, false
}

View File

@@ -0,0 +1,20 @@
package piipattern
import "regexp"
// Compile validates src against the restricted grammar and, if it passes,
// compiles it to an RE2 program set to leftmost-longest matching so a hit grabs
// the whole secret (the entire key) rather than the shortest prefix.
func Compile(src string) (*regexp.Regexp, error) {
if err := ValidatePattern(src); err != nil {
return nil, err
}
re, err := regexp.Compile(src)
if err != nil {
// ValidatePattern already parsed with the same flags, so this is
// effectively unreachable, but surface it rather than panic.
return nil, err
}
re.Longest()
return re, nil
}

View File

@@ -0,0 +1,163 @@
// Package piipattern is a bounded, restricted-regex matcher for high-entropy,
// highly-regular secrets (API keys, tokens, private-key blocks) that the NER
// PII tier cannot catch — it has no credential class, so it fragments a key
// into the nearest-looking trained categories and may leave the secret part
// exposed.
//
// The language is a deliberately restricted subset of regular expressions
// compiled to Go's RE2 engine (regexp), which is linear-time with no
// backtracking — there is no ReDoS class of failure. On top of RE2 we cap the
// pattern source length, the {n,m} expansion bound, the pattern count, and the
// scanned input, and we require every pattern to carry a fixed literal
// "anchor". The anchor rule is what admits `sk-ant-…` / `ghp_…` style keys
// while rejecting open-ended shapes like an email address or a bare `\w+`
// (which would match almost anything) — those stay with the NER tier.
//
// This package is a leaf: it imports only the standard library, so both
// core/config (validation at load) and core/application (the resolver) can use
// it without an import cycle.
package piipattern
import (
"fmt"
"regexp/syntax"
)
const (
// MaxPatternLen caps the source length of a single pattern. Generous for a
// credential shape, small enough that the compiled program stays tiny.
MaxPatternLen = 256
// MaxQuantifier caps an explicit {n,m} upper bound. RE2 expands a bounded
// repeat into that many copies, so an uncapped {0,1000000} would blow up
// the compiled program's memory. Unbounded {n,} (no upper) is a loop, not
// an expansion, and is allowed.
MaxQuantifier = 4096
// MaxAlternation caps the arms of a single `a|b|c` alternation.
MaxAlternation = 64
// MaxAST bounds recursion depth so a pathologically nested pattern can't
// blow the stack during validation.
MaxAST = 64
// MinAnchorLen is the shortest fixed literal run a pattern must contain to
// be considered "anchored" to a recognisable secret prefix/shape.
MinAnchorLen = 3
)
// parseFlags enables Perl character classes (\w \d \s) and word boundaries,
// matching what regexp.Compile uses, so validation and compilation agree.
const parseFlags = syntax.Perl
// ValidatePattern reports whether src is an acceptable restricted-subset
// pattern. It returns a descriptive error naming the offending construct so an
// operator editing a model config gets actionable feedback (the error is
// surfaced by config Validate at load and by the resolver, which fails closed).
func ValidatePattern(src string) error {
if src == "" {
return fmt.Errorf("pattern is empty")
}
if len(src) > MaxPatternLen {
return fmt.Errorf("pattern is too long (%d chars; max %d)", len(src), MaxPatternLen)
}
re, err := syntax.Parse(src, parseFlags)
if err != nil {
return fmt.Errorf("invalid pattern: %w", err)
}
if err := walk(re, 0); err != nil {
return err
}
if anchorLen(re) < MinAnchorLen {
return fmt.Errorf("pattern must contain a fixed literal run of at least %d characters "+
"(e.g. \"sk-ant-\", \"ghp_\", \"AKIA\") so it is anchored to a recognisable secret; "+
"open-ended shapes like emails or bare \\w+ belong to the NER tier", MinAnchorLen)
}
return nil
}
// walk enforces the allow-list of regex constructs.
func walk(re *syntax.Regexp, depth int) error {
if depth > MaxAST {
return fmt.Errorf("pattern is too deeply nested")
}
switch re.Op {
case syntax.OpAnyChar, syntax.OpAnyCharNotNL:
return fmt.Errorf("'.' (any character) is not allowed; use an explicit class like [A-Za-z0-9]")
case syntax.OpCapture:
return fmt.Errorf("capturing groups are not allowed; use a non-capturing group (?:…) if you need grouping")
case syntax.OpRepeat:
if re.Min > MaxQuantifier || (re.Max >= 0 && re.Max > MaxQuantifier) {
return fmt.Errorf("{n,m} bound is too large (max %d)", MaxQuantifier)
}
case syntax.OpAlternate:
if len(re.Sub) > MaxAlternation {
return fmt.Errorf("too many alternation arms (%d; max %d)", len(re.Sub), MaxAlternation)
}
case syntax.OpLiteral, syntax.OpCharClass, syntax.OpConcat,
syntax.OpStar, syntax.OpPlus, syntax.OpQuest,
syntax.OpEmptyMatch,
syntax.OpBeginLine, syntax.OpEndLine, syntax.OpBeginText, syntax.OpEndText,
syntax.OpWordBoundary, syntax.OpNoWordBoundary:
// allowed
default:
return fmt.Errorf("unsupported construct in pattern")
}
for _, sub := range re.Sub {
if err := walk(sub, depth+1); err != nil {
return err
}
}
return nil
}
// anchorLen returns the number of fixed (non-space) literal characters every
// match of re is guaranteed to contain — the pattern's "anchor strength".
// Concatenation sums its parts; alternation takes the min (every arm must
// carry the anchor); a `+`/{n,} with n>=1 contributes its body's literal once;
// `*`, `?`, {0,m} and char classes/anchors contribute 0 (they may be absent).
//
// We sum rather than measure the longest contiguous run because RE2 factors
// common prefixes — `(?:ghp|gho|ghs)_…` parses to `gh[ops]_…`, whose longest
// contiguous literal is only "gh" (2) but whose guaranteed literals are
// "gh"+"_" (3). Summing keeps such real key prefixes admissible while still
// rejecting open-ended shapes: an email `[\w.]+@[\w.]+\.\w+` guarantees only
// `@` and `.` (2 < MinAnchorLen).
func anchorLen(re *syntax.Regexp) int {
switch re.Op {
case syntax.OpLiteral:
n := 0
for _, r := range re.Rune {
if r != ' ' && r != '\t' && r != '\n' && r != '\r' {
n++
}
}
return n
case syntax.OpConcat:
sum := 0
for _, sub := range re.Sub {
sum += anchorLen(sub)
}
return sum
case syntax.OpAlternate:
if len(re.Sub) == 0 {
return 0
}
min := -1
for _, sub := range re.Sub {
if a := anchorLen(sub); min < 0 || a < min {
min = a
}
}
return min
case syntax.OpPlus:
if len(re.Sub) == 1 {
return anchorLen(re.Sub[0])
}
return 0
case syntax.OpRepeat:
if re.Min >= 1 && len(re.Sub) == 1 {
return anchorLen(re.Sub[0])
}
return 0
default:
// char classes, anchors, OpStar, OpQuest carry no guaranteed literal.
return 0
}
}

View File

@@ -0,0 +1,100 @@
package piipattern
import (
"fmt"
"regexp"
)
const (
// MaxPatternsPerMatcher bounds how many patterns one detector may hold.
MaxPatternsPerMatcher = 128
// MaxMatchesPerPattern bounds matches emitted per pattern per call, so a
// pathological input can't produce an unbounded result set.
MaxMatchesPerPattern = 1000
)
// Pattern is one compiled-ready rule: matches are reported under Group, and a
// match shorter than MinLen bytes is dropped (0 = no floor).
type Pattern struct {
Group string
Pattern string
MinLen int
}
// Match is one detected span: a half-open byte range [Start,End) into the
// scanned text, the matched text, and the reporting Group.
type Match struct {
Group string
Start int
End int
Text string
}
type compiled struct {
group string
re *regexp.Regexp
minLen int
}
// Matcher holds a set of compiled patterns and scans text for all of them.
type Matcher struct {
pats []compiled
}
// NewMatcher compiles the named built-ins plus the custom patterns into a
// Matcher. Unknown built-in names and patterns that fail the restricted grammar
// are reported as errors (the caller fails closed). Built-in and custom counts
// together may not exceed MaxPatternsPerMatcher.
func NewMatcher(builtinNames []string, custom []Pattern) (*Matcher, error) {
if len(builtinNames)+len(custom) > MaxPatternsPerMatcher {
return nil, fmt.Errorf("too many patterns (%d; max %d)", len(builtinNames)+len(custom), MaxPatternsPerMatcher)
}
m := &Matcher{}
for _, name := range builtinNames {
b, ok := LookupBuiltin(name)
if !ok {
return nil, fmt.Errorf("unknown built-in pattern %q", name)
}
re, err := Compile(b.Pattern)
if err != nil {
return nil, fmt.Errorf("built-in %q: %w", name, err)
}
m.pats = append(m.pats, compiled{group: b.Group, re: re})
}
for _, p := range custom {
if p.Group == "" {
return nil, fmt.Errorf("custom pattern is missing a name/group")
}
re, err := Compile(p.Pattern)
if err != nil {
return nil, fmt.Errorf("pattern %q: %w", p.Group, err)
}
m.pats = append(m.pats, compiled{group: p.Group, re: re, minLen: p.MinLen})
}
return m, nil
}
// Find returns every match of every pattern over text. Spans from different
// patterns may overlap; the caller (the redactor) unions and resolves them.
func (m *Matcher) Find(text string) []Match {
if m == nil || text == "" {
return nil
}
var out []Match
for _, p := range m.pats {
locs := p.re.FindAllStringIndex(text, MaxMatchesPerPattern)
for _, loc := range locs {
start, end := loc[0], loc[1]
if end-start < p.minLen {
continue
}
out = append(out, Match{
Group: p.group,
Start: start,
End: end,
Text: text[start:end],
})
}
}
return out
}

View File

@@ -0,0 +1,105 @@
package piipattern
import (
"strings"
"testing"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
func TestPiipattern(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "piipattern suite")
}
var _ = Describe("ValidatePattern", func() {
DescribeTable("accepts anchored, bounded patterns",
func(src string) { Expect(ValidatePattern(src)).To(Succeed()) },
Entry("anthropic", `sk-ant-[A-Za-z0-9_-]{20,200}`),
Entry("github via alternation", `(?:ghp|gho|ghs)_[A-Za-z0-9]{36,}`),
Entry("custom token", `tok-\w{32,64}`),
Entry("aws", `AKIA[0-9A-Z]{16}`),
Entry("anchored by mid-literal", `(?:sk|rk)_live_[0-9A-Za-z]{16,}`),
)
DescribeTable("rejects unanchored or unsafe patterns",
func(src string) { Expect(ValidatePattern(src)).NotTo(Succeed()) },
Entry("email (no fixed anchor)", `[\w.]+@[\w.]+\.\w+`),
Entry("bare word run", `\w+`),
Entry("any-char greedy", `sk-.*`),
Entry("capturing group", `(sk-ant-[A-Za-z0-9]+)`),
Entry("two fixed chars only", `ab[0-9]{8,}`),
Entry("over-long source", "sk-ant-"+strings.Repeat("a", MaxPatternLen)),
Entry("huge bounded repeat", `sk-ant-[A-Za-z0-9]{5000}`),
Entry("empty", ``),
)
})
var _ = Describe("Compile", func() {
It("compiles a valid pattern with leftmost-longest semantics", func() {
re, err := Compile(`sk-ant-[A-Za-z0-9_-]{4,}`)
Expect(err).NotTo(HaveOccurred())
// Longest() makes the match span the whole key, not a shorter prefix.
loc := re.FindString("key sk-ant-AAAA1111bbbb end")
Expect(loc).To(Equal("sk-ant-AAAA1111bbbb"))
})
It("refuses an invalid pattern", func() {
_, err := Compile(`.*`)
Expect(err).To(HaveOccurred())
})
})
var _ = Describe("builtins", func() {
It("every built-in validates, compiles, and is uniquely named", func() {
seen := map[string]bool{}
for _, b := range BuiltinCatalogue() {
Expect(seen[b.Name]).To(BeFalse(), "duplicate builtin %s", b.Name)
seen[b.Name] = true
Expect(ValidatePattern(b.Pattern)).To(Succeed(), "builtin %s pattern %q", b.Name, b.Pattern)
}
})
DescribeTable("matches a real sample and not a decoy",
func(name, sample, decoy string) {
b, ok := LookupBuiltin(name)
Expect(ok).To(BeTrue())
re, err := Compile(b.Pattern)
Expect(err).NotTo(HaveOccurred())
Expect(re.MatchString(sample)).To(BeTrue(), "should match %q", sample)
Expect(re.MatchString(decoy)).To(BeFalse(), "should not match %q", decoy)
},
Entry("anthropic", "anthropic_api_key", "sk-ant-api03-AbCdEf012345_-AbCdEf012345", "sk-ant-short"),
Entry("aws", "aws_access_key", "AKIAIOSFODNN7EXAMPLE", "AKIAshort"),
Entry("github", "github_token", "ghp_"+strings.Repeat("a", 36), "ghp_short"),
)
})
var _ = Describe("Matcher", func() {
It("reports the whole key as one span under its group", func() {
m, err := NewMatcher([]string{"anthropic_api_key"}, nil)
Expect(err).NotTo(HaveOccurred())
got := m.Find("my key is sk-ant-api03-AbCdEf012345AbCdEf012345 thanks")
Expect(got).To(HaveLen(1))
Expect(got[0].Group).To(Equal("ANTHROPIC_KEY"))
Expect(got[0].Text).To(Equal("sk-ant-api03-AbCdEf012345AbCdEf012345"))
})
It("compiles custom patterns and honours MinLen", func() {
m, err := NewMatcher(nil, []Pattern{{Group: "INTERNAL", Pattern: `tok-[A-Za-z0-9]{4,}`, MinLen: 12}})
Expect(err).NotTo(HaveOccurred())
// "tok-AAAA" (8 bytes) is below MinLen 12 and is dropped.
Expect(m.Find("tok-AAAA")).To(BeEmpty())
Expect(m.Find("tok-AAAABBBBCCCC")).To(HaveLen(1))
})
It("fails closed on an unknown built-in", func() {
_, err := NewMatcher([]string{"nope"}, nil)
Expect(err).To(HaveOccurred())
})
It("rejects an invalid custom pattern", func() {
_, err := NewMatcher(nil, []Pattern{{Group: "X", Pattern: `.*`}})
Expect(err).To(HaveOccurred())
})
})