mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-19 06:09:07 -04:00
feat(pii): NER tier engine — privacy-filter.cpp backend + NER-centric PII filter (#10360)
Squashed feat/pii-ner-tier-engine rebased onto master (was 45 commits; see backup/pii-ner-tier-engine-prerebase). Net change: - privacy-filter.cpp: standalone GGML engine for the openai-privacy-filter PII/NER token classifier, wired as a LocalAI gRPC backend (CPU/CUDA/Vulkan). TokenClassify moves off the patched llama.cpp path onto this backend. - PII filter reworked to be NER-centric (encoder/NER detection tier scanning whole conversations as one document), with a recreated bounded restricted- regex secret-matching pattern detector tier alongside it (per-model pii_detection.builtins / .patterns + core/services/routing/piipattern). - Detection labelled by source (ner vs pattern); backend trace / confidence / debug observability; analyze/redact exposed as a synchronous API. - Instance-wide default detector policy + per-usecase default-on; request filtering extended to completions, embeddings, edits & Ollama. - React UI: NER-centric PII editor, detector-models table, pattern/builtins editor, middleware default-policy UI. - Gallery: privacy-filter-multilingual token-classify model + NER install filter; token_classify known_usecase; batch sized to context for NER models. privacy-filter backend registered in the backend gallery (cpu/vulkan/cuda-13 meta + image entries with a capabilities map) matching its CI matrix jobs, and an /import-model auto-detect importer (PrivacyFilterImporter, narrow privacy-filter GGUF detection) replacing the prior pref-only registration. Reconciled against master's independent evolution: - Dropped master's PIIPatternOverrides feature (global-pattern runtime overrides + /api/pii/patterns API + runtime_settings.json persistence). The per-model NER + pattern-detector design supersedes it; it was built on the global redactor pattern set this branch replaced. - Reverted the llama.cpp Score carry-patch (0006-server-task-type-score): removed the patch and restored master's grpc-server.cpp Score RPC (direct llama_decode, slot-loop bypass) and LLAMA_VERSION pin, plus master's model_config validation forbidding score + chat/completion/embeddings on llama-cpp. token_classify is unaffected (it runs on the privacy-filter backend, not llama-cpp). Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
committed by
GitHub
parent
c133ca39dc
commit
3fa7b2955c
@@ -1,71 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// FileConfig is the on-disk schema for pii.yaml. Each Pattern entry
|
||||
// overrides the matching default by ID; missing fields fall back to
|
||||
// the default. Unknown IDs are rejected at load time so an admin who
|
||||
// fat-fingers a pattern name gets a clear error rather than a silent
|
||||
// no-op.
|
||||
type FileConfig struct {
|
||||
Patterns []FilePattern `yaml:"patterns"`
|
||||
}
|
||||
|
||||
type FilePattern struct {
|
||||
ID string `yaml:"id"`
|
||||
Action Action `yaml:"action"`
|
||||
}
|
||||
|
||||
// LoadConfig reads pii.yaml from path and merges it on top of
|
||||
// DefaultPatterns(). path == "" returns the defaults compiled and
|
||||
// ready. The returned slice is already Compile()'d, so callers can
|
||||
// pass it straight to NewRedactor.
|
||||
func LoadConfig(path string) ([]Pattern, error) {
|
||||
defaults := DefaultPatterns()
|
||||
if path == "" {
|
||||
return Compile(defaults)
|
||||
}
|
||||
|
||||
raw, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pii: read config %q: %w", path, err)
|
||||
}
|
||||
var cfg FileConfig
|
||||
if err := yaml.Unmarshal(raw, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("pii: parse config %q: %w", path, err)
|
||||
}
|
||||
|
||||
overrides := make(map[string]Action, len(cfg.Patterns))
|
||||
known := make(map[string]bool, len(defaults))
|
||||
for _, d := range defaults {
|
||||
known[d.ID] = true
|
||||
}
|
||||
for _, p := range cfg.Patterns {
|
||||
if !known[p.ID] {
|
||||
return nil, fmt.Errorf("pii: unknown pattern id %q in %q", p.ID, path)
|
||||
}
|
||||
if p.Action == "" {
|
||||
continue
|
||||
}
|
||||
switch p.Action {
|
||||
case ActionMask, ActionBlock, ActionAllow:
|
||||
overrides[p.ID] = p.Action
|
||||
default:
|
||||
return nil, fmt.Errorf("pii: invalid action %q for pattern %q", p.Action, p.ID)
|
||||
}
|
||||
}
|
||||
|
||||
merged := make([]Pattern, len(defaults))
|
||||
for i, d := range defaults {
|
||||
if a, ok := overrides[d.ID]; ok {
|
||||
d.Action = a
|
||||
}
|
||||
merged[i] = d
|
||||
}
|
||||
return Compile(merged)
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("LoadConfig", func() {
|
||||
It("returns defaults when no path given", func() {
|
||||
patterns, err := LoadConfig("")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(patterns).To(HaveLen(len(DefaultPatterns())))
|
||||
})
|
||||
|
||||
It("overrides action", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
path := filepath.Join(dir, "pii.yaml")
|
||||
body := []byte(`patterns:
|
||||
- id: email
|
||||
action: block
|
||||
- id: ssn
|
||||
action: allow
|
||||
`)
|
||||
Expect(os.WriteFile(path, body, 0o600)).To(Succeed())
|
||||
patterns, err := LoadConfig(path)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
got := map[string]Action{}
|
||||
for _, p := range patterns {
|
||||
got[p.ID] = p.Action
|
||||
}
|
||||
Expect(got["email"]).To(Equal(ActionBlock))
|
||||
Expect(got["ssn"]).To(Equal(ActionAllow))
|
||||
// Unmentioned patterns keep their default action.
|
||||
Expect(got["credit_card"]).To(Equal(ActionMask), "credit_card default action lost")
|
||||
})
|
||||
|
||||
It("rejects unknown id", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
path := filepath.Join(dir, "pii.yaml")
|
||||
Expect(os.WriteFile(path, []byte("patterns:\n - id: nonsense\n action: mask\n"), 0o600)).To(Succeed())
|
||||
_, err := LoadConfig(path)
|
||||
Expect(err).To(HaveOccurred(), "expected error on unknown pattern id")
|
||||
})
|
||||
|
||||
It("rejects invalid action", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
path := filepath.Join(dir, "pii.yaml")
|
||||
Expect(os.WriteFile(path, []byte("patterns:\n - id: email\n action: lolwhat\n"), 0o600)).To(Succeed())
|
||||
_, err := LoadConfig(path)
|
||||
Expect(err).To(HaveOccurred(), "expected error on invalid action")
|
||||
})
|
||||
})
|
||||
@@ -19,28 +19,71 @@ import (
|
||||
// drag the http/middleware package into pii's import graph and create
|
||||
// a cycle (http/middleware will import this one).
|
||||
const (
|
||||
ctxKeyCorrelationID = "routing.correlation_id"
|
||||
ctxKeyPIIEventID = "routing.pii_event_id"
|
||||
ctxKeyCorrelationID = "routing.correlation_id"
|
||||
ctxKeyPIIEventID = "routing.pii_event_id"
|
||||
// Must match the constants in core/http/middleware/request.go.
|
||||
// Echoing them across packages would create an import cycle
|
||||
// (http/middleware imports this package). Drift is caught by
|
||||
// integration tests against the chat route.
|
||||
ctxKeyParsedRequest = "LOCALAI_REQUEST"
|
||||
ctxKeyModelConfig = "MODEL_CONFIG"
|
||||
ctxKeyParsedRequest = "LOCALAI_REQUEST"
|
||||
ctxKeyModelConfig = "MODEL_CONFIG"
|
||||
)
|
||||
|
||||
// ModelPIIConfig is the duck-typed view this middleware needs of the
|
||||
// per-model PII configuration carried on the echo context. *config.ModelConfig
|
||||
// satisfies it via PIIIsEnabled / PIIPatternOverrides; the indirection
|
||||
// keeps the pii package from importing core/config.
|
||||
// per-model PII configuration carried on the echo context.
|
||||
// *config.ModelConfig satisfies it via PIIIsEnabled / PIIDetectors; the
|
||||
// indirection keeps the pii package from importing core/config.
|
||||
//
|
||||
// Consumers of the override map: the action returned from PIIPatternOverrides
|
||||
// is the raw YAML string (e.g. "block"). Validation against the canonical
|
||||
// ActionMask/Block/Allow constants happens here, so a typo in a model
|
||||
// YAML logs and is ignored rather than panicking.
|
||||
// PIIDetectors lists the token-classification models whose detections
|
||||
// drive redaction for this (consuming) model. The detection policy lives
|
||||
// on each named detector model — resolved via NERDetectorResolver — so
|
||||
// this consuming view carries no per-entity actions of its own.
|
||||
type ModelPIIConfig interface {
|
||||
PIIIsEnabled() bool
|
||||
PIIPatternOverrides() map[string]string
|
||||
PIIDetectors() []string
|
||||
}
|
||||
|
||||
// NERDetectorResolver resolves a detector model name to a ready-to-use
|
||||
// NERConfig — the detector plus the policy (min score, entity→action
|
||||
// map, default action) read from that model's own pii_detection block.
|
||||
// ok is false when the name can't supply a detector (unknown model, not
|
||||
// a token_classify model, or load failure); the middleware fails closed
|
||||
// in that case. Supplied by the application layer, which owns the model
|
||||
// loader and the core/backend dependency, keeping the pii package free of
|
||||
// both. A nil resolver (or the option being unset) disables the NER tier.
|
||||
type NERDetectorResolver func(modelName string) (NERConfig, bool)
|
||||
|
||||
// Option configures optional RequestMiddleware behaviour. Threaded as
|
||||
// variadic options so adding the NER tier doesn't break the existing
|
||||
// four-argument call sites (routes and tests).
|
||||
type Option func(*mwOptions)
|
||||
|
||||
type mwOptions struct {
|
||||
nerResolver NERDetectorResolver
|
||||
policyResolver PolicyResolver
|
||||
}
|
||||
|
||||
// PolicyResolver returns the effective (enabled, detectors) for the model
|
||||
// carried on the request context, layering instance-wide PII defaults over the
|
||||
// per-model config. Supplied by the application layer (which owns core/config),
|
||||
// keeping this package decoupled from it — the middleware passes the raw
|
||||
// context value through as `any`. When unset, the middleware falls back to the
|
||||
// duck-typed ModelPIIConfig (explicit per-model config only, no global default).
|
||||
type PolicyResolver func(modelCfg any) (enabled bool, detectors []string)
|
||||
|
||||
// WithPolicyResolver overrides how the middleware decides enablement and the
|
||||
// detector list, so the instance-wide default detector / default-on usecases
|
||||
// apply. Without it the middleware reads ModelPIIConfig off the context.
|
||||
func WithPolicyResolver(r PolicyResolver) Option {
|
||||
return func(o *mwOptions) { o.policyResolver = r }
|
||||
}
|
||||
|
||||
// WithNERResolver enables the NER tier. When a request's model lists
|
||||
// pii.detectors, the middleware resolves each to a NERConfig and runs
|
||||
// RedactNER (the union of all detectors' hits, merged). Without this
|
||||
// option, or when a model lists no detectors, redaction is a no-op.
|
||||
func WithNERResolver(r NERDetectorResolver) Option {
|
||||
return func(o *mwOptions) { o.nerResolver = r }
|
||||
}
|
||||
|
||||
// ScannedText is one piece of user text from the request. Index is
|
||||
@@ -84,30 +127,32 @@ type Adapter struct {
|
||||
// no-auth identity. The middleware writes ctxKeyPIIEventID on the echo
|
||||
// context so the usage middleware can later cross-reference the event
|
||||
// with the UsageRecord.
|
||||
func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fallbackUser *auth.User) echo.MiddlewareFunc {
|
||||
func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fallbackUser *auth.User, opts ...Option) echo.MiddlewareFunc {
|
||||
var o mwOptions
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if redactor == nil || len(redactor.Patterns()) == 0 || adapter.Scan == nil {
|
||||
if redactor == nil || adapter.Scan == nil {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Per-model gating: redaction is opt-in per model. If the
|
||||
// resolved config disables PII for this model (the default
|
||||
// for non-proxy backends), pass through immediately. We do
|
||||
// this before parsing the request so a disabled model
|
||||
// doesn't pay the regex scan cost.
|
||||
if cfg, ok := c.Get(ctxKeyModelConfig).(ModelPIIConfig); ok {
|
||||
if !cfg.PIIIsEnabled() {
|
||||
return next(c)
|
||||
}
|
||||
} else {
|
||||
// No ModelPIIConfig on context → fail-closed: skip
|
||||
// redaction. This protects routes that wire the
|
||||
// middleware before SetModelAndConfig runs (or non-chat
|
||||
// routes that don't carry a model). The middleware was
|
||||
// previously fail-open, applying the global redactor
|
||||
// unconditionally; the new contract is per-model
|
||||
// opt-in, and a missing model is treated as disabled.
|
||||
// Per-model gating: redaction is opt-in per model. The policy
|
||||
// resolver (when wired) layers instance-wide defaults over the
|
||||
// per-model config; otherwise we read the per-model config
|
||||
// directly. A missing config (non-chat routes, or middleware
|
||||
// wired before SetModelAndConfig) or a not-enabled result passes
|
||||
// through.
|
||||
rawCfg := c.Get(ctxKeyModelConfig)
|
||||
var enabled bool
|
||||
var detectors []string
|
||||
if o.policyResolver != nil {
|
||||
enabled, detectors = o.policyResolver(rawCfg)
|
||||
} else if cfg, ok := rawCfg.(ModelPIIConfig); ok {
|
||||
enabled, detectors = cfg.PIIIsEnabled(), cfg.PIIDetectors()
|
||||
}
|
||||
if !enabled {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
@@ -116,6 +161,12 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// A PII-enabled model with no detectors (or no resolver wired)
|
||||
// has nothing to scan with — pass through.
|
||||
if len(detectors) == 0 || o.nerResolver == nil {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
user := auth.GetUser(c)
|
||||
if user == nil {
|
||||
user = fallbackUser
|
||||
@@ -126,24 +177,19 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
}
|
||||
correlationID, _ := c.Get(ctxKeyCorrelationID).(string)
|
||||
|
||||
// Resolve per-model action overrides once per request. The
|
||||
// raw map is YAML strings; convert to the typed Action set
|
||||
// and silently drop unknown values rather than failing the
|
||||
// request — model YAML typos shouldn't take chat down.
|
||||
var overrides map[string]Action
|
||||
if cfg, ok := c.Get(ctxKeyModelConfig).(ModelPIIConfig); ok {
|
||||
if raw := cfg.PIIPatternOverrides(); len(raw) > 0 {
|
||||
overrides = make(map[string]Action, len(raw))
|
||||
for id, action := range raw {
|
||||
switch Action(action) {
|
||||
case ActionMask, ActionBlock, ActionAllow:
|
||||
overrides[id] = Action(action)
|
||||
default:
|
||||
xlog.Warn("pii: ignoring unknown action in per-model override",
|
||||
"pattern", id, "action", action)
|
||||
}
|
||||
}
|
||||
// Resolve each named detector to its NERConfig (detector +
|
||||
// the policy from that model's own pii_detection block). A
|
||||
// configured detector that can't be resolved fails closed:
|
||||
// serving the request without the semantic check the operator
|
||||
// asked for is exactly the leak this tier exists to prevent.
|
||||
cfgs := make([]NERConfig, 0, len(detectors))
|
||||
for _, name := range detectors {
|
||||
nc, ok := o.nerResolver(name)
|
||||
if !ok {
|
||||
xlog.Error("pii: configured detector model could not be resolved; blocking request (fail-closed)", "detector", name)
|
||||
return blockNERUnavailable(c, store, correlationID, userID)
|
||||
}
|
||||
cfgs = append(cfgs, nc)
|
||||
}
|
||||
|
||||
texts := adapter.Scan(parsed)
|
||||
@@ -151,24 +197,38 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
var blocked bool
|
||||
var firstEventID string
|
||||
|
||||
for _, st := range texts {
|
||||
if st.Text == "" {
|
||||
continue
|
||||
}
|
||||
res := redactor.RedactWithOverrides(st.Text, overrides)
|
||||
// Scan the request as ONE document (messages joined) so the NER
|
||||
// tier keeps conversational context — whether "4421" is a PIN is
|
||||
// decided by the question in the previous message. The spans come
|
||||
// back per message with local offsets for in-place rewriting.
|
||||
segTexts := make([]string, len(texts))
|
||||
for i, st := range texts {
|
||||
segTexts[i] = st.Text
|
||||
}
|
||||
// Fail closed: a detector outage at request time must NOT
|
||||
// silently serve the request. The NER tier was explicitly
|
||||
// configured for this model, so the semantic check is part
|
||||
// of the contract.
|
||||
segResults, nerErr := RedactNERSegments(c.Request().Context(), segTexts, cfgs)
|
||||
if nerErr != nil {
|
||||
xlog.Error("pii: NER detector failed; blocking request (fail-closed)", "error", nerErr)
|
||||
return blockNERUnavailable(c, store, correlationID, userID)
|
||||
}
|
||||
|
||||
for i, res := range segResults {
|
||||
st := texts[i]
|
||||
if len(res.Spans) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Persist one event per span so admins can see exactly
|
||||
// which patterns fired in which positions. The action
|
||||
// recorded is the resolved one (after override), so the
|
||||
// events log reflects what actually happened to the
|
||||
// request, not the global default.
|
||||
// Persist one event per detected span. The action recorded
|
||||
// is the one that actually fired (carried on the span after
|
||||
// the overlap merge), so the events log reflects what
|
||||
// happened to the request.
|
||||
for _, span := range res.Spans {
|
||||
action := actionForSpan(redactor.Patterns(), span.Pattern, overrides)
|
||||
ev := PIIEvent{
|
||||
ID: newEventID(),
|
||||
Origin: OriginMiddleware,
|
||||
CorrelationID: correlationID,
|
||||
UserID: userID,
|
||||
Direction: DirectionIn,
|
||||
@@ -176,7 +236,8 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
ByteOffset: span.Start,
|
||||
Length: span.End - span.Start,
|
||||
HashPrefix: span.HashPrefix,
|
||||
Action: action,
|
||||
Action: span.Action,
|
||||
Score: span.Score,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
if firstEventID == "" {
|
||||
@@ -223,24 +284,85 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
}
|
||||
}
|
||||
|
||||
func actionForPattern(patterns []Pattern, id string) Action {
|
||||
for _, p := range patterns {
|
||||
if p.ID == id {
|
||||
return p.Action
|
||||
// nerUnavailablePattern is the sentinel PatternID recorded on the
|
||||
// fail-closed audit event when a model's configured NER tier cannot
|
||||
// run. It is not a real regex pattern — it marks a request blocked
|
||||
// because the encoder/NER check was unavailable (model unresolved or
|
||||
// backend error), so the events log distinguishes it from a content
|
||||
// block (which carries a real pattern ID).
|
||||
const nerUnavailablePattern = "__ner_unavailable__"
|
||||
|
||||
// blockNERUnavailable records a fail-closed audit event and returns the
|
||||
// response used when a model has an NER tier configured but it could
|
||||
// not run. Failing closed is deliberate for a PII filter: if the
|
||||
// semantic check the operator asked for cannot execute, refusing the
|
||||
// request is safer than serving it with only the cheap regex tier. The
|
||||
// 503 (vs the 400 used for a content block) tells clients and operators
|
||||
// this was a dependency outage, not sensitive data in the request.
|
||||
func blockNERUnavailable(c echo.Context, store EventStore, correlationID, userID string) error {
|
||||
ev := PIIEvent{
|
||||
ID: newEventID(),
|
||||
Kind: KindPII,
|
||||
Origin: OriginMiddleware,
|
||||
CorrelationID: correlationID,
|
||||
UserID: userID,
|
||||
Direction: DirectionIn,
|
||||
PatternID: nerUnavailablePattern,
|
||||
Action: ActionBlock,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
if store != nil {
|
||||
if err := store.Record(context.Background(), ev); err != nil {
|
||||
xlog.Error("pii: failed to record NER-unavailable event", "error", err)
|
||||
}
|
||||
}
|
||||
return ActionMask
|
||||
c.Set(ctxKeyPIIEventID, ev.ID)
|
||||
return c.JSON(http.StatusServiceUnavailable, map[string]any{
|
||||
"error": map[string]string{
|
||||
"message": "request blocked: PII NER check is configured but unavailable",
|
||||
"type": "pii_ner_unavailable",
|
||||
},
|
||||
"correlation_id": correlationID,
|
||||
"pii_event_id": ev.ID,
|
||||
})
|
||||
}
|
||||
|
||||
// actionForSpan returns the resolved action for a span, preferring a
|
||||
// per-request override over the pattern's stored action. Used so the
|
||||
// PIIEvent log reflects the action that actually fired (e.g., a model
|
||||
// upgraded email from mask to block — the event row says "block").
|
||||
func actionForSpan(patterns []Pattern, id string, overrides map[string]Action) Action {
|
||||
if action, ok := overrides[id]; ok {
|
||||
return action
|
||||
// validAction converts a raw YAML action string to the typed Action,
|
||||
// returning "" for anything that isn't a known action.
|
||||
func validAction(raw string) Action {
|
||||
switch Action(raw) {
|
||||
case ActionMask, ActionBlock, ActionAllow:
|
||||
return Action(raw)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
return actionForPattern(patterns, id)
|
||||
}
|
||||
|
||||
// validActionOr is validAction with a fallback for empty/invalid input.
|
||||
func validActionOr(raw string, fallback Action) Action {
|
||||
if a := validAction(raw); a != "" {
|
||||
return a
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// validActions converts a raw entity-group->action map to typed
|
||||
// Actions, dropping (and logging) unknown actions so a model YAML typo
|
||||
// is ignored rather than taking the request down — mirroring how the
|
||||
// per-pattern overrides are validated above.
|
||||
func validActions(raw map[string]string) map[string]Action {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]Action, len(raw))
|
||||
for group, action := range raw {
|
||||
if a := validAction(action); a != "" {
|
||||
out[group] = a
|
||||
} else {
|
||||
xlog.Warn("pii: ignoring unknown NER entity action", "group", group, "action", action)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func newEventID() string {
|
||||
@@ -248,3 +370,8 @@ func newEventID() string {
|
||||
_, _ = rand.Read(b[:])
|
||||
return "pii_" + hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
// NewEventID mints a fresh random event id in the package's standard shape.
|
||||
// Exported so callers outside this package (the analyze/redact API handlers)
|
||||
// record events with ids indistinguishable from the in-band middleware's.
|
||||
func NewEventID() string { return newEventID() }
|
||||
|
||||
@@ -3,12 +3,12 @@ package pii
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -56,21 +56,15 @@ func setRequestOnContext(req *fakeRequest) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
// fakeModelPIIConfig satisfies the duck-typed ModelPIIConfig interface
|
||||
// the middleware expects on the echo context. The real implementation
|
||||
// lives on *config.ModelConfig; using a fake here keeps these tests
|
||||
// out of the core/config import graph.
|
||||
// the middleware expects on the echo context (PIIIsEnabled + PIIDetectors).
|
||||
type fakeModelPIIConfig struct {
|
||||
enabled bool
|
||||
overrides map[string]string
|
||||
detectors []string
|
||||
}
|
||||
|
||||
func (f fakeModelPIIConfig) PIIIsEnabled() bool { return f.enabled }
|
||||
func (f fakeModelPIIConfig) PIIPatternOverrides() map[string]string { return f.overrides }
|
||||
func (f fakeModelPIIConfig) PIIIsEnabled() bool { return f.enabled }
|
||||
func (f fakeModelPIIConfig) PIIDetectors() []string { return f.detectors }
|
||||
|
||||
// withModelConfig wires a ModelPIIConfig onto the context so the
|
||||
// middleware's per-model gate doesn't fail-closed during tests. Pass
|
||||
// enabled=true for the default test path; explicit-false tests should
|
||||
// use the gating spec further down instead.
|
||||
func withModelConfig(cfg fakeModelPIIConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
@@ -80,230 +74,257 @@ func withModelConfig(cfg fakeModelPIIConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRedactor(ids ...string) *Redactor {
|
||||
patterns, err := Compile(pick(DefaultPatterns(), ids))
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
|
||||
return NewRedactor(patterns)
|
||||
// resolverFor returns a NERDetectorResolver that maps each named model to
|
||||
// the supplied NERConfig. Names absent from the map resolve to (zero,
|
||||
// false) so the middleware fails closed — mirroring an unresolvable model.
|
||||
func resolverFor(byName map[string]NERConfig) NERDetectorResolver {
|
||||
return func(name string) (NERConfig, bool) {
|
||||
cfg, ok := byName[name]
|
||||
return cfg, ok
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("RequestMiddleware", func() {
|
||||
It("masks email", func() {
|
||||
red := newTestRedactor("email")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
user := &auth.User{ID: "user-1", Name: "alice"}
|
||||
func serve(body *fakeRequest, cfg fakeModelPIIConfig, mw echo.MiddlewareFunc, withConfig bool) (*httptest.ResponseRecorder, *bool) {
|
||||
called := new(bool)
|
||||
e := echo.New()
|
||||
chain := []echo.MiddlewareFunc{setRequestOnContext(body)}
|
||||
if withConfig {
|
||||
chain = append(chain, withModelConfig(cfg))
|
||||
}
|
||||
chain = append(chain, mw)
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
*called = true
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, chain...)
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
return w, called
|
||||
}
|
||||
|
||||
body := &fakeRequest{Messages: []string{"contact me at alice@example.com"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
func nerCfg(action Action, entities ...NEREntity) NERConfig {
|
||||
return NERConfig{
|
||||
Detector: &stubNERDetector{entities: entities},
|
||||
DefaultAction: action,
|
||||
}
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
// Inject the user as if upstream auth ran.
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", user)
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
var _ = Describe("RequestMiddleware (NER)", func() {
|
||||
store := func() EventStore { return NewMemoryEventStore(0) }
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
It("masks a detected entity end-to-end", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Hi I'm Alice today"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{
|
||||
"privacy-filter": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
|
||||
})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"privacy-filter"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK), "body=%s", w.Body.String())
|
||||
Expect(body.Messages[0]).NotTo(ContainSubstring("alice@example.com"), "request body should be redacted in place")
|
||||
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:email]"))
|
||||
|
||||
events, err := store.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(err).NotTo(HaveOccurred(), "list events")
|
||||
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:ner:PER]"))
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].PatternID).To(Equal("email"))
|
||||
Expect(events[0].PatternID).To(Equal("ner:PER"))
|
||||
Expect(events[0].Direction).To(Equal(DirectionIn))
|
||||
})
|
||||
|
||||
It("blocks api key", func() {
|
||||
red := newTestRedactor("api_key_prefix")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
|
||||
body := &fakeRequest{Messages: []string{"my key is sk-abcdefghijklmnopqrstuvwxyz0123456789"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
handlerCalled := false
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
handlerCalled = true
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusBadRequest), "expected 400 on block; body=%s", w.Body.String())
|
||||
Expect(handlerCalled).To(BeFalse(), "handler must not run when request is blocked")
|
||||
// Ensure the matched value never appears in the response body.
|
||||
Expect(w.Body.String()).NotTo(ContainSubstring("abcdefghijklmnopqrstuvwxyz0123456789"), "blocked response leaks the matched value")
|
||||
It("blocks (400) when a detected entity's action is block", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"my password is hunter2 ok"}}
|
||||
cfg := NERConfig{
|
||||
Detector: &stubNERDetector{entities: []NEREntity{{Group: "PASSWORD", Start: 15, End: 22, Score: 0.99}}},
|
||||
EntityActions: map[string]Action{"PASSWORD": ActionBlock},
|
||||
}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{"pf": cfg})))
|
||||
w, called := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusBadRequest), "body=%s", w.Body.String())
|
||||
Expect(*called).To(BeFalse(), "handler must not run when blocked")
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(w.Body.Bytes(), &resp)).To(Succeed())
|
||||
errBlock, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
errBlock, _ := resp["error"].(map[string]any)
|
||||
Expect(errBlock["type"]).To(Equal("pii_blocked"))
|
||||
})
|
||||
|
||||
It("allow leaves text intact but still records an event", func() {
|
||||
patterns, _ := Compile([]Pattern{{
|
||||
ID: "email", Description: "Email", Action: ActionAllow, MaxMatchLength: 254,
|
||||
}})
|
||||
red := NewRedactor(patterns)
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
|
||||
It("allow leaves text intact but records an event", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"hi at alice@example.com"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
cfg := NERConfig{
|
||||
Detector: &stubNERDetector{entities: []NEREntity{{Group: "EMAIL", Start: 6, End: 23, Score: 0.9}}},
|
||||
EntityActions: map[string]Action{"EMAIL": ActionAllow},
|
||||
}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{"pf": cfg})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
// allow does NOT mutate the body — the model still sees the email.
|
||||
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "allow should leave text intact")
|
||||
// ...but the detection is still recorded for audit.
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1), "allow should still record a PIIEvent")
|
||||
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"))
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].Action).To(Equal(ActionAllow))
|
||||
})
|
||||
|
||||
It("no match passes through", func() {
|
||||
red := newTestRedactor()
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
|
||||
It("passes through on no match", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"perfectly innocent text"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{"pf": nerCfg(ActionMask)})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(Equal("perfectly innocent text"), "body should be untouched")
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(BeEmpty(), "expected 0 events on no-match input")
|
||||
Expect(body.Messages[0]).To(Equal("perfectly innocent text"))
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("skips when model config disabled", func() {
|
||||
// Per-model gating is the new contract: a model with PIIIsEnabled
|
||||
// returning false must bypass redaction entirely, even if the
|
||||
// global redactor has matching patterns.
|
||||
red := newTestRedactor("email")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
|
||||
body := &fakeRequest{Messages: []string{"contact alice@example.com"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: false}), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
It("skips when the model has PII disabled", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Hi I'm Alice"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{
|
||||
"pf": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
|
||||
})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: false, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "disabled model must not redact")
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(BeEmpty(), "disabled model must produce no events")
|
||||
Expect(body.Messages[0]).To(Equal("Hi I'm Alice"), "disabled model must not redact")
|
||||
})
|
||||
|
||||
It("fails closed without model config", func() {
|
||||
// Routes that wire the middleware before SetModelAndConfig, or
|
||||
// non-chat routes lacking a model, hit this path. The contract
|
||||
// is fail-closed: pass through without redaction so a missing
|
||||
// model can't accidentally leak through global defaults.
|
||||
red := newTestRedactor("email")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
|
||||
body := &fakeRequest{Messages: []string{"contact alice@example.com"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
// Note: no withModelConfig in the chain.
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
It("passes through when the model lists no detectors", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Hi I'm Alice"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "missing ModelPIIConfig should fail-closed (no redaction)")
|
||||
Expect(body.Messages[0]).To(Equal("Hi I'm Alice"))
|
||||
})
|
||||
|
||||
It("applies per-model override", func() {
|
||||
// email defaults to mask. A per-model override upgrades it to
|
||||
// block. The middleware short-circuits with 400, the request
|
||||
// body is never touched, and the events log records action=block.
|
||||
red := newTestRedactor("email")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
It("fails closed without a model config", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Hi I'm Alice"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{
|
||||
"pf": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
|
||||
})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{}, mw, false) // no model config on context
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(Equal("Hi I'm Alice"), "missing ModelPIIConfig should pass through")
|
||||
})
|
||||
|
||||
It("unions multiple detectors", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Alice at acme"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{
|
||||
"names": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 0, End: 5, Score: 0.9}),
|
||||
"orgs": nerCfg(ActionMask, NEREntity{Group: "ORG", Start: 9, End: 13, Score: 0.9}),
|
||||
})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"names", "orgs"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:ner:PER]"))
|
||||
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:ner:ORG]"))
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("fails closed (503) when a detector errors", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"contact alice@example.com"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
cfg := NERConfig{Detector: &stubNERDetector{err: errors.New("backend offline")}, DefaultAction: ActionMask}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{"pf": cfg})))
|
||||
w, called := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
e := echo.New()
|
||||
handlerCalled := false
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
handlerCalled = true
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body),
|
||||
withModelConfig(fakeModelPIIConfig{
|
||||
enabled: true,
|
||||
overrides: map[string]string{"email": "block"},
|
||||
}), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusBadRequest), "expected 400 from override-block; body=%s", w.Body.String())
|
||||
Expect(handlerCalled).To(BeFalse(), "handler must not run when override blocks")
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(w.Code).To(Equal(http.StatusServiceUnavailable), "body=%s", w.Body.String())
|
||||
Expect(*called).To(BeFalse())
|
||||
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "request body must be untouched on a fail-closed block")
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(w.Body.Bytes(), &resp)).To(Succeed())
|
||||
errBlock, _ := resp["error"].(map[string]any)
|
||||
Expect(errBlock["type"]).To(Equal("pii_ner_unavailable"))
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].Action).To(Equal(ActionBlock), "event must record the resolved (override) action")
|
||||
Expect(events[0].PatternID).To(Equal(nerUnavailablePattern))
|
||||
})
|
||||
|
||||
It("fails closed (503) when a configured detector can't be resolved", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"contact alice@example.com"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{}))) // "missing" not present
|
||||
w, called := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"missing"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusServiceUnavailable))
|
||||
Expect(*called).To(BeFalse())
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].PatternID).To(Equal(nerUnavailablePattern))
|
||||
})
|
||||
|
||||
It("nil redactor is passthrough", func() {
|
||||
body := &fakeRequest{Messages: []string{"alice@example.com"}}
|
||||
mw := RequestMiddleware(nil, nil, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(Equal("alice@example.com"), "nil redactor must be a no-op")
|
||||
})
|
||||
|
||||
It("WithPolicyResolver enables a model the per-model config left off (global default)", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Hi I'm Alice today"}}
|
||||
// The per-model config is disabled with no detectors; the policy
|
||||
// resolver (instance-wide default) turns it on and supplies one.
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{
|
||||
"global-pf": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
|
||||
})),
|
||||
WithPolicyResolver(func(_ any) (bool, []string) { return true, []string{"global-pf"} }))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: false}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK), "body=%s", w.Body.String())
|
||||
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:ner:PER]"))
|
||||
})
|
||||
|
||||
It("WithPolicyResolver returning disabled short-circuits an otherwise-enabled model", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Hi I'm Alice today"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{
|
||||
"pf": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
|
||||
})),
|
||||
WithPolicyResolver(func(_ any) (bool, []string) { return false, nil }))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(Equal("Hi I'm Alice today"), "resolver disabled => no redaction")
|
||||
})
|
||||
|
||||
It("scans all messages as one document so earlier-message context applies", func() {
|
||||
st := store()
|
||||
// The detector (pinAfterCard) only recognises "4421" when "card"
|
||||
// appears earlier in the SAME text it is handed — so this only
|
||||
// masks if the middleware joins the messages before scanning.
|
||||
body := &fakeRequest{Messages: []string{
|
||||
"What are the last four digits of your card?",
|
||||
"it is 4421 ok",
|
||||
}}
|
||||
cfg := NERConfig{Detector: &funcNERDetector{fn: pinAfterCard}, DefaultAction: ActionMask}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{"pf": cfg})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK), "body=%s", w.Body.String())
|
||||
Expect(body.Messages[0]).To(Equal("What are the last four digits of your card?"), "question untouched")
|
||||
Expect(body.Messages[1]).To(Equal("it is [REDACTED:ner:PIN] ok"))
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].ByteOffset).To(Equal(6), "event offsets are message-local")
|
||||
})
|
||||
})
|
||||
|
||||
@@ -28,6 +28,10 @@ type NEREntity struct {
|
||||
Start int
|
||||
End int
|
||||
Score float32
|
||||
// Text is the matched substring as the detector saw it. Carried for
|
||||
// debug logging only (the persisted PIIEvent never stores the raw
|
||||
// value); the redactor re-slices the original text for masking.
|
||||
Text string
|
||||
}
|
||||
|
||||
// NERConfig configures the encoder tier for one redactor invocation.
|
||||
@@ -56,8 +60,22 @@ type NERConfig struct {
|
||||
// entities silently" — useful when the model returns a broad
|
||||
// taxonomy but the admin only cares about a subset.
|
||||
DefaultAction Action
|
||||
|
||||
// Source labels where this detector's hits come from. It becomes the
|
||||
// PatternID prefix on events and the [REDACTED:<id>] mask, so neural NER
|
||||
// detections (Source "ner") and deterministic pattern-matcher detections
|
||||
// (Source "pattern") are told apart in the events log and to the model.
|
||||
// Empty defaults to "ner" for backward compatibility.
|
||||
Source string
|
||||
}
|
||||
|
||||
// Detector source labels (the PatternID prefix). Kept short and stable —
|
||||
// they appear in the events log and the [REDACTED:...] mask.
|
||||
const (
|
||||
SourceNER = "ner"
|
||||
SourcePattern = "pattern"
|
||||
)
|
||||
|
||||
// ResolveAction returns the action configured for a detected entity
|
||||
// group, falling back to DefaultAction. Returns ("", false) when the
|
||||
// entity should be ignored entirely (no override + no default).
|
||||
@@ -71,13 +89,39 @@ func (c NERConfig) ResolveAction(group string) (Action, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// nerPatternID returns the synthetic pattern ID that audit rows carry
|
||||
// for NER hits. Prefixing with "ner:" keeps these distinguishable from
|
||||
// regex pattern IDs in the events tab and in filter queries; admins
|
||||
// can switch off a single entity type with the same Disabled-pattern
|
||||
// machinery used for regex.
|
||||
func nerPatternID(group string) string {
|
||||
return "ner:" + group
|
||||
// NERConfigFromRaw builds a typed NERConfig from a detector plus the raw
|
||||
// policy strings carried on a detector model's pii_detection config. An
|
||||
// empty or invalid default_action becomes ActionMask — the safe-by-default
|
||||
// policy for a PII filter (a detected entity is masked unless an admin
|
||||
// downgrades it). Unknown per-entity actions are dropped (and logged by
|
||||
// validActions). This is the single conversion point the application-layer
|
||||
// resolver uses, so the detector model's policy reaches the redactor in
|
||||
// exactly one shape. source labels the detector kind (SourceNER /
|
||||
// SourcePattern) and becomes the PatternID prefix; empty defaults to
|
||||
// SourceNER.
|
||||
func NERConfigFromRaw(detector NERDetector, minScore float32, defaultAction string, entityActions map[string]string, source string) NERConfig {
|
||||
if source == "" {
|
||||
source = SourceNER
|
||||
}
|
||||
return NERConfig{
|
||||
Detector: detector,
|
||||
MinScore: minScore,
|
||||
DefaultAction: validActionOr(defaultAction, ActionMask),
|
||||
EntityActions: validActions(entityActions),
|
||||
Source: source,
|
||||
}
|
||||
}
|
||||
|
||||
// patternID returns the synthetic pattern ID that audit rows and masks carry
|
||||
// for this detector's hits, e.g. "ner:EMAIL" or "pattern:ANTHROPIC_KEY". The
|
||||
// source prefix keeps neural and deterministic detections distinguishable in
|
||||
// the events tab and in pattern_id filter queries.
|
||||
func (c NERConfig) patternID(group string) string {
|
||||
source := c.Source
|
||||
if source == "" {
|
||||
source = SourceNER
|
||||
}
|
||||
return source + ":" + group
|
||||
}
|
||||
|
||||
// errNERDetector is a NERDetector that always returns the wrapped
|
||||
|
||||
@@ -9,8 +9,7 @@ import (
|
||||
)
|
||||
|
||||
// stubNERDetector returns a fixed slice of entities and tracks call
|
||||
// count so tests can assert the detector isn't called when text is
|
||||
// empty / no patterns / detector disabled.
|
||||
// count so tests can assert the detector isn't called when text is empty.
|
||||
type stubNERDetector struct {
|
||||
entities []NEREntity
|
||||
err error
|
||||
@@ -22,43 +21,39 @@ func (s *stubNERDetector) Detect(_ context.Context, _ string) ([]NEREntity, erro
|
||||
return s.entities, s.err
|
||||
}
|
||||
|
||||
var _ = Describe("RedactWithNER", func() {
|
||||
It("nil detector is regex-only", func() {
|
||||
// When the NER tier is disabled (Detector == nil) the redactor
|
||||
// must behave exactly like the existing regex-only path — no
|
||||
// detector call, same Result shape, no error.
|
||||
r := NewRedactor([]Pattern{pickEmail()})
|
||||
res, err := r.RedactWithNER(context.Background(), "ping me at alice@example.com", nil, NERConfig{})
|
||||
var _ = Describe("RedactNER", func() {
|
||||
It("no detectors is a no-op", func() {
|
||||
res, err := RedactNER(context.Background(), "ping me at alice@example.com", nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:email]"), "regex tier should still run when Detector is nil")
|
||||
Expect(res.Redacted).To(Equal("ping me at alice@example.com"))
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("applies entity actions", func() {
|
||||
det := &stubNERDetector{entities: []NEREntity{
|
||||
{Group: "PER", Start: 6, End: 11, Score: 0.95}, // "Alice" in "Hi I'm Alice today"
|
||||
}}
|
||||
r := NewRedactor(nil)
|
||||
res, err := r.RedactWithNER(context.Background(), "Hi I'm Alice today", nil, NERConfig{
|
||||
res, err := RedactNER(context.Background(), "Hi I'm Alice today", []NERConfig{{
|
||||
Detector: det,
|
||||
EntityActions: map[string]Action{"PER": ActionMask},
|
||||
})
|
||||
}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(det.calls).To(Equal(1))
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:PER]"))
|
||||
Expect(res.Spans).To(HaveLen(1))
|
||||
Expect(res.Spans[0].Pattern).To(Equal("ner:PER"))
|
||||
Expect(res.Spans[0].Action).To(Equal(ActionMask))
|
||||
})
|
||||
|
||||
It("filters below MinScore", func() {
|
||||
det := &stubNERDetector{entities: []NEREntity{
|
||||
{Group: "PER", Start: 0, End: 5, Score: 0.20},
|
||||
}}
|
||||
r := NewRedactor(nil)
|
||||
res, err := r.RedactWithNER(context.Background(), "Alice", nil, NERConfig{
|
||||
res, err := RedactNER(context.Background(), "Alice", []NERConfig{{
|
||||
Detector: det,
|
||||
MinScore: 0.50,
|
||||
EntityActions: map[string]Action{"PER": ActionMask},
|
||||
})
|
||||
}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(Equal("Alice"), "low-confidence entity should be dropped")
|
||||
})
|
||||
@@ -67,108 +62,120 @@ var _ = Describe("RedactWithNER", func() {
|
||||
det := &stubNERDetector{entities: []NEREntity{
|
||||
{Group: "ORG", Start: 7, End: 11, Score: 0.9}, // "Acme" in "Hello, Acme!"
|
||||
}}
|
||||
r := NewRedactor(nil)
|
||||
res, err := r.RedactWithNER(context.Background(), "Hello, Acme!", nil, NERConfig{
|
||||
res, err := RedactNER(context.Background(), "Hello, Acme!", []NERConfig{{
|
||||
Detector: det,
|
||||
DefaultAction: ActionMask,
|
||||
})
|
||||
}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:ORG]"), "DefaultAction should apply to ORG")
|
||||
})
|
||||
|
||||
It("drops unconfigured groups with no default", func() {
|
||||
// EntityActions has no entry for ORG and DefaultAction is empty —
|
||||
// the detected entity must be ignored entirely (no audit row, no
|
||||
// redaction).
|
||||
det := &stubNERDetector{entities: []NEREntity{
|
||||
{Group: "ORG", Start: 0, End: 4, Score: 0.9},
|
||||
}}
|
||||
r := NewRedactor(nil)
|
||||
res, err := r.RedactWithNER(context.Background(), "Acme", nil, NERConfig{
|
||||
res, err := RedactNER(context.Background(), "Acme", []NERConfig{{
|
||||
Detector: det,
|
||||
EntityActions: map[string]Action{"PER": ActionMask}, // ORG is unconfigured
|
||||
})
|
||||
}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(Equal("Acme"))
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("overlapping hits keep stronger action", func() {
|
||||
// Regex marks 0..10 as mask; NER marks 5..15 as block. After
|
||||
// merge, the union 0..15 keeps the strongest action (block).
|
||||
pat := Pattern{ID: "test", Action: ActionMask, regex: rangeRegex(0, 10)}
|
||||
r := NewRedactor([]Pattern{pat})
|
||||
det := &stubNERDetector{entities: []NEREntity{
|
||||
{Group: "PER", Start: 5, End: 15, Score: 0.9},
|
||||
}}
|
||||
It("unions multiple detectors and keeps the stronger action on overlap", func() {
|
||||
// Detector A marks 0..10 as mask; detector B marks 5..15 as block.
|
||||
// After merge, the union 0..15 keeps the strongest action (block).
|
||||
detA := &stubNERDetector{entities: []NEREntity{{Group: "A", Start: 0, End: 10, Score: 0.9}}}
|
||||
detB := &stubNERDetector{entities: []NEREntity{{Group: "B", Start: 5, End: 15, Score: 0.9}}}
|
||||
text := "0123456789ABCDEF"
|
||||
res, err := r.RedactWithNER(context.Background(), text, nil, NERConfig{
|
||||
Detector: det,
|
||||
EntityActions: map[string]Action{"PER": ActionBlock},
|
||||
res, err := RedactNER(context.Background(), text, []NERConfig{
|
||||
{Detector: detA, EntityActions: map[string]Action{"A": ActionMask}},
|
||||
{Detector: detB, EntityActions: map[string]Action{"B": ActionBlock}},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(detA.calls).To(Equal(1))
|
||||
Expect(detB.calls).To(Equal(1))
|
||||
Expect(res.Blocked).To(BeTrue(), "overlapping mask+block should set Blocked=true")
|
||||
})
|
||||
|
||||
It("detector error returns regex result and error", func() {
|
||||
// Fail-open: when the NER detector errors, the redactor still
|
||||
// returns regex-tier hits so an offline NER backend doesn't strip
|
||||
// the cheap protection. Caller can read the error and decide
|
||||
// whether to surface it.
|
||||
det := &stubNERDetector{err: errors.New("backend offline")}
|
||||
r := NewRedactor([]Pattern{pickEmail()})
|
||||
res, err := r.RedactWithNER(context.Background(), "ping alice@example.com", nil, NERConfig{
|
||||
Detector: det,
|
||||
DefaultAction: ActionMask,
|
||||
It("returns a best-effort result and the error when a detector fails (fail-closed contract)", func() {
|
||||
// One healthy detector, one failing. RedactNER returns the healthy
|
||||
// detector's hits AND the error, so the caller can fail closed.
|
||||
good := &stubNERDetector{entities: []NEREntity{{Group: "PER", Start: 0, End: 5, Score: 0.9}}}
|
||||
bad := &stubNERDetector{err: errors.New("backend offline")}
|
||||
res, err := RedactNER(context.Background(), "Alice", []NERConfig{
|
||||
{Detector: good, DefaultAction: ActionMask},
|
||||
{Detector: bad, DefaultAction: ActionMask},
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "expected detector error to surface")
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:email]"), "regex tier should still apply on NER failure")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:PER]"), "healthy detector's hits should still apply")
|
||||
})
|
||||
|
||||
It("out-of-bounds offsets are skipped", func() {
|
||||
// A misconfigured / buggy backend could return offsets past the
|
||||
// end of text. The redactor must not panic on slice OOB.
|
||||
It("skips out-of-bounds offsets without panicking", func() {
|
||||
det := &stubNERDetector{entities: []NEREntity{
|
||||
{Group: "PER", Start: 0, End: 999, Score: 0.9},
|
||||
{Group: "PER", Start: -1, End: 3, Score: 0.9},
|
||||
{Group: "PER", Start: 5, End: 5, Score: 0.9}, // zero-length
|
||||
}}
|
||||
r := NewRedactor(nil)
|
||||
res, err := r.RedactWithNER(context.Background(), "Alice", nil, NERConfig{
|
||||
res, err := RedactNER(context.Background(), "Alice", []NERConfig{{
|
||||
Detector: det,
|
||||
DefaultAction: ActionMask,
|
||||
})
|
||||
}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(Equal("Alice"))
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
// --- test helpers ---
|
||||
var _ = Describe("NERConfigFromRaw", func() {
|
||||
det := &stubNERDetector{}
|
||||
|
||||
// rangeMatcher is a deterministic regexpMatcher stub: it claims one
|
||||
// fixed range regardless of input. Lets the overlap-merge test
|
||||
// produce a known regex/NER intersection without depending on a real
|
||||
// compiled regex.
|
||||
type rangeMatcher struct{ start, end int }
|
||||
It("defaults an empty default_action to mask and an empty source to ner", func() {
|
||||
cfg := NERConfigFromRaw(det, 0.4, "", nil, "")
|
||||
Expect(cfg.DefaultAction).To(Equal(ActionMask))
|
||||
Expect(cfg.MinScore).To(BeNumerically("~", 0.4, 1e-6))
|
||||
Expect(cfg.Source).To(Equal(SourceNER))
|
||||
Expect(cfg.patternID("EMAIL")).To(Equal("ner:EMAIL"))
|
||||
})
|
||||
|
||||
func (m rangeMatcher) FindAllStringIndex(_ string, _ int) [][]int {
|
||||
return [][]int{{m.start, m.end}}
|
||||
}
|
||||
It("passes through valid actions and drops invalid ones", func() {
|
||||
cfg := NERConfigFromRaw(det, 0, "block", map[string]string{
|
||||
"PASSWORD": "block",
|
||||
"EMAIL": "mask",
|
||||
"BOGUS": "nonsense", // dropped
|
||||
}, SourceNER)
|
||||
Expect(cfg.DefaultAction).To(Equal(ActionBlock))
|
||||
Expect(cfg.EntityActions).To(HaveKeyWithValue("PASSWORD", ActionBlock))
|
||||
Expect(cfg.EntityActions).To(HaveKeyWithValue("EMAIL", ActionMask))
|
||||
Expect(cfg.EntityActions).NotTo(HaveKey("BOGUS"))
|
||||
})
|
||||
|
||||
func rangeRegex(start, end int) regexpMatcher { return rangeMatcher{start: start, end: end} }
|
||||
It("prefixes pattern-detector hits with the pattern source", func() {
|
||||
cfg := NERConfigFromRaw(det, 0, "mask", nil, SourcePattern)
|
||||
Expect(cfg.Source).To(Equal(SourcePattern))
|
||||
Expect(cfg.patternID("ANTHROPIC_KEY")).To(Equal("pattern:ANTHROPIC_KEY"))
|
||||
})
|
||||
})
|
||||
|
||||
// pickEmail returns the compiled "email" pattern from DefaultPatterns
|
||||
// — the NER tests use it as the regex tier's contribution.
|
||||
func pickEmail() Pattern {
|
||||
for _, p := range DefaultPatterns() {
|
||||
if p.ID == "email" {
|
||||
compiled, err := Compile([]Pattern{p})
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
|
||||
return compiled[0]
|
||||
}
|
||||
}
|
||||
Fail("email pattern missing from DefaultPatterns")
|
||||
return Pattern{}
|
||||
}
|
||||
var _ = Describe("NERConfig.ResolveAction", func() {
|
||||
It("prefers an explicit entity action over the default", func() {
|
||||
cfg := NERConfig{EntityActions: map[string]Action{"EMAIL": ActionBlock}, DefaultAction: ActionMask}
|
||||
a, ok := cfg.ResolveAction("EMAIL")
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(a).To(Equal(ActionBlock))
|
||||
})
|
||||
|
||||
It("falls back to the default action", func() {
|
||||
cfg := NERConfig{DefaultAction: ActionMask}
|
||||
a, ok := cfg.ResolveAction("ANYTHING")
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(a).To(Equal(ActionMask))
|
||||
})
|
||||
|
||||
It("ignores a group with no override and no default", func() {
|
||||
cfg := NERConfig{}
|
||||
_, ok := cfg.ResolveAction("ANYTHING")
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,188 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// regexpMatcher is a thin wrapper so tests can swap in a deterministic
|
||||
// matcher without touching the regexp package. Real usage uses
|
||||
// regexpMatcherFromPattern; tests can construct fakes.
|
||||
type regexpMatcher interface {
|
||||
FindAllStringIndex(s string, n int) [][]int
|
||||
}
|
||||
|
||||
type goRegexp struct{ r *regexp.Regexp }
|
||||
|
||||
func (g goRegexp) FindAllStringIndex(s string, n int) [][]int {
|
||||
return g.r.FindAllStringIndex(s, n)
|
||||
}
|
||||
|
||||
// DefaultPatterns returns the built-in regex set. Each entry includes
|
||||
// a conservative MaxMatchLength so the streaming filter can size its
|
||||
// tail buffer without re-parsing the regex at runtime.
|
||||
//
|
||||
// Caveats by design:
|
||||
// - The phone pattern matches international and US formats but does
|
||||
// not validate area codes. False positives on numbers that look
|
||||
// phone-like (e.g., timestamps in some formats) are accepted in
|
||||
// return for reliable coverage.
|
||||
// - The credit card pattern requires the Luhn check (verifyLuhn) to
|
||||
// reduce false positives — random 16-digit strings won't match.
|
||||
// - The API-key pattern targets common provider prefixes (sk-, pk-,
|
||||
// xoxb-, ghp_, github_pat_) rather than guessing entropy. Adding
|
||||
// new providers should append a new Pattern, not extend an
|
||||
// existing alternation, so the admin UI can show one row per
|
||||
// provider with its own toggle.
|
||||
func DefaultPatterns() []Pattern {
|
||||
return []Pattern{
|
||||
{
|
||||
ID: "email",
|
||||
Description: "Email address",
|
||||
Action: ActionMask,
|
||||
MaxMatchLength: 254, // RFC 5321 max
|
||||
},
|
||||
{
|
||||
ID: "phone",
|
||||
Description: "Phone number (international or US format)",
|
||||
Action: ActionMask,
|
||||
MaxMatchLength: 24,
|
||||
},
|
||||
{
|
||||
ID: "ssn",
|
||||
Description: "US Social Security Number (NNN-NN-NNNN)",
|
||||
Action: ActionMask,
|
||||
MaxMatchLength: 11,
|
||||
},
|
||||
{
|
||||
ID: "credit_card",
|
||||
Description: "Credit card number (Luhn-verified)",
|
||||
Action: ActionMask,
|
||||
MaxMatchLength: 19,
|
||||
},
|
||||
{
|
||||
ID: "ipv4",
|
||||
Description: "IPv4 address",
|
||||
Action: ActionMask,
|
||||
MaxMatchLength: 15,
|
||||
},
|
||||
{
|
||||
ID: "api_key_prefix",
|
||||
Description: "Common API key prefixes (sk-, pk-, xoxb-, ghp_, github_pat_)",
|
||||
Action: ActionBlock, // tighter default — leaked credentials are higher harm
|
||||
MaxMatchLength: 200,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// patternRegexps maps Pattern.ID to its compiled regex. Kept separate
|
||||
// from the Pattern struct so DefaultPatterns can be data-only and
|
||||
// tests can swap matchers via Compile().
|
||||
var patternRegexps = map[string]*regexp.Regexp{
|
||||
// Pragmatic email — does not implement RFC 5322 in full (no one
|
||||
// sane does in a regex). Catches the common shape; the encoder
|
||||
// NER tier (future) catches edge cases.
|
||||
"email": regexp.MustCompile(`(?i)[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}`),
|
||||
// US: (123) 456-7890, 123-456-7890, 123.456.7890, 1234567890.
|
||||
// International: +<country>-<area>-<rest> with separators.
|
||||
"phone": regexp.MustCompile(`(?:\+?\d{1,3}[\s\-.]?)?(?:\(\d{3}\)|\d{3})[\s\-.]?\d{3}[\s\-.]?\d{4}`),
|
||||
"ssn": regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`),
|
||||
// 13-19 digit Luhn-eligible runs. The verifier in match() rejects
|
||||
// non-Luhn matches.
|
||||
"credit_card": regexp.MustCompile(`\b(?:\d[ \-]?){13,19}\b`),
|
||||
"ipv4": regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`),
|
||||
// Common provider prefixes; each alternative is a separate
|
||||
// well-known marker rather than a permissive entropy match.
|
||||
"api_key_prefix": regexp.MustCompile(`(?:sk-[A-Za-z0-9]{20,}|pk-[A-Za-z0-9]{20,}|xoxb-[A-Za-z0-9\-]{20,}|ghp_[A-Za-z0-9]{20,}|github_pat_[A-Za-z0-9_]{20,})`),
|
||||
}
|
||||
|
||||
// Compile attaches matchers to each pattern. Patterns whose ID is not
|
||||
// in patternRegexps are returned as a typed error so an admin who
|
||||
// adds a custom pattern via config gets a clear "no regex registered"
|
||||
// message instead of silent skip.
|
||||
func Compile(patterns []Pattern) ([]Pattern, error) {
|
||||
out := make([]Pattern, len(patterns))
|
||||
for i, p := range patterns {
|
||||
r, ok := patternRegexps[p.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("pii: no regex registered for pattern id %q", p.ID)
|
||||
}
|
||||
p.regex = goRegexp{r: r}
|
||||
out[i] = p
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// VerifyMatch applies pattern-specific post-checks (e.g. Luhn for
|
||||
// credit_card). Returns the original match or "" to discard it.
|
||||
func VerifyMatch(patternID, candidate string) string {
|
||||
switch patternID {
|
||||
case "credit_card":
|
||||
digits := stripNonDigits(candidate)
|
||||
if len(digits) < 13 || len(digits) > 19 {
|
||||
return ""
|
||||
}
|
||||
if !verifyLuhn(digits) {
|
||||
return ""
|
||||
}
|
||||
case "ipv4":
|
||||
// Each octet must be 0..255. The regex allows 0..999 since
|
||||
// regex isn't great at numeric ranges; we tighten here.
|
||||
for oct := range strings.SplitSeq(candidate, ".") {
|
||||
n := 0
|
||||
for _, c := range oct {
|
||||
if c < '0' || c > '9' {
|
||||
return ""
|
||||
}
|
||||
n = n*10 + int(c-'0')
|
||||
}
|
||||
if n > 255 {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
|
||||
func stripNonDigits(s string) string {
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
for _, c := range s {
|
||||
if c >= '0' && c <= '9' {
|
||||
b.WriteRune(c)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// verifyLuhn implements the Luhn checksum used by credit-card numbers.
|
||||
// Returns true iff the digits pass.
|
||||
func verifyLuhn(digits string) bool {
|
||||
sum := 0
|
||||
double := false
|
||||
for i := len(digits) - 1; i >= 0; i-- {
|
||||
d := int(digits[i] - '0')
|
||||
if double {
|
||||
d *= 2
|
||||
if d > 9 {
|
||||
d -= 9
|
||||
}
|
||||
}
|
||||
sum += d
|
||||
double = !double
|
||||
}
|
||||
return sum%10 == 0
|
||||
}
|
||||
|
||||
// MaxPatternLength returns the longest MaxMatchLength across the input
|
||||
// patterns. Used by the streaming filter to size its tail buffer.
|
||||
func MaxPatternLength(patterns []Pattern) int {
|
||||
max := 0
|
||||
for _, p := range patterns {
|
||||
if p.MaxMatchLength > max {
|
||||
max = p.MaxMatchLength
|
||||
}
|
||||
}
|
||||
return max
|
||||
}
|
||||
@@ -4,212 +4,152 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// rawHit is one detection — regex-side or NER-side — before
|
||||
// overlap-merging. Lifted to file scope so the regex and NER
|
||||
// collectors can both produce them and feed the same merge/emit step.
|
||||
// rawHit is one detection before overlap-merging. Lifted to file scope so
|
||||
// the NER collector and the merge/emit step can share it.
|
||||
type rawHit struct {
|
||||
patternID string
|
||||
action Action
|
||||
start int
|
||||
end int
|
||||
score float32
|
||||
}
|
||||
|
||||
// Redactor scans text against a configured pattern set and applies the
|
||||
// per-pattern action. The pattern set itself is mutable at runtime via
|
||||
// SetAction (the /api/pii/patterns/:id admin endpoint mutates it
|
||||
// in-place); reads are guarded by a mutex so concurrent requests stay
|
||||
// race-free.
|
||||
type Redactor struct {
|
||||
mu sync.RWMutex
|
||||
patterns []Pattern
|
||||
maxLen int
|
||||
}
|
||||
// Redactor is a stateless handle for the PII subsystem. The regex tier
|
||||
// was removed: detection is driven entirely by per-model NER detectors
|
||||
// (see RedactNER), whose policy lives on each detector model's
|
||||
// pii_detection config. The type is retained (zero-field) as the
|
||||
// on/off sentinel the application wiring and middleware gate on, so a
|
||||
// nil *Redactor still means "PII subsystem unavailable".
|
||||
type Redactor struct{}
|
||||
|
||||
// NewRedactor constructs a redactor from a list of compiled patterns
|
||||
// (use Compile() to compile config-loaded patterns first). nil
|
||||
// patterns is valid and produces a no-op redactor — convenient for the
|
||||
// "PII disabled" deployment.
|
||||
func NewRedactor(patterns []Pattern) *Redactor {
|
||||
return &Redactor{
|
||||
patterns: patterns,
|
||||
maxLen: MaxPatternLength(patterns),
|
||||
}
|
||||
}
|
||||
|
||||
// MaxPatternLength is exposed so the streaming wrapper can size its
|
||||
// tail buffer to match.
|
||||
func (r *Redactor) MaxPatternLength() int { return r.maxLen }
|
||||
|
||||
// Patterns returns a copy of the configured pattern set so callers can
|
||||
// iterate without holding the redactor lock. The compiled regexes are
|
||||
// shared — they are immutable once built.
|
||||
func (r *Redactor) Patterns() []Pattern {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return slices.Clone(r.patterns)
|
||||
}
|
||||
|
||||
// SetAction overrides the action for a single pattern. Used by the
|
||||
// /api/pii/patterns/:id admin endpoint and the set_pii_pattern_action
|
||||
// MCP tool — transient until process restart unless persisted via
|
||||
// --pii-config.
|
||||
// RedactNER runs every configured NER detector over text, unions their
|
||||
// detections, and emits one redacted output. Each NERConfig carries its
|
||||
// own detector and policy (min score, entity→action map, default
|
||||
// action), so a consuming model that references several detector models
|
||||
// gets each model's policy applied to its own hits before the overlap
|
||||
// merge (block > mask > allow) resolves any span two detectors both
|
||||
// claim.
|
||||
//
|
||||
// Publishes a new slice so concurrent Redact callers iterating an
|
||||
// older snapshot don't race on the per-element Action string (Go
|
||||
// strings are not atomic two-word values).
|
||||
func (r *Redactor) SetAction(id string, action Action) error {
|
||||
if action != ActionMask && action != ActionBlock && action != ActionAllow {
|
||||
return fmt.Errorf("unknown action %q (must be mask, block, or allow)", action)
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for i := range r.patterns {
|
||||
if r.patterns[i].ID == id {
|
||||
next := slices.Clone(r.patterns)
|
||||
next[i].Action = action
|
||||
r.patterns = next
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("unknown pattern id %q", id)
|
||||
}
|
||||
|
||||
// SetDisabled toggles a pattern's enabled state in the live redactor.
|
||||
// Same COW publish as SetAction.
|
||||
func (r *Redactor) SetDisabled(id string, disabled bool) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for i := range r.patterns {
|
||||
if r.patterns[i].ID == id {
|
||||
next := slices.Clone(r.patterns)
|
||||
next[i].Disabled = disabled
|
||||
r.patterns = next
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("unknown pattern id %q", id)
|
||||
}
|
||||
|
||||
// Redact is a thin wrapper for callers that don't need per-request
|
||||
// action overrides. It applies each pattern's compiled-in default
|
||||
// action.
|
||||
func (r *Redactor) Redact(text string) Result {
|
||||
return r.RedactWithOverrides(text, nil)
|
||||
}
|
||||
|
||||
// RedactWithOverrides scans text and returns the result. The override
|
||||
// map is keyed by pattern id; when present, the value replaces the
|
||||
// pattern's compiled-in action for this call only — the redactor's
|
||||
// stored action is unchanged. Pattern ids missing from the map use
|
||||
// their stored action.
|
||||
// Any detector error is returned alongside a best-effort Result built
|
||||
// from the detectors that did succeed, so the caller can fail closed
|
||||
// (refuse the request) while still seeing what the healthy detectors
|
||||
// found. Configs with a nil Detector are skipped.
|
||||
//
|
||||
// For every match it records a Span (with HashPrefix, never the value)
|
||||
// and applies the resolved Action:
|
||||
// - block: sets Result.Blocked, leaves text intact (caller decides
|
||||
// whether to surface the redacted form).
|
||||
// - mask: replaces the span with maskFor(pattern.ID), sets Result.Masked.
|
||||
// - allow: leaves text intact and sets no flag (the span is still
|
||||
// recorded so the match is auditable).
|
||||
//
|
||||
// Spans are returned in the original input's coordinate system so the
|
||||
// PIIEvent record can be written without re-running the scan.
|
||||
func (r *Redactor) RedactWithOverrides(text string, overrides map[string]Action) Result {
|
||||
return r.redact(context.Background(), text, overrides, NERConfig{})
|
||||
}
|
||||
|
||||
// RedactWithNER is the encoder-tier variant: runs both the regex tier
|
||||
// (with per-pattern overrides) and the NER tier, merges hits, and
|
||||
// emits one redacted output. A nil NERConfig.Detector skips the NER
|
||||
// pass — callers can hand the same path the same NERConfig{} whether
|
||||
// or not the model has NER configured.
|
||||
//
|
||||
// Errors from the NER detector are returned alongside a best-effort
|
||||
// regex-only Result so the caller can decide whether to fail open
|
||||
// (return the regex Result, log the error) or fail closed (refuse the
|
||||
// request). The regex tier never errors.
|
||||
func (r *Redactor) RedactWithNER(ctx context.Context, text string, overrides map[string]Action, nerCfg NERConfig) (Result, error) {
|
||||
if nerCfg.Detector == nil {
|
||||
return r.redact(ctx, text, overrides, nerCfg), nil
|
||||
}
|
||||
hits, err := r.collectRegexHits(text, overrides)
|
||||
if err != nil {
|
||||
return Result{Redacted: text}, err
|
||||
}
|
||||
nerHits, nerErr := collectNERHits(ctx, text, nerCfg)
|
||||
if nerErr != nil {
|
||||
// Return the regex-only result so a NER-backend outage doesn't
|
||||
// strip the cheap protection. Caller decides fail-open vs
|
||||
// fail-closed via the returned error.
|
||||
return mergeAndEmit(text, hits), nerErr
|
||||
}
|
||||
return mergeAndEmit(text, append(hits, nerHits...)), nil
|
||||
}
|
||||
|
||||
// redact is the internal regex-only entry point. RedactWithOverrides
|
||||
// is the public wrapper; RedactWithNER routes through here only when
|
||||
// the NER detector is nil (so the call site doesn't need a separate
|
||||
// "regex-only" code path).
|
||||
func (r *Redactor) redact(_ context.Context, text string, overrides map[string]Action, _ NERConfig) Result {
|
||||
hits, _ := r.collectRegexHits(text, overrides)
|
||||
return mergeAndEmit(text, hits)
|
||||
}
|
||||
|
||||
// collectRegexHits walks the configured pattern set against text and
|
||||
// returns each verified match as a rawHit. The redactor lock is held
|
||||
// only long enough to snapshot the pattern slice — regex evaluation
|
||||
// runs lock-free against the snapshot, so SetAction/SetDisabled don't
|
||||
// stall a long-running Redact.
|
||||
func (r *Redactor) collectRegexHits(text string, overrides map[string]Action) ([]rawHit, error) {
|
||||
r.mu.RLock()
|
||||
patterns := r.patterns
|
||||
r.mu.RUnlock()
|
||||
|
||||
if len(patterns) == 0 || text == "" {
|
||||
return nil, nil
|
||||
// Package-level (no Redactor state): both the in-band request middleware
|
||||
// and the MITM request path call it with their own resolved []NERConfig.
|
||||
func RedactNER(ctx context.Context, text string, cfgs []NERConfig) (Result, error) {
|
||||
if text == "" || len(cfgs) == 0 {
|
||||
return Result{Redacted: text}, nil
|
||||
}
|
||||
var hits []rawHit
|
||||
for _, p := range patterns {
|
||||
if p.regex == nil {
|
||||
// Pattern declared but Compile() not called. Skip rather
|
||||
// than panic; the caller already saw an error from Compile.
|
||||
var firstErr error
|
||||
for _, cfg := range cfgs {
|
||||
if cfg.Detector == nil {
|
||||
continue
|
||||
}
|
||||
if p.Disabled {
|
||||
h, err := collectNERHits(ctx, text, cfg)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
action := p.Action
|
||||
if override, ok := overrides[p.ID]; ok {
|
||||
action = override
|
||||
hits = append(hits, h...)
|
||||
}
|
||||
return mergeAndEmit(text, hits), firstErr
|
||||
}
|
||||
|
||||
// segmentSeparator joins per-message texts into the single document
|
||||
// RedactNERSegments scans. Two newlines read as a paragraph break to the
|
||||
// NER encoder — neutral, in-distribution context — and never carry PII
|
||||
// themselves, so a detected span landing on the separator can only be the
|
||||
// fringe of an entity that started in a real segment.
|
||||
const segmentSeparator = "\n\n"
|
||||
|
||||
// RedactNERSegments scans texts as ONE concatenated document and maps the
|
||||
// detections back to one Result per input text. Scanning the segments
|
||||
// together is what gives the NER tier conversational context: whether
|
||||
// "jdoe_42" is a USERNAME or "4421" is a PIN is decided by the question
|
||||
// asked in the *previous* message, and a bidirectional encoder only sees
|
||||
// that context if both messages are in the same forward pass. (Measured on
|
||||
// privacy-filter-multilingual: "4421" alone detects nothing; preceded by
|
||||
// "What are the last four digits of your card?" it detects PIN at 0.726.)
|
||||
//
|
||||
// Span offsets in each Result are local to its text, so callers rewrite
|
||||
// fields in place exactly as with per-text RedactNER. A hit that crosses a
|
||||
// segment boundary is split and each fragment keeps the hit's action —
|
||||
// conservative, and only possible for an entity the model stretched across
|
||||
// the separator. Error semantics mirror RedactNER: best-effort results
|
||||
// plus the first detector error, so callers can fail closed.
|
||||
func RedactNERSegments(ctx context.Context, texts []string, cfgs []NERConfig) ([]Result, error) {
|
||||
results := make([]Result, len(texts))
|
||||
if len(texts) == 0 || len(cfgs) == 0 {
|
||||
for i := range results {
|
||||
results[i] = Result{Redacted: texts[i]}
|
||||
}
|
||||
idxs := p.regex.FindAllStringIndex(text, -1)
|
||||
for _, idx := range idxs {
|
||||
candidate := text[idx[0]:idx[1]]
|
||||
if VerifyMatch(p.ID, candidate) == "" {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
var joined strings.Builder
|
||||
starts := make([]int, len(texts))
|
||||
ends := make([]int, len(texts))
|
||||
for i, t := range texts {
|
||||
if i > 0 {
|
||||
joined.WriteString(segmentSeparator)
|
||||
}
|
||||
starts[i] = joined.Len()
|
||||
joined.WriteString(t)
|
||||
ends[i] = joined.Len()
|
||||
}
|
||||
doc := joined.String()
|
||||
|
||||
var hits []rawHit
|
||||
var firstErr error
|
||||
for _, cfg := range cfgs {
|
||||
if cfg.Detector == nil {
|
||||
continue
|
||||
}
|
||||
h, err := collectNERHits(ctx, doc, cfg)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
hits = append(hits, h...)
|
||||
}
|
||||
|
||||
perSegment := make([][]rawHit, len(texts))
|
||||
for _, h := range hits {
|
||||
for i := range texts {
|
||||
s := max(h.start, starts[i])
|
||||
e := min(h.end, ends[i])
|
||||
if s >= e {
|
||||
continue
|
||||
}
|
||||
hits = append(hits, rawHit{
|
||||
patternID: p.ID,
|
||||
action: action,
|
||||
start: idx[0],
|
||||
end: idx[1],
|
||||
})
|
||||
local := h
|
||||
local.start = s - starts[i]
|
||||
local.end = e - starts[i]
|
||||
perSegment[i] = append(perSegment[i], local)
|
||||
}
|
||||
}
|
||||
return hits, nil
|
||||
for i := range texts {
|
||||
results[i] = mergeAndEmit(texts[i], perSegment[i])
|
||||
}
|
||||
return results, firstErr
|
||||
}
|
||||
|
||||
// collectNERHits invokes the configured NERDetector and converts each
|
||||
// returned entity into a rawHit using the NERConfig's action map.
|
||||
// Entities below MinScore or with no resolved action are dropped — the
|
||||
// detector doesn't know which entity groups the admin cares about, so
|
||||
// the redactor filters here.
|
||||
// the policy filters here.
|
||||
func collectNERHits(ctx context.Context, text string, cfg NERConfig) ([]rawHit, error) {
|
||||
if cfg.Detector == nil || text == "" {
|
||||
return nil, nil
|
||||
@@ -220,42 +160,58 @@ func collectNERHits(ctx context.Context, text string, cfg NERConfig) ([]rawHit,
|
||||
}
|
||||
var hits []rawHit
|
||||
for _, e := range entities {
|
||||
// One DEBUG line per raw detection with the model's confidence, the
|
||||
// byte range, the matched substring, and the policy decision. This is
|
||||
// the lowest-level view of why a request was masked/blocked — e.g. a
|
||||
// phone number scored as SSN — and answers "what was in that range and
|
||||
// how sure was the model" without re-running the detector. DEBUG-gated
|
||||
// because the matched value is sensitive.
|
||||
if e.Score < cfg.MinScore {
|
||||
xlog.Debug("pii/ner: detection dropped (below min score)",
|
||||
"group", e.Group, "score", e.Score, "min_score", cfg.MinScore,
|
||||
"start", e.Start, "end", e.End, "text", e.Text)
|
||||
continue
|
||||
}
|
||||
action, ok := cfg.ResolveAction(e.Group)
|
||||
if !ok {
|
||||
xlog.Debug("pii/ner: detection ignored (no action for group)",
|
||||
"group", e.Group, "score", e.Score,
|
||||
"start", e.Start, "end", e.End, "text", e.Text)
|
||||
continue
|
||||
}
|
||||
if e.Start < 0 || e.End <= e.Start || e.End > len(text) {
|
||||
// Defensive: the backend should return byte offsets into
|
||||
// the original text, but a misconfigured model could
|
||||
// produce garbage. Skip rather than panic on slice OOB.
|
||||
// Defensive: the backend should return byte offsets into the
|
||||
// original text, but a misconfigured model could produce
|
||||
// garbage. Skip rather than panic on slice OOB.
|
||||
xlog.Warn("pii/ner: detection has out-of-range offsets; skipping",
|
||||
"group", e.Group, "start", e.Start, "end", e.End, "text_len", len(text))
|
||||
continue
|
||||
}
|
||||
xlog.Debug("pii/ner: detection accepted",
|
||||
"group", e.Group, "score", e.Score, "action", action,
|
||||
"start", e.Start, "end", e.End, "text", e.Text)
|
||||
hits = append(hits, rawHit{
|
||||
patternID: nerPatternID(e.Group),
|
||||
patternID: cfg.patternID(e.Group),
|
||||
action: action,
|
||||
start: e.Start,
|
||||
end: e.End,
|
||||
score: e.Score,
|
||||
})
|
||||
}
|
||||
return hits, nil
|
||||
}
|
||||
|
||||
// mergeAndEmit handles the overlap-merge + masked-output step that
|
||||
// regex-only and combined regex+NER redactions both perform. Sorts by
|
||||
// mergeAndEmit handles the overlap-merge + masked-output step. Sorts by
|
||||
// start (stable on equal starts by descending action strength), drops
|
||||
// overlapping hits in favour of the stronger action, and walks the
|
||||
// text once to emit replacement spans.
|
||||
// overlapping hits in favour of the stronger action, and walks the text
|
||||
// once to emit replacement spans.
|
||||
func mergeAndEmit(text string, hits []rawHit) Result {
|
||||
if len(hits) == 0 {
|
||||
return Result{Redacted: text}
|
||||
}
|
||||
// Sort and deduplicate overlapping hits — when two patterns claim
|
||||
// the same span (e.g., a credit-card-shaped value also scans as
|
||||
// digits, or NER tags a span the regex also caught), keep the one
|
||||
// with the strongest action. Order: block > mask > allow.
|
||||
// Sort and deduplicate overlapping hits — when two detectors claim
|
||||
// the same span, keep the one with the strongest action. Order:
|
||||
// block > mask > allow.
|
||||
sort.Slice(hits, func(i, j int) bool {
|
||||
if hits[i].start != hits[j].start {
|
||||
return hits[i].start < hits[j].start
|
||||
@@ -270,6 +226,7 @@ func mergeAndEmit(text string, hits []rawHit) Result {
|
||||
if actionRank(h.action) > actionRank(last.action) {
|
||||
last.action = h.action
|
||||
last.patternID = h.patternID
|
||||
last.score = h.score
|
||||
}
|
||||
if h.end > last.end {
|
||||
last.end = h.end
|
||||
@@ -291,6 +248,8 @@ func mergeAndEmit(text string, hits []rawHit) Result {
|
||||
End: h.end,
|
||||
Pattern: h.patternID,
|
||||
HashPrefix: hashPrefix(matched),
|
||||
Action: h.action,
|
||||
Score: h.score,
|
||||
}
|
||||
res.Spans = append(res.Spans, span)
|
||||
|
||||
@@ -315,17 +274,15 @@ func mergeAndEmit(text string, hits []rawHit) Result {
|
||||
|
||||
// maskFor returns the placeholder that replaces a matched span. The
|
||||
// shape "[REDACTED:<id>]" is intentionally stable — it surfaces the
|
||||
// pattern id back to the model, which is sometimes useful (e.g., the
|
||||
// model can say "I see you redacted an email"). Admins who want a
|
||||
// less informative replacement can build one in front of this.
|
||||
// detector group back to the model (e.g. "I see you redacted an email").
|
||||
func maskFor(patternID string) string {
|
||||
return "[REDACTED:" + patternID + "]"
|
||||
}
|
||||
|
||||
// hashPrefix returns the first 8 chars of sha256(value). Two calls
|
||||
// with the same input produce the same prefix so an admin auditing
|
||||
// the PIIEvent log can spot a recurring leak ("the same SSN appears
|
||||
// 200 times this hour") without ever recovering the value.
|
||||
// hashPrefix returns the first 8 chars of sha256(value). Two calls with
|
||||
// the same input produce the same prefix so an admin auditing the
|
||||
// PIIEvent log can spot a recurring leak without ever recovering the
|
||||
// value.
|
||||
func hashPrefix(value string) string {
|
||||
sum := sha256.Sum256([]byte(value))
|
||||
return hex.EncodeToString(sum[:])[:8]
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Redactor_SetActionConcurrentRedact pins the SetAction copy-on-
|
||||
// write contract: concurrent SetAction must not race with readers
|
||||
// iterating an older patterns snapshot. Run with -race to surface the
|
||||
// regression that motivated the COW (in-place mutation of the
|
||||
// per-element Action string is not atomic).
|
||||
var _ = Describe("Redactor", func() {
|
||||
It("SetAction concurrent with Redact", func() {
|
||||
patterns, err := Compile(DefaultPatterns())
|
||||
Expect(err).NotTo(HaveOccurred(), "compile")
|
||||
r := NewRedactor(patterns)
|
||||
|
||||
const writers = 4
|
||||
const readers = 8
|
||||
const iter = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stop := make(chan struct{})
|
||||
|
||||
for w := 0; w < writers; w++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iter; i++ {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
}
|
||||
action := ActionMask
|
||||
if i%2 == 0 {
|
||||
action = ActionBlock
|
||||
}
|
||||
_ = r.SetAction("email", action)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for rd := 0; rd < readers; rd++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
text := "contact alice@example.com please"
|
||||
for i := 0; i < iter*2; i++ {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
}
|
||||
_ = r.Redact(text)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(stop)
|
||||
})
|
||||
})
|
||||
@@ -1,186 +1,182 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func mustCompile(ids ...string) []Pattern {
|
||||
all := DefaultPatterns()
|
||||
if len(ids) == 0 {
|
||||
out, err := Compile(all)
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
|
||||
return out
|
||||
}
|
||||
pickP := pick(all, ids)
|
||||
out, err := Compile(pickP)
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
|
||||
return out
|
||||
// detect builds a single-detector []NERConfig that reports one entity
|
||||
// over the whole input under the given group/action.
|
||||
func oneShot(group string, action Action, start, end int) []NERConfig {
|
||||
return []NERConfig{{
|
||||
Detector: &stubNERDetector{entities: []NEREntity{{Group: group, Start: start, End: end, Score: 1}}},
|
||||
EntityActions: map[string]Action{group: action},
|
||||
}}
|
||||
}
|
||||
|
||||
func pick(all []Pattern, ids []string) []Pattern {
|
||||
keep := map[string]bool{}
|
||||
for _, id := range ids {
|
||||
keep[id] = true
|
||||
}
|
||||
var out []Pattern
|
||||
for _, p := range all {
|
||||
if keep[p.ID] {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
var _ = Describe("RedactNER emission", func() {
|
||||
ctx := context.Background()
|
||||
|
||||
var _ = Describe("Redactor", func() {
|
||||
It("masks email", func() {
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
res := r.Redact("Contact me at alice@example.com any time.")
|
||||
Expect(res.Blocked).To(BeFalse(), "email is mask-action by default, should not block")
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:email]"))
|
||||
It("masks with a [REDACTED:ner:GROUP] placeholder and records a hash prefix", func() {
|
||||
res, err := RedactNER(ctx, "Contact me at alice@example.com any time.", oneShot("EMAIL", ActionMask, 14, 31))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Masked).To(BeTrue())
|
||||
Expect(res.Blocked).To(BeFalse())
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:EMAIL]"))
|
||||
Expect(res.Redacted).NotTo(ContainSubstring("alice@example.com"))
|
||||
Expect(res.Spans).To(HaveLen(1))
|
||||
Expect(res.Spans[0].HashPrefix).NotTo(BeEmpty(), "hash prefix must be set so audits can dedupe leaks")
|
||||
})
|
||||
|
||||
It("masks SSN", func() {
|
||||
r := NewRedactor(mustCompile("ssn"))
|
||||
res := r.Redact("call me about SSN 123-45-6789 please")
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ssn]"))
|
||||
It("labels pattern-detector hits with the pattern source, not ner", func() {
|
||||
cfgs := []NERConfig{{
|
||||
Detector: &stubNERDetector{entities: []NEREntity{{Group: "ANTHROPIC_KEY", Start: 4, End: 24, Score: 1}}},
|
||||
EntityActions: map[string]Action{"ANTHROPIC_KEY": ActionMask},
|
||||
Source: SourcePattern,
|
||||
}}
|
||||
res, err := RedactNER(ctx, "use sk-ant-aaaaaaaaaaaaaaaa now", cfgs)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:pattern:ANTHROPIC_KEY]"))
|
||||
Expect(res.Redacted).NotTo(ContainSubstring("[REDACTED:ner:"))
|
||||
Expect(res.Spans).To(HaveLen(1))
|
||||
Expect(res.Spans[0].Pattern).To(Equal("pattern:ANTHROPIC_KEY"))
|
||||
})
|
||||
|
||||
It("uses Luhn for credit card", func() {
|
||||
r := NewRedactor(mustCompile("credit_card"))
|
||||
|
||||
// 4111 1111 1111 1111 — canonical Luhn-valid Visa test number.
|
||||
good := r.Redact("card: 4111 1111 1111 1111")
|
||||
Expect(good.Spans).To(HaveLen(1))
|
||||
Expect(good.Redacted).To(ContainSubstring("[REDACTED:credit_card]"))
|
||||
|
||||
// 4111 1111 1111 1112 — same shape, fails Luhn. Must NOT match.
|
||||
bad := r.Redact("card: 4111 1111 1111 1112")
|
||||
Expect(bad.Spans).To(BeEmpty(), "Luhn-invalid 16-digit run must not be redacted")
|
||||
Expect(bad.Redacted).To(ContainSubstring("1112"), "Luhn-invalid input should pass through untouched")
|
||||
It("block leaves the matched span intact and sets Blocked", func() {
|
||||
res, err := RedactNER(ctx, "token sk-abcdef here", oneShot("PASSWORD", ActionBlock, 6, 15))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Blocked).To(BeTrue())
|
||||
Expect(res.Redacted).To(ContainSubstring("sk-abcdef"), "block leaves the value intact for the caller to discard")
|
||||
Expect(res.Spans[0].Action).To(Equal(ActionBlock))
|
||||
})
|
||||
|
||||
It("validates IPv4 octets", func() {
|
||||
r := NewRedactor(mustCompile("ipv4"))
|
||||
|
||||
good := r.Redact("server at 192.168.1.10 is up")
|
||||
Expect(good.Spans).To(HaveLen(1))
|
||||
|
||||
// 999.999.999.999 — regex matches but octet > 255 must reject.
|
||||
bad := r.Redact("not an ip: 999.999.999.999")
|
||||
Expect(bad.Spans).To(BeEmpty(), "ipv4 with octet>255 must not match")
|
||||
})
|
||||
|
||||
It("api_key defaults to block", func() {
|
||||
r := NewRedactor(mustCompile("api_key_prefix"))
|
||||
res := r.Redact("here's a token sk-abcdefghijklmnopqrstuvwxyz0123456789 to use")
|
||||
Expect(res.Blocked).To(BeTrue(), "api_key default action is block; Result.Blocked must be true")
|
||||
// The redacted output keeps the matched value when blocking — the
|
||||
// caller is expected to refuse the request, not to forward a partial.
|
||||
Expect(res.Redacted).To(ContainSubstring("sk-abcdefghijklmn"), "blocked actions leave the matched span intact for caller inspection")
|
||||
})
|
||||
|
||||
It("preserves non-matching text", func() {
|
||||
r := NewRedactor(mustCompile()) // all default patterns
|
||||
in := "no PII here at all, just words and numbers like 42 and 1.5"
|
||||
res := r.Redact(in)
|
||||
Expect(res.Redacted).To(Equal(in), "non-PII input should pass through unchanged")
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("handles empty input", func() {
|
||||
r := NewRedactor(mustCompile())
|
||||
res := r.Redact("")
|
||||
Expect(res.Redacted).To(BeEmpty())
|
||||
Expect(res.Blocked).To(BeFalse())
|
||||
It("allow leaves text intact but still records the span", func() {
|
||||
res, err := RedactNER(ctx, "Hello Acme!", oneShot("ORG", ActionAllow, 6, 10))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Masked).To(BeFalse())
|
||||
Expect(res.Blocked).To(BeFalse())
|
||||
Expect(res.Redacted).To(Equal("Hello Acme!"))
|
||||
Expect(res.Spans).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("passes non-matching text through unchanged", func() {
|
||||
det := &stubNERDetector{} // no entities
|
||||
res, err := RedactNER(ctx, "no PII here, just words", []NERConfig{{Detector: det, DefaultAction: ActionMask}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(Equal("no PII here, just words"))
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("nil patterns is a no-op", func() {
|
||||
// Disabled-PII deployment: pii.NewRedactor(nil) is a no-op.
|
||||
r := NewRedactor(nil)
|
||||
res := r.Redact("alice@example.com sent it")
|
||||
Expect(res.Redacted).To(Equal("alice@example.com sent it"))
|
||||
It("handles empty input without calling the detector", func() {
|
||||
det := &stubNERDetector{entities: []NEREntity{{Group: "X", Start: 0, End: 1, Score: 1}}}
|
||||
res, err := RedactNER(ctx, "", []NERConfig{{Detector: det, DefaultAction: ActionMask}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(BeEmpty())
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
Expect(det.calls).To(Equal(0))
|
||||
})
|
||||
|
||||
It("hash prefix is stable", func() {
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
a := r.Redact("a@b.com")
|
||||
b := r.Redact("hi a@b.com again")
|
||||
It("produces a stable hash prefix for the same matched value", func() {
|
||||
a, _ := RedactNER(ctx, "a@b.com", oneShot("EMAIL", ActionMask, 0, 7))
|
||||
b, _ := RedactNER(ctx, "hi a@b.com", oneShot("EMAIL", ActionMask, 3, 10))
|
||||
Expect(a.Spans).To(HaveLen(1))
|
||||
Expect(b.Spans).To(HaveLen(1))
|
||||
Expect(a.Spans[0].HashPrefix).To(Equal(b.Spans[0].HashPrefix), "same matched value must produce same hash prefix")
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Compile", func() {
|
||||
It("rejects unknown pattern id", func() {
|
||||
_, err := Compile([]Pattern{{ID: "nonexistent", Action: ActionMask}})
|
||||
Expect(err).To(HaveOccurred(), "Compile must error on unknown pattern id")
|
||||
// funcNERDetector computes entities from the text it is handed — used to
|
||||
// prove the segment scan gives the detector the JOINED document, the way a
|
||||
// context-sensitive encoder behaves.
|
||||
type funcNERDetector struct {
|
||||
fn func(text string) ([]NEREntity, error)
|
||||
}
|
||||
|
||||
func (f *funcNERDetector) Detect(_ context.Context, text string) ([]NEREntity, error) {
|
||||
return f.fn(text)
|
||||
}
|
||||
|
||||
// pinAfterCard mimics the real encoder's context sensitivity: "4421" is a
|
||||
// PIN only when "card" appears earlier in the same document (measured on
|
||||
// privacy-filter-multilingual: alone it detects nothing, with the eliciting
|
||||
// question it detects PIN).
|
||||
func pinAfterCard(text string) ([]NEREntity, error) {
|
||||
i := strings.Index(text, "4421")
|
||||
if i < 0 || !strings.Contains(text[:i], "card") {
|
||||
return nil, nil
|
||||
}
|
||||
return []NEREntity{{Group: "PIN", Start: i, End: i + 4, Score: 0.9}}, nil
|
||||
}
|
||||
|
||||
var _ = Describe("RedactNERSegments", func() {
|
||||
ctx := context.Background()
|
||||
maskCfg := func(d NERDetector) []NERConfig {
|
||||
return []NERConfig{{Detector: d, DefaultAction: ActionMask}}
|
||||
}
|
||||
|
||||
It("scans segments as one document so context crosses messages", func() {
|
||||
det := &funcNERDetector{fn: pinAfterCard}
|
||||
|
||||
// Scanned alone the digits are invisible...
|
||||
alone, err := RedactNER(ctx, "it is 4421 ok", maskCfg(det))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(alone.Spans).To(BeEmpty())
|
||||
|
||||
// ...as a segment after the eliciting question they are detected,
|
||||
// and the span maps back to the second segment with local offsets.
|
||||
res, err := RedactNERSegments(ctx,
|
||||
[]string{"What are the last four digits of your card?", "it is 4421 ok"},
|
||||
maskCfg(det))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res).To(HaveLen(2))
|
||||
Expect(res[0].Spans).To(BeEmpty())
|
||||
Expect(res[0].Redacted).To(Equal("What are the last four digits of your card?"))
|
||||
Expect(res[1].Spans).To(HaveLen(1))
|
||||
Expect(res[1].Spans[0].Start).To(Equal(6))
|
||||
Expect(res[1].Spans[0].End).To(Equal(10))
|
||||
Expect(res[1].Masked).To(BeTrue())
|
||||
Expect(res[1].Redacted).To(Equal("it is [REDACTED:ner:PIN] ok"))
|
||||
})
|
||||
|
||||
It("splits a hit crossing a segment boundary, masking both fragments", func() {
|
||||
det := &funcNERDetector{fn: func(text string) ([]NEREntity, error) {
|
||||
i := strings.Index(text, "22 Baker")
|
||||
j := strings.Index(text, "Street")
|
||||
if i < 0 || j < 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return []NEREntity{{Group: "STREET", Start: i, End: j + len("Street"), Score: 0.9}}, nil
|
||||
}}
|
||||
res, err := RedactNERSegments(ctx, []string{"22 Baker", "Street"}, maskCfg(det))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res[0].Redacted).To(Equal("[REDACTED:ner:STREET]"))
|
||||
Expect(res[1].Redacted).To(Equal("[REDACTED:ner:STREET]"))
|
||||
})
|
||||
|
||||
It("returns best-effort results with the first detector error", func() {
|
||||
bad := NERConfig{Detector: &stubNERDetector{err: errors.New("backend down")}, DefaultAction: ActionMask}
|
||||
good := NERConfig{
|
||||
Detector: &stubNERDetector{entities: []NEREntity{{Group: "PER", Start: 0, End: 5, Score: 0.9}}},
|
||||
DefaultAction: ActionMask,
|
||||
}
|
||||
res, err := RedactNERSegments(ctx, []string{"Alice", "rest"}, []NERConfig{bad, good})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(res[0].Spans).To(HaveLen(1), "healthy detector's hits still apply")
|
||||
})
|
||||
|
||||
It("is a per-text no-op without detectors or texts", func() {
|
||||
res, err := RedactNERSegments(ctx, []string{"a", ""}, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res).To(HaveLen(2))
|
||||
Expect(res[0].Redacted).To(Equal("a"))
|
||||
Expect(res[1].Redacted).To(Equal(""))
|
||||
|
||||
res, err = RedactNERSegments(ctx, nil, maskCfg(&stubNERDetector{}))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("MaxPatternLength", func() {
|
||||
It("returns the longest pattern's max length", func() {
|
||||
patterns := mustCompile("email", "ssn")
|
||||
got := MaxPatternLength(patterns)
|
||||
// email is the longer of the two (254). The streaming filter
|
||||
// will use this to size its tail buffer.
|
||||
Expect(got).To(Equal(254))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("RedactWithOverrides", func() {
|
||||
It("upgrades action", func() {
|
||||
// email is mask by default; the per-model override turns it into a
|
||||
// hard block for one request without mutating the redactor.
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
res := r.RedactWithOverrides("contact alice@example.com",
|
||||
map[string]Action{"email": ActionBlock})
|
||||
Expect(res.Blocked).To(BeTrue(), "override should have set Blocked")
|
||||
// Block leaves the value intact (the caller short-circuits the
|
||||
// request) — the redactor never echoes the matched text.
|
||||
Expect(res.Redacted).To(ContainSubstring("alice@example.com"), "block leaves text intact for the caller to discard")
|
||||
// Stored action is unchanged so a subsequent default Redact still
|
||||
// masks rather than blocks.
|
||||
res2 := r.Redact("contact alice@example.com")
|
||||
Expect(res2.Blocked).To(BeFalse(), "override must not mutate stored action")
|
||||
})
|
||||
|
||||
It("ignores unknown IDs", func() {
|
||||
// An override for a pattern this redactor doesn't know about is a
|
||||
// no-op rather than an error — per-model configs may reference
|
||||
// patterns from a wider catalogue than the active redactor holds.
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
res := r.RedactWithOverrides("contact alice@example.com",
|
||||
map[string]Action{"ssn": ActionBlock})
|
||||
Expect(res.Blocked).To(BeFalse(), "ssn override against email-only redactor must be no-op")
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("SetAction", func() {
|
||||
It("swaps in place", func() {
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
Expect(r.SetAction("email", ActionAllow)).To(Succeed())
|
||||
res := r.Redact("contact alice@example.com")
|
||||
Expect(res.Masked).To(BeFalse(), "allow leaves text intact, so nothing is masked")
|
||||
Expect(res.Redacted).To(ContainSubstring("alice@example.com"), "allow should leave the match in place")
|
||||
Expect(res.Spans).To(HaveLen(1), "allow still records the match")
|
||||
Expect(res.Blocked).To(BeFalse(), "SetAction(allow) should not block")
|
||||
})
|
||||
|
||||
It("rejects unknown id", func() {
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
Expect(r.SetAction("nonexistent", ActionMask)).NotTo(Succeed(), "expected error for unknown pattern id")
|
||||
})
|
||||
|
||||
It("rejects unknown action", func() {
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
Expect(r.SetAction("email", Action("frobnicate"))).NotTo(Succeed(), "expected error for unknown action")
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -27,7 +27,10 @@ type ListQuery struct {
|
||||
UserID string
|
||||
PatternID string
|
||||
Kind EventKind
|
||||
Limit int
|
||||
// Origin scopes the search to redaction events from one surface
|
||||
// (middleware | proxy | pii_analyze | pii_redact); empty matches any.
|
||||
Origin Origin
|
||||
Limit int
|
||||
}
|
||||
|
||||
// NewMemoryEventStore returns an in-memory ring-buffer event store.
|
||||
@@ -91,6 +94,9 @@ func (s *memoryEventStore) List(_ context.Context, q ListQuery) ([]PIIEvent, err
|
||||
if q.Kind != "" && e.ResolvedKind() != q.Kind {
|
||||
return false
|
||||
}
|
||||
if q.Origin != "" && e.Origin != q.Origin {
|
||||
return false
|
||||
}
|
||||
out = append(out, e)
|
||||
return len(out) >= limit
|
||||
}
|
||||
|
||||
48
core/services/routing/pii/store_test.go
Normal file
48
core/services/routing/pii/store_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("EventStore Origin filter", func() {
|
||||
var store EventStore
|
||||
ctx := context.Background()
|
||||
|
||||
BeforeEach(func() {
|
||||
store = NewMemoryEventStore(0)
|
||||
// Three redaction events from three different surfaces.
|
||||
for _, o := range []Origin{OriginMiddleware, OriginRedactAPI, OriginAnalyzeAPI} {
|
||||
Expect(store.Record(ctx, PIIEvent{
|
||||
ID: NewEventID(),
|
||||
Kind: KindPII,
|
||||
Origin: o,
|
||||
PatternID: "ner:EMAIL",
|
||||
})).To(Succeed())
|
||||
}
|
||||
// An older row with no Origin (pre-field) must not match any origin filter.
|
||||
Expect(store.Record(ctx, PIIEvent{ID: NewEventID(), Kind: KindPII, PatternID: "ner:EMAIL"})).To(Succeed())
|
||||
})
|
||||
|
||||
It("returns only events from the requested origin", func() {
|
||||
got, err := store.List(ctx, ListQuery{Origin: OriginRedactAPI})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Origin).To(Equal(OriginRedactAPI))
|
||||
})
|
||||
|
||||
It("an empty origin matches every event (including pre-field rows)", func() {
|
||||
got, err := store.List(ctx, ListQuery{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(HaveLen(4))
|
||||
})
|
||||
|
||||
It("does not match a pre-field (empty-origin) row against a concrete origin", func() {
|
||||
got, err := store.List(ctx, ListQuery{Origin: OriginMiddleware})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Origin).To(Equal(OriginMiddleware))
|
||||
})
|
||||
})
|
||||
@@ -1,198 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// StreamFilter applies the regex PII tier to a streaming response,
|
||||
// chunk by chunk, with a buffered-emit invariant: for any active
|
||||
// pattern with bounded max-length L, the filter never emits the
|
||||
// trailing L-1 characters of the cumulative input until either
|
||||
//
|
||||
// (a) more text arrives that disambiguates the boundary, or
|
||||
// (b) the stream closes (Drain).
|
||||
//
|
||||
// That keeps the redactor honest across chunk splits — an email
|
||||
// arriving as "alice@" + "example.com" still masks the same way as
|
||||
// "alice@example.com" arriving in one piece.
|
||||
//
|
||||
// Action handling in stream mode differs from the request-side
|
||||
// middleware. Earlier chunks of the response are already on the wire
|
||||
// by the time later chunks are scanned, so a "block" can't actually
|
||||
// reject the request. We remap block → mask for redaction purposes
|
||||
// while still recording PIIEvent rows with action="block" so audits
|
||||
// surface the original intent ("the model would have leaked X here,
|
||||
// suppressed in flight"). allow on the output side is a no-op — the
|
||||
// text is left intact, matching its request-side detect-and-log
|
||||
// behaviour.
|
||||
//
|
||||
// StreamFilter is NOT safe for concurrent use across goroutines; one
|
||||
// instance per response stream.
|
||||
type StreamFilter struct {
|
||||
redactor *Redactor
|
||||
maskOverrides map[string]Action // block → mask map used for redaction
|
||||
auditActions map[string]Action // original action per pattern, used for events
|
||||
store EventStore
|
||||
correlationID string
|
||||
userID string
|
||||
holdLen int
|
||||
buffer strings.Builder
|
||||
emittedBytes int
|
||||
}
|
||||
|
||||
// NewStreamFilter constructs a per-response filter. modelOverrides is
|
||||
// the per-model action override map (same shape the request-side
|
||||
// middleware uses); it can be nil when the model only accepts global
|
||||
// defaults.
|
||||
//
|
||||
// store may be nil — events are then computed but not persisted, which
|
||||
// is what the chat handler does when --disable-stats is set.
|
||||
func NewStreamFilter(redactor *Redactor, modelOverrides map[string]Action, store EventStore, correlationID, userID string) *StreamFilter {
|
||||
if redactor == nil {
|
||||
return &StreamFilter{}
|
||||
}
|
||||
|
||||
patterns := redactor.Patterns()
|
||||
|
||||
// auditActions: the action we *would* have applied if this match
|
||||
// occurred on the request side. Honours the per-model override.
|
||||
auditActions := make(map[string]Action, len(patterns))
|
||||
for _, p := range patterns {
|
||||
auditActions[p.ID] = p.Action
|
||||
}
|
||||
for id, action := range modelOverrides {
|
||||
auditActions[id] = action
|
||||
}
|
||||
|
||||
// maskOverrides: the action we actually apply to the stream. Same
|
||||
// as auditActions, but with every block remapped to mask.
|
||||
maskOverrides := make(map[string]Action, len(auditActions))
|
||||
for id, action := range auditActions {
|
||||
if action == ActionBlock {
|
||||
maskOverrides[id] = ActionMask
|
||||
} else {
|
||||
maskOverrides[id] = action
|
||||
}
|
||||
}
|
||||
|
||||
return &StreamFilter{
|
||||
redactor: redactor,
|
||||
maskOverrides: maskOverrides,
|
||||
auditActions: auditActions,
|
||||
store: store,
|
||||
correlationID: correlationID,
|
||||
userID: userID,
|
||||
holdLen: redactor.MaxPatternLength() - 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Push appends new text to the filter's buffer and returns the prefix
|
||||
// safe to emit downstream — the cumulative input minus a tail of
|
||||
// holdLen characters that might still be the start of a longer match.
|
||||
// Returned text has masks already applied.
|
||||
//
|
||||
// Returns an empty string when not enough text has arrived to clear
|
||||
// the hold window.
|
||||
func (sf *StreamFilter) Push(text string) string {
|
||||
if sf.redactor == nil || sf.holdLen <= 0 {
|
||||
return text
|
||||
}
|
||||
sf.buffer.WriteString(text)
|
||||
bufStr := sf.buffer.String()
|
||||
n := len(bufStr)
|
||||
|
||||
if n <= sf.holdLen {
|
||||
return ""
|
||||
}
|
||||
|
||||
emitBoundary := n - sf.holdLen
|
||||
|
||||
// Scan the entire buffer. A match whose start is before the
|
||||
// boundary but whose end runs past it crosses the window — pull
|
||||
// the boundary back to match.start so the pattern stays whole in
|
||||
// the buffer for the next Push to scan again.
|
||||
full := sf.redactor.RedactWithOverrides(bufStr, sf.maskOverrides)
|
||||
for _, span := range full.Spans {
|
||||
if span.Start < emitBoundary && span.End > emitBoundary {
|
||||
emitBoundary = span.Start
|
||||
}
|
||||
}
|
||||
|
||||
// holdLen is byte-sized but a chunk boundary may land mid-codepoint.
|
||||
// Snap back to the nearest rune start so neither the emitted prefix
|
||||
// nor the retained tail contains a split codepoint — otherwise the
|
||||
// next regex scan over an invalid-UTF-8 prefix could mis-match.
|
||||
for emitBoundary > 0 && emitBoundary < n && !utf8.RuneStart(bufStr[emitBoundary]) {
|
||||
emitBoundary--
|
||||
}
|
||||
|
||||
if emitBoundary <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
emitted := sf.applyAndEmit(bufStr[:emitBoundary])
|
||||
sf.buffer.Reset()
|
||||
sf.buffer.WriteString(bufStr[emitBoundary:])
|
||||
return emitted
|
||||
}
|
||||
|
||||
// Drain emits whatever's left in the buffer with all matches applied.
|
||||
// Call exactly once when the stream closes — repeat calls return the
|
||||
// empty string.
|
||||
func (sf *StreamFilter) Drain() string {
|
||||
if sf.redactor == nil {
|
||||
return sf.buffer.String()
|
||||
}
|
||||
bufStr := sf.buffer.String()
|
||||
if bufStr == "" {
|
||||
return ""
|
||||
}
|
||||
emitted := sf.applyAndEmit(bufStr)
|
||||
sf.buffer.Reset()
|
||||
return emitted
|
||||
}
|
||||
|
||||
// applyAndEmit runs the redactor over a committed-for-emit fragment,
|
||||
// substitutes mask/block placeholders inline, and records one
|
||||
// PIIEvent per matched span (with the audit action, not the masked
|
||||
// one). ByteOffset is referenced to the cumulative emitted output so
|
||||
// admins can correlate event positions against the streamed body.
|
||||
func (sf *StreamFilter) applyAndEmit(fragment string) string {
|
||||
res := sf.redactor.RedactWithOverrides(fragment, sf.maskOverrides)
|
||||
output := res.Redacted
|
||||
|
||||
if len(res.Spans) > 0 {
|
||||
now := time.Now().UTC()
|
||||
for _, span := range res.Spans {
|
||||
ev := PIIEvent{
|
||||
ID: newStreamEventID(),
|
||||
CorrelationID: sf.correlationID,
|
||||
UserID: sf.userID,
|
||||
Direction: DirectionOut,
|
||||
PatternID: span.Pattern,
|
||||
ByteOffset: sf.emittedBytes + span.Start,
|
||||
Length: span.End - span.Start,
|
||||
HashPrefix: span.HashPrefix,
|
||||
Action: sf.auditActions[span.Pattern],
|
||||
CreatedAt: now,
|
||||
}
|
||||
if sf.store != nil {
|
||||
_ = sf.store.Record(context.Background(), ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sf.emittedBytes += len(fragment)
|
||||
return output
|
||||
}
|
||||
|
||||
func newStreamEventID() string {
|
||||
var b [12]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
return "pii_" + hex.EncodeToString(b[:])
|
||||
}
|
||||
@@ -1,184 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func newStreamRedactor(ids ...string) *Redactor {
|
||||
all := DefaultPatterns()
|
||||
chosen := all
|
||||
if len(ids) > 0 {
|
||||
chosen = pick(all, ids)
|
||||
}
|
||||
patterns, err := Compile(chosen)
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
|
||||
return NewRedactor(patterns)
|
||||
}
|
||||
|
||||
var _ = Describe("StreamFilter", func() {
|
||||
It("masks across chunks", func() {
|
||||
// The most important streaming test: an email split arbitrarily
|
||||
// across chunk boundaries must mask exactly the same way as one
|
||||
// arriving in a single Push.
|
||||
red := newStreamRedactor("email")
|
||||
sf := NewStreamFilter(red, nil, nil, "", "")
|
||||
|
||||
// "alice@example.com" (17 bytes) split between '@' and 'e'.
|
||||
out := ""
|
||||
out += sf.Push("hi alice@")
|
||||
out += sf.Push("example.com! end")
|
||||
out += sf.Drain()
|
||||
|
||||
Expect(out).NotTo(ContainSubstring("alice@example.com"), "stream leaked email across chunk boundary")
|
||||
Expect(out).To(ContainSubstring("[REDACTED:email]"))
|
||||
})
|
||||
|
||||
It("block becomes mask", func() {
|
||||
// api_key_prefix is block by default. In stream mode the earlier
|
||||
// chunks are already on the wire so block is impossible — the
|
||||
// filter remaps to mask while still recording action="block" so
|
||||
// the audit log keeps the original intent.
|
||||
red := newStreamRedactor("api_key_prefix")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
sf := NewStreamFilter(red, nil, store, "corr-1", "user-1")
|
||||
|
||||
out := sf.Push("here is your token: sk-abcdefghijklmnopqrstuvwxyz0123456789 done")
|
||||
out += sf.Drain()
|
||||
|
||||
Expect(out).NotTo(ContainSubstring("abcdefghijklmnopqrstuvwxyz0123456789"), "block-in-stream must mask, leaked the value")
|
||||
Expect(out).To(ContainSubstring("[REDACTED:api_key_prefix]"))
|
||||
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 10})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].Action).To(Equal(ActionBlock), "audit must record original block action")
|
||||
Expect(events[0].Direction).To(Equal(DirectionOut), "stream events must be DirectionOut")
|
||||
})
|
||||
|
||||
It("no match passthrough", func() {
|
||||
red := newStreamRedactor("email")
|
||||
sf := NewStreamFilter(red, nil, nil, "", "")
|
||||
out := sf.Push("perfectly clean text that should") + sf.Push(" pass through unchanged.") + sf.Drain()
|
||||
Expect(out).To(Equal("perfectly clean text that should pass through unchanged."))
|
||||
})
|
||||
|
||||
It("nil redactor passthrough", func() {
|
||||
// --disable-pii path: NewStreamFilter(nil, ...) returns a filter
|
||||
// that just forwards Push input verbatim.
|
||||
sf := NewStreamFilter(nil, nil, nil, "", "")
|
||||
out := sf.Push("any old text including alice@example.com") + sf.Drain()
|
||||
Expect(out).To(Equal("any old text including alice@example.com"))
|
||||
})
|
||||
|
||||
It("per-model overrides", func() {
|
||||
// email defaults to mask; per-model override upgrades to block.
|
||||
// In stream mode the override still maps to mask placeholder, but
|
||||
// the audit event records action="block".
|
||||
red := newStreamRedactor("email")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
sf := NewStreamFilter(red, map[string]Action{"email": ActionBlock}, store, "corr-2", "user-2")
|
||||
|
||||
out := sf.Push("contact alice@example.com please") + sf.Drain()
|
||||
Expect(out).NotTo(ContainSubstring("alice@example.com"), "override block-in-stream must mask")
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 10})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].Action).To(Equal(ActionBlock))
|
||||
})
|
||||
|
||||
// StreamFilter_BufferedEmitInvariant feeds the redactor a corpus
|
||||
// one rune at a time, randomly chunked, and asserts:
|
||||
//
|
||||
// 1. Across all (input, splitting) pairs, the cumulative emitted
|
||||
// output never contains any of the secret values that were
|
||||
// embedded in the input.
|
||||
// 2. The output, fully drained, equals what Redact would have
|
||||
// produced on the unsplit input.
|
||||
//
|
||||
// This is the load-bearing property of streaming PII: regardless of
|
||||
// where chunks split, the emitted bytes cannot contain a value that a
|
||||
// single-shot redactor would have masked.
|
||||
It("buffered emit invariant", func() {
|
||||
corpus := []struct {
|
||||
text string
|
||||
secrets []string
|
||||
}{
|
||||
{"contact alice@example.com or bob@example.org", []string{"alice@example.com", "bob@example.org"}},
|
||||
{"my SSN is 123-45-6789 and his is 987-65-4321", []string{"123-45-6789", "987-65-4321"}},
|
||||
{"sk-abcdefghijklmnopqrstuvwxyz0123456789 leaked", []string{"sk-abcdefghijklmnopqrstuvwxyz0123456789"}},
|
||||
{"repeats: alice@example.com / alice@example.com / alice@example.com", []string{"alice@example.com"}},
|
||||
// Multibyte UTF-8 corpora pin the rune-boundary snap in
|
||||
// StreamFilter.Push: holdLen is byte-sized, so a chunk boundary
|
||||
// may land mid-codepoint. Without the snap, the retained tail
|
||||
// has a partial codepoint and the next regex scan can mis-align.
|
||||
// Each entry mixes ASCII secrets with surrounding multibyte text
|
||||
// so a byte-aligned cut would land inside a CJK or accented
|
||||
// character on at least some splits.
|
||||
{"こんにちは alice@example.com さようなら", []string{"alice@example.com"}},
|
||||
{"クレジットカード: 4111-1111-1111-1111 終わり", []string{"4111-1111-1111-1111"}},
|
||||
{"naïve résumé: alice@example.com, façade", []string{"alice@example.com"}},
|
||||
}
|
||||
|
||||
red := newStreamRedactor() // all default patterns
|
||||
rng := rand.New(rand.NewSource(1)) // seeded for reproducibility
|
||||
|
||||
for _, tc := range corpus {
|
||||
for trial := 0; trial < 10; trial++ {
|
||||
sf := NewStreamFilter(red, nil, nil, "", "")
|
||||
var out strings.Builder
|
||||
for i := 0; i < utf8.RuneCountInString(tc.text); {
|
||||
// Random chunk size 1-8 runes, never crossing the end.
|
||||
chunk := 1 + rng.Intn(8)
|
||||
if i+chunk > utf8.RuneCountInString(tc.text) {
|
||||
chunk = utf8.RuneCountInString(tc.text) - i
|
||||
}
|
||||
out.WriteString(sf.Push(stringSlice(tc.text, i, i+chunk)))
|
||||
i += chunk
|
||||
}
|
||||
out.WriteString(sf.Drain())
|
||||
result := out.String()
|
||||
|
||||
// Property 1: no secret value appears anywhere in the
|
||||
// output.
|
||||
for _, secret := range tc.secrets {
|
||||
Expect(result).NotTo(ContainSubstring(secret),
|
||||
fmt.Sprintf("trial %d: secret %q leaked through streaming\n input: %q\n output: %q", trial, secret, tc.text, result))
|
||||
}
|
||||
|
||||
// Property 2: the streamed output equals what a single-shot
|
||||
// Redact would have produced on the same input. (Block
|
||||
// patterns get masked in stream mode, so we compare against
|
||||
// a remapped redaction.)
|
||||
expected := singleShotMaskAll(red, tc.text)
|
||||
Expect(result).To(Equal(expected),
|
||||
fmt.Sprintf("trial %d: stream != single-shot\n input: %q", trial, tc.text))
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// singleShotMaskAll runs the redactor in one pass with all blocks
|
||||
// remapped to mask — the same view the StreamFilter produces.
|
||||
func singleShotMaskAll(red *Redactor, text string) string {
|
||||
patterns := red.Patterns()
|
||||
overrides := make(map[string]Action, len(patterns))
|
||||
for _, p := range patterns {
|
||||
if p.Action == ActionBlock {
|
||||
overrides[p.ID] = ActionMask
|
||||
}
|
||||
}
|
||||
res := red.RedactWithOverrides(text, overrides)
|
||||
return res.Redacted
|
||||
}
|
||||
|
||||
func stringSlice(s string, fromRune, toRune int) string {
|
||||
runes := []rune(s)
|
||||
return string(runes[fromRune:toRune])
|
||||
}
|
||||
@@ -62,10 +62,12 @@ const (
|
||||
// substring slicing; call sites that need to log it strip it via
|
||||
// HashPrefix.
|
||||
type Span struct {
|
||||
Start int
|
||||
End int
|
||||
Pattern string // matches Pattern.ID
|
||||
HashPrefix string // first 8 chars of sha256(matched value); audit-safe
|
||||
Start int
|
||||
End int
|
||||
Pattern string // synthetic detector id, "<source>:<GROUP>" (e.g. "ner:EMAIL", "pattern:ANTHROPIC_KEY")
|
||||
HashPrefix string // first 8 chars of sha256(matched value); audit-safe
|
||||
Action Action // the action that fired for this span (after merge)
|
||||
Score float32 // detector confidence for the (winning) hit, 0..1
|
||||
}
|
||||
|
||||
// Result is what Redact returns. Redacted is the input string after
|
||||
@@ -88,30 +90,6 @@ type Result struct {
|
||||
Masked bool
|
||||
}
|
||||
|
||||
// Pattern is one configurable rule. Description is shown in the admin
|
||||
// UI alongside the pattern; the regex itself stays an implementation
|
||||
// detail (a leak-prone admin showing an SSN regex with a sample value
|
||||
// in the field is a risk we deliberately design around).
|
||||
type Pattern struct {
|
||||
ID string
|
||||
Description string
|
||||
Action Action
|
||||
// Disabled skips the pattern entirely when true — useful for
|
||||
// admins who want to keep a regex around (visible in the UI) but
|
||||
// turn it off without removing the YAML entry. Default-false so
|
||||
// every existing pattern stays active without touching its config.
|
||||
Disabled bool
|
||||
// MaxMatchLength is the longest possible match in characters. The
|
||||
// streaming filter (subsystem 3, follow-up commit) uses this to
|
||||
// size its tail buffer. For regex patterns we compute it at
|
||||
// compile time from the pattern's structure when possible, or set
|
||||
// a conservative upper bound otherwise.
|
||||
MaxMatchLength int
|
||||
|
||||
// internal — populated by Compile().
|
||||
regex regexpMatcher
|
||||
}
|
||||
|
||||
// EventKind classifies a stored audit event. The store is shared by the
|
||||
// PII filter (its original use), the MITM proxy (connect decisions and
|
||||
// per-request traffic counters), and — when subsystem 2 lands — the
|
||||
@@ -135,6 +113,20 @@ const (
|
||||
KindAdmission EventKind = "admission"
|
||||
)
|
||||
|
||||
// Origin labels which surface produced a redaction event, so the events
|
||||
// log distinguishes an inline chat redaction from a MITM-proxy one and
|
||||
// from an explicit /api/pii/{analyze,redact} call. It is set on PII
|
||||
// redaction events only (Kind KindPII); connection/admission events leave
|
||||
// it empty. An empty Origin on an older row reads as "unknown".
|
||||
type Origin = string
|
||||
|
||||
const (
|
||||
OriginMiddleware Origin = "middleware" // in-band chat/completions PII middleware
|
||||
OriginProxy Origin = "proxy" // cloud-proxy MITM input path
|
||||
OriginAnalyzeAPI Origin = "pii_analyze" // POST /api/pii/analyze
|
||||
OriginRedactAPI Origin = "pii_redact" // POST /api/pii/redact
|
||||
)
|
||||
|
||||
// PIIEvent is the persisted record. The Hash field is the first 8 chars
|
||||
// of sha256(matched value) — enough to deduplicate "is this the same
|
||||
// thing as last time" without ever storing the value itself.
|
||||
@@ -146,6 +138,7 @@ const (
|
||||
type PIIEvent struct {
|
||||
ID string `json:"id"`
|
||||
Kind EventKind `json:"kind,omitempty"`
|
||||
Origin Origin `json:"origin,omitempty"`
|
||||
CorrelationID string `json:"correlation_id,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
Direction Direction `json:"direction,omitempty"`
|
||||
@@ -154,7 +147,11 @@ type PIIEvent struct {
|
||||
Length int `json:"length,omitempty"`
|
||||
HashPrefix string `json:"hash_prefix,omitempty"`
|
||||
Action Action `json:"action,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
// Score is the detector confidence (0..1) for an NER PII hit. Metadata
|
||||
// only — never the matched value. Lets admins see how sure the model was
|
||||
// about a (possibly false-positive) detection without re-running it.
|
||||
Score float32 `json:"score,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
Host string `json:"host,omitempty"`
|
||||
Intercepted *bool `json:"intercepted,omitempty"`
|
||||
|
||||
119
core/services/routing/piiadapter/ollama.go
Normal file
119
core/services/routing/piiadapter/ollama.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package piiadapter
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
|
||||
// OllamaChat returns a pii.Adapter for *schema.OllamaChatRequest (POST
|
||||
// /api/chat). It scans each message's text content (Ollama messages carry a
|
||||
// plain string, no multimodal block form) and writes redacted text back.
|
||||
func OllamaChat() pii.Adapter {
|
||||
return pii.Adapter{
|
||||
Scan: func(parsed any) []pii.ScannedText {
|
||||
req, ok := parsed.(*schema.OllamaChatRequest)
|
||||
if !ok || req == nil {
|
||||
return nil
|
||||
}
|
||||
var out []pii.ScannedText
|
||||
for i := range req.Messages {
|
||||
if req.Messages[i].Content != "" {
|
||||
out = append(out, pii.ScannedText{Index: i, Text: req.Messages[i].Content})
|
||||
}
|
||||
}
|
||||
return out
|
||||
},
|
||||
Apply: func(parsed any, updates []pii.ScannedText) {
|
||||
req, ok := parsed.(*schema.OllamaChatRequest)
|
||||
if !ok || req == nil {
|
||||
return
|
||||
}
|
||||
for _, u := range updates {
|
||||
if u.Index >= 0 && u.Index < len(req.Messages) {
|
||||
req.Messages[u.Index].Content = u.Text
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Field selectors for OllamaGenerate (Prompt + System).
|
||||
const (
|
||||
ollamaGenPrompt = iota
|
||||
ollamaGenSystem
|
||||
)
|
||||
|
||||
// OllamaGenerate returns a pii.Adapter for *schema.OllamaGenerateRequest (POST
|
||||
// /api/generate). It scans the Prompt and System strings.
|
||||
func OllamaGenerate() pii.Adapter {
|
||||
return pii.Adapter{
|
||||
Scan: func(parsed any) []pii.ScannedText {
|
||||
req, ok := parsed.(*schema.OllamaGenerateRequest)
|
||||
if !ok || req == nil {
|
||||
return nil
|
||||
}
|
||||
var out []pii.ScannedText
|
||||
if req.Prompt != "" {
|
||||
out = append(out, pii.ScannedText{Index: ollamaGenPrompt, Text: req.Prompt})
|
||||
}
|
||||
if req.System != "" {
|
||||
out = append(out, pii.ScannedText{Index: ollamaGenSystem, Text: req.System})
|
||||
}
|
||||
return out
|
||||
},
|
||||
Apply: func(parsed any, updates []pii.ScannedText) {
|
||||
req, ok := parsed.(*schema.OllamaGenerateRequest)
|
||||
if !ok || req == nil {
|
||||
return
|
||||
}
|
||||
for _, u := range updates {
|
||||
switch u.Index {
|
||||
case ollamaGenPrompt:
|
||||
req.Prompt = u.Text
|
||||
case ollamaGenSystem:
|
||||
req.System = u.Text
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Field selectors for OllamaEmbed (Input + its Prompt alias). Reuses the
|
||||
// shared encField/decField packing.
|
||||
const (
|
||||
ollamaEmbInput = iota
|
||||
ollamaEmbPrompt
|
||||
)
|
||||
|
||||
// OllamaEmbed returns a pii.Adapter for *schema.OllamaEmbedRequest (POST
|
||||
// /api/embed, /api/embeddings). Input and its Prompt alias may be a string or
|
||||
// a []any of strings; non-string elements are skipped.
|
||||
func OllamaEmbed() pii.Adapter {
|
||||
return pii.Adapter{
|
||||
Scan: func(parsed any) []pii.ScannedText {
|
||||
req, ok := parsed.(*schema.OllamaEmbedRequest)
|
||||
if !ok || req == nil {
|
||||
return nil
|
||||
}
|
||||
var out []pii.ScannedText
|
||||
scanAnyText(ollamaEmbInput, req.Input, &out)
|
||||
scanAnyText(ollamaEmbPrompt, req.Prompt, &out)
|
||||
return out
|
||||
},
|
||||
Apply: func(parsed any, updates []pii.ScannedText) {
|
||||
req, ok := parsed.(*schema.OllamaEmbedRequest)
|
||||
if !ok || req == nil {
|
||||
return
|
||||
}
|
||||
for _, u := range updates {
|
||||
field, elem := decField(u.Index)
|
||||
switch field {
|
||||
case ollamaEmbInput:
|
||||
req.Input = applyAnyText(req.Input, elem, u.Text)
|
||||
case ollamaEmbPrompt:
|
||||
req.Prompt = applyAnyText(req.Prompt, elem, u.Text)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
46
core/services/routing/piiadapter/ollama_test.go
Normal file
46
core/services/routing/piiadapter/ollama_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package piiadapter
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Ollama adapters", func() {
|
||||
It("OllamaChat scans and rewrites message content", func() {
|
||||
req := &schema.OllamaChatRequest{Messages: []schema.OllamaMessage{
|
||||
{Role: "user", Content: "I'm alice@example.com"},
|
||||
{Role: "assistant", Content: ""},
|
||||
}}
|
||||
a := OllamaChat()
|
||||
Expect(a.Scan(req)).To(HaveLen(1))
|
||||
applyAll(a, req, func(string) string { return "X" })
|
||||
Expect(req.Messages[0].Content).To(Equal("X"))
|
||||
Expect(req.Messages[1].Content).To(Equal(""))
|
||||
})
|
||||
|
||||
It("OllamaGenerate scans Prompt and System", func() {
|
||||
req := &schema.OllamaGenerateRequest{Prompt: "ssn 123", System: "be terse"}
|
||||
a := OllamaGenerate()
|
||||
Expect(a.Scan(req)).To(HaveLen(2))
|
||||
applyAll(a, req, func(string) string { return "Y" })
|
||||
Expect(req.Prompt).To(Equal("Y"))
|
||||
Expect(req.System).To(Equal("Y"))
|
||||
})
|
||||
|
||||
It("OllamaEmbed scans string and array Input, skipping non-strings", func() {
|
||||
a := OllamaEmbed()
|
||||
|
||||
s := &schema.OllamaEmbedRequest{Input: "secret email"}
|
||||
Expect(a.Scan(s)).To(HaveLen(1))
|
||||
applyAll(a, s, func(string) string { return "Z" })
|
||||
Expect(s.Input).To(Equal("Z"))
|
||||
|
||||
arr := &schema.OllamaEmbedRequest{Input: []any{"a secret", float64(1)}}
|
||||
Expect(a.Scan(arr)).To(HaveLen(1))
|
||||
applyAll(a, arr, func(string) string { return "Z" })
|
||||
got, _ := arr.Input.([]any)
|
||||
Expect(got).To(Equal([]any{"Z", float64(1)}))
|
||||
})
|
||||
})
|
||||
@@ -6,6 +6,8 @@
|
||||
package piiadapter
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
@@ -74,17 +76,35 @@ func OpenAI() pii.Adapter {
|
||||
}
|
||||
msg := &req.Messages[msgIdx]
|
||||
if blockIdx < 0 {
|
||||
// Whole-string content.
|
||||
// Whole-string content. Write BOTH the serializable
|
||||
// Content and the StringContent staging buffer: the
|
||||
// rendered-template path (evaluator.TemplateMessages,
|
||||
// taken whenever use_tokenizer_template is off — e.g.
|
||||
// cloud-proxy translate and Go-templated chat models)
|
||||
// reads StringContent, not Content. Masking only Content
|
||||
// would leave the original in StringContent and leak it
|
||||
// to the backend/upstream.
|
||||
msg.Content = u.Text
|
||||
msg.StringContent = u.Text
|
||||
continue
|
||||
}
|
||||
blocks, ok := msg.Content.([]any)
|
||||
if !ok || blockIdx >= len(blocks) {
|
||||
continue
|
||||
}
|
||||
if blockMap, ok := blocks[blockIdx].(map[string]any); ok {
|
||||
blockMap["text"] = u.Text
|
||||
blockMap, ok := blocks[blockIdx].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// Keep the StringContent projection in sync. For multimodal
|
||||
// messages StringContent is the text blocks flattened with
|
||||
// media markers injected (see middleware/request.go), so we
|
||||
// can't just overwrite it — replace this block's original text
|
||||
// run in place, preserving the markers around it.
|
||||
if orig, ok := blockMap["text"].(string); ok && orig != "" && msg.StringContent != "" {
|
||||
msg.StringContent = strings.Replace(msg.StringContent, orig, u.Text, 1)
|
||||
}
|
||||
blockMap["text"] = u.Text
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
91
core/services/routing/piiadapter/openai_completion.go
Normal file
91
core/services/routing/piiadapter/openai_completion.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package piiadapter
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
|
||||
// Field selectors for the prompt-style OpenAI requests (/v1/completions,
|
||||
// /v1/embeddings, /v1/edits), which carry user text in Prompt / Input /
|
||||
// Instruction rather than Messages.
|
||||
const (
|
||||
fldPrompt = iota
|
||||
fldInput
|
||||
fldInstruction
|
||||
)
|
||||
|
||||
// encField packs (field, element) into one ScannedText.Index. element=-1
|
||||
// means the field is a whole string; element>=0 indexes into a []any value.
|
||||
// Stored as element+1 so -1 maps to 0, with the field in the high bits.
|
||||
func encField(field, elem int) int { return (field << 24) | (elem + 1) }
|
||||
func decField(p int) (field, elem int) { return p >> 24, (p & 0xFFFFFF) - 1 }
|
||||
|
||||
// scanAnyText appends scannable strings from a string-or-[]any field. Non-string
|
||||
// array elements (token-id arrays, numbers) are skipped — only human text is
|
||||
// redacted.
|
||||
func scanAnyText(field int, v any, out *[]pii.ScannedText) {
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
if t != "" {
|
||||
*out = append(*out, pii.ScannedText{Index: encField(field, -1), Text: t})
|
||||
}
|
||||
case []any:
|
||||
for k, e := range t {
|
||||
if s, ok := e.(string); ok && s != "" {
|
||||
*out = append(*out, pii.ScannedText{Index: encField(field, k), Text: s})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// applyAnyText writes redacted text back to a string-or-[]any field, returning
|
||||
// the (possibly replaced) value to assign back to the struct field.
|
||||
func applyAnyText(v any, elem int, text string) any {
|
||||
if elem < 0 {
|
||||
return text
|
||||
}
|
||||
if arr, ok := v.([]any); ok && elem >= 0 && elem < len(arr) {
|
||||
arr[elem] = text
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// OpenAICompletion returns a pii.Adapter for the prompt-style OpenAI requests
|
||||
// (completions, embeddings, edits) on *schema.OpenAIRequest. It scans Prompt,
|
||||
// Input and Instruction — the string form and the string elements of an array
|
||||
// form — and writes redacted text back. Chat uses the separate OpenAI()
|
||||
// adapter (Messages); these endpoints leave Messages empty and vice versa.
|
||||
func OpenAICompletion() pii.Adapter {
|
||||
return pii.Adapter{
|
||||
Scan: func(parsed any) []pii.ScannedText {
|
||||
req, ok := parsed.(*schema.OpenAIRequest)
|
||||
if !ok || req == nil {
|
||||
return nil
|
||||
}
|
||||
var out []pii.ScannedText
|
||||
scanAnyText(fldPrompt, req.Prompt, &out)
|
||||
scanAnyText(fldInput, req.Input, &out)
|
||||
if req.Instruction != "" {
|
||||
out = append(out, pii.ScannedText{Index: encField(fldInstruction, -1), Text: req.Instruction})
|
||||
}
|
||||
return out
|
||||
},
|
||||
Apply: func(parsed any, updates []pii.ScannedText) {
|
||||
req, ok := parsed.(*schema.OpenAIRequest)
|
||||
if !ok || req == nil {
|
||||
return
|
||||
}
|
||||
for _, u := range updates {
|
||||
field, elem := decField(u.Index)
|
||||
switch field {
|
||||
case fldPrompt:
|
||||
req.Prompt = applyAnyText(req.Prompt, elem, u.Text)
|
||||
case fldInput:
|
||||
req.Input = applyAnyText(req.Input, elem, u.Text)
|
||||
case fldInstruction:
|
||||
req.Instruction = u.Text
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
59
core/services/routing/piiadapter/openai_completion_test.go
Normal file
59
core/services/routing/piiadapter/openai_completion_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package piiadapter
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// applyAll feeds every scanned span back through Apply with the text
|
||||
// transformed by fn — the shape the middleware uses (scan, redact, apply).
|
||||
func applyAll(a pii.Adapter, parsed any, fn func(string) string) {
|
||||
scanned := a.Scan(parsed)
|
||||
updates := make([]pii.ScannedText, 0, len(scanned))
|
||||
for _, s := range scanned {
|
||||
updates = append(updates, pii.ScannedText{Index: s.Index, Text: fn(s.Text)})
|
||||
}
|
||||
a.Apply(parsed, updates)
|
||||
}
|
||||
|
||||
var _ = Describe("OpenAICompletion adapter", func() {
|
||||
a := OpenAICompletion()
|
||||
|
||||
It("scans and rewrites a string prompt", func() {
|
||||
req := &schema.OpenAIRequest{}
|
||||
req.Prompt = "contact alice@example.com"
|
||||
got := a.Scan(req)
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Text).To(Equal("contact alice@example.com"))
|
||||
applyAll(a, req, func(string) string { return "REDACTED" })
|
||||
Expect(req.Prompt).To(Equal("REDACTED"))
|
||||
})
|
||||
|
||||
It("scans array prompt elements and skips non-strings (token ids)", func() {
|
||||
req := &schema.OpenAIRequest{}
|
||||
req.Prompt = []any{"first secret", float64(42), "second secret"}
|
||||
got := a.Scan(req)
|
||||
Expect(got).To(HaveLen(2))
|
||||
applyAll(a, req, func(s string) string { return "[X]" })
|
||||
arr, _ := req.Prompt.([]any)
|
||||
Expect(arr).To(Equal([]any{"[X]", float64(42), "[X]"}))
|
||||
})
|
||||
|
||||
It("scans Input and Instruction (the edit/embeddings shape)", func() {
|
||||
req := &schema.OpenAIRequest{Instruction: "fix the SSN 123-45-6789"}
|
||||
req.Input = "my email is bob@example.com"
|
||||
got := a.Scan(req)
|
||||
Expect(got).To(HaveLen(2))
|
||||
applyAll(a, req, func(string) string { return "*" })
|
||||
Expect(req.Input).To(Equal("*"))
|
||||
Expect(req.Instruction).To(Equal("*"))
|
||||
})
|
||||
|
||||
It("returns nothing for an empty / non-matching request", func() {
|
||||
Expect(a.Scan(&schema.OpenAIRequest{})).To(BeEmpty())
|
||||
Expect(a.Scan(nil)).To(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -54,6 +54,55 @@ var _ = Describe("OpenAI adapter", func() {
|
||||
Expect(req.Messages[1].Content.(string)).To(Equal("REDACTED-1"))
|
||||
})
|
||||
|
||||
It("Apply keeps StringContent in sync for string content", func() {
|
||||
// Regression: the request middleware fills StringContent from Content
|
||||
// at parse time, and the rendered-template path (TemplateMessages)
|
||||
// reads StringContent, not Content. Apply must redact both or the
|
||||
// original leaks to the backend/upstream (e.g. cloud-proxy translate).
|
||||
req := &schema.OpenAIRequest{
|
||||
Messages: []schema.Message{
|
||||
{Role: "user", Content: "my key is sk-secret", StringContent: "my key is sk-secret"},
|
||||
},
|
||||
}
|
||||
adapter := OpenAI()
|
||||
scans := adapter.Scan(req)
|
||||
Expect(scans).To(HaveLen(1))
|
||||
scans[0].Text = "my key is [REDACTED]"
|
||||
adapter.Apply(req, scans)
|
||||
|
||||
Expect(req.Messages[0].Content.(string)).To(Equal("my key is [REDACTED]"))
|
||||
Expect(req.Messages[0].StringContent).To(Equal("my key is [REDACTED]"),
|
||||
"StringContent (what TemplateMessages renders) must be redacted too")
|
||||
})
|
||||
|
||||
It("Apply keeps StringContent in sync for content blocks, preserving media markers", func() {
|
||||
// For multimodal content StringContent is the flattened text with
|
||||
// media markers injected (request.go), so Apply must redact the text
|
||||
// run in place rather than clobber the whole buffer.
|
||||
req := &schema.OpenAIRequest{
|
||||
Messages: []schema.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "leak sk-secret here"},
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,xyz"}},
|
||||
},
|
||||
StringContent: "leak sk-secret here<__media__>",
|
||||
},
|
||||
},
|
||||
}
|
||||
adapter := OpenAI()
|
||||
scans := adapter.Scan(req)
|
||||
Expect(scans).To(HaveLen(1))
|
||||
scans[0].Text = "leak [REDACTED] here"
|
||||
adapter.Apply(req, scans)
|
||||
|
||||
blocks := req.Messages[0].Content.([]any)
|
||||
Expect(blocks[0].(map[string]any)["text"]).To(Equal("leak [REDACTED] here"))
|
||||
Expect(req.Messages[0].StringContent).To(Equal("leak [REDACTED] here<__media__>"),
|
||||
"StringContent must be redacted in place, keeping the media marker")
|
||||
})
|
||||
|
||||
It("Apply mutates content block selectively", func() {
|
||||
req := &schema.OpenAIRequest{
|
||||
Messages: []schema.Message{
|
||||
|
||||
86
core/services/routing/piidetector/detector.go
Normal file
86
core/services/routing/piidetector/detector.go
Normal file
@@ -0,0 +1,86 @@
|
||||
// Package piidetector adapts the core/backend token-classification
|
||||
// wrapper to the PII redactor's pii.NERDetector seam. It lives outside
|
||||
// the pii package so pii stays free of core/backend imports (the
|
||||
// redactor is unit-tested with stub detectors). The dependency runs one
|
||||
// way: piidetector -> {core/backend, pii}.
|
||||
package piidetector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// New builds a pii.NERDetector backed by the token-classification model
|
||||
// in modelConfig. Phase 0: the Python `transformers` backend loaded with
|
||||
// Type=TokenClassification; Phase 2: the GGML privacy-filter backend —
|
||||
// both speak the same gRPC TokenClassify contract, so this adapter is
|
||||
// unchanged across the swap. The model is resolved lazily on first
|
||||
// Detect, so building a detector for a not-yet-loaded model is cheap and
|
||||
// never blocks startup.
|
||||
func New(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) pii.NERDetector {
|
||||
return &nerDetector{
|
||||
classifier: backend.NewTokenClassifier(loader, modelConfig, appConfig, backend.TokenClassifyOptions{}),
|
||||
modelName: modelConfig.Name,
|
||||
}
|
||||
}
|
||||
|
||||
type nerDetector struct {
|
||||
classifier backend.TokenClassifier
|
||||
modelName string
|
||||
}
|
||||
|
||||
// Detect runs the model and maps its spans onto pii.NEREntity. Offsets
|
||||
// pass through as BYTE offsets per the TokenClassify proto contract.
|
||||
// Spans whose offsets fall outside the text or land off a UTF-8 rune
|
||||
// boundary are dropped: a bad offset must never reach the redactor,
|
||||
// which splices text[Start:End] and would otherwise corrupt output or
|
||||
// panic. The redactor applies NERConfig.MinScore and the entity->action
|
||||
// map itself, so we deliberately return every (validated) span here.
|
||||
//
|
||||
// CONTRACT NOTE: the proto defines start/end as UTF-8 byte offsets. The
|
||||
// Python transformers backend converts HuggingFace's codepoint offsets to
|
||||
// bytes before responding (see TokenClassify in backend.py), and the GGML
|
||||
// privacy-filter backend will emit bytes natively. The boundary check
|
||||
// below is defense-in-depth against a backend that regresses to codepoint
|
||||
// offsets: it downgrades the bug from "corrupted redaction / panic" to
|
||||
// "dropped span + warning" rather than trusting the wire blindly.
|
||||
func (d *nerDetector) Detect(ctx context.Context, text string) ([]pii.NEREntity, error) {
|
||||
ents, err := d.classifier.TokenClassify(ctx, text)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n := len(text)
|
||||
out := make([]pii.NEREntity, 0, len(ents))
|
||||
for _, e := range ents {
|
||||
if e.Group == "" || e.Start < 0 || e.Start >= e.End || e.End > n {
|
||||
xlog.Warn("pii NER: dropping span with invalid byte range",
|
||||
"model", d.modelName, "group", e.Group, "start", e.Start, "end", e.End, "len", n)
|
||||
continue
|
||||
}
|
||||
// text[e.Start] is safe (Start < End <= n => Start < n). End is
|
||||
// exclusive: when End < n, text[End] is the first byte past the
|
||||
// span and must itself start a rune. Off-boundary offsets are the
|
||||
// signature of codepoint-vs-byte offset confusion.
|
||||
if !utf8.RuneStart(text[e.Start]) || (e.End < n && !utf8.RuneStart(text[e.End])) {
|
||||
xlog.Warn("pii NER: dropping span off UTF-8 boundary (offset units mismatch?)",
|
||||
"model", d.modelName, "group", e.Group, "start", e.Start, "end", e.End)
|
||||
continue
|
||||
}
|
||||
out = append(out, pii.NEREntity{
|
||||
Group: e.Group,
|
||||
Start: e.Start,
|
||||
End: e.End,
|
||||
Score: e.Score,
|
||||
Text: e.Text,
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
80
core/services/routing/piidetector/pattern.go
Normal file
80
core/services/routing/piidetector/pattern.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package piidetector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/piipattern"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
)
|
||||
|
||||
// NewPattern builds a pii.NERDetector that matches secrets with the restricted
|
||||
// regex tier (built-ins + operator-defined patterns) instead of a neural model.
|
||||
// It runs entirely in-process — no backend, GGUF, or VRAM — and the patterns
|
||||
// compile once here, so an invalid pattern is reported now (the resolver fails
|
||||
// closed) rather than per request. Matches are reported under their group with
|
||||
// a deterministic Score of 1.0.
|
||||
func NewPattern(modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (pii.NERDetector, error) {
|
||||
custom := make([]piipattern.Pattern, 0, len(modelConfig.PIIDetection.Patterns))
|
||||
for _, p := range modelConfig.PIIDetection.Patterns {
|
||||
custom = append(custom, piipattern.Pattern{Group: p.Name, Pattern: p.Match, MinLen: p.MinLen})
|
||||
}
|
||||
m, err := piipattern.NewMatcher(modelConfig.PIIDetection.Builtins, custom)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &patternDetector{matcher: m, modelName: modelConfig.Name, appConfig: appConfig}, nil
|
||||
}
|
||||
|
||||
type patternDetector struct {
|
||||
matcher *piipattern.Matcher
|
||||
modelName string
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
// Detect runs the compiled patterns and maps each match onto a pii.NEREntity.
|
||||
// When tracing is enabled it records a pattern_pii BackendTrace so the matches
|
||||
// (group, byte range, text) show in the Traces UI alongside NER detections.
|
||||
func (d *patternDetector) Detect(_ context.Context, text string) ([]pii.NEREntity, error) {
|
||||
var start time.Time
|
||||
if d.appConfig != nil && d.appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(d.appConfig.TracingMaxItems, d.appConfig.TracingMaxBodyBytes)
|
||||
start = time.Now()
|
||||
}
|
||||
|
||||
matches := d.matcher.Find(text)
|
||||
out := make([]pii.NEREntity, 0, len(matches))
|
||||
var traceEnts []backend.TokenEntity
|
||||
for _, mt := range matches {
|
||||
out = append(out, pii.NEREntity{Group: mt.Group, Start: mt.Start, End: mt.End, Score: 1.0, Text: mt.Text})
|
||||
if d.appConfig != nil && d.appConfig.EnableTracing {
|
||||
traceEnts = append(traceEnts, backend.TokenEntity{Group: mt.Group, Start: mt.Start, End: mt.End, Score: 1.0, Text: mt.Text})
|
||||
}
|
||||
}
|
||||
|
||||
if d.appConfig != nil && d.appConfig.EnableTracing {
|
||||
trace.RecordBackendTrace(patternPIITrace(d.modelName, text, traceEnts, start))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// patternPIITrace assembles the Traces-UI row for one pattern-detector run.
|
||||
// Split out so the Data assembly is unit-testable without a request.
|
||||
func patternPIITrace(modelName, text string, entities []backend.TokenEntity, start time.Time) trace.BackendTrace {
|
||||
return trace.BackendTrace{
|
||||
Timestamp: start,
|
||||
Duration: time.Since(start),
|
||||
Type: trace.BackendTracePatternPII,
|
||||
ModelName: modelName,
|
||||
Backend: "pattern",
|
||||
Summary: trace.TruncateString(text, 200),
|
||||
Data: map[string]any{
|
||||
"input_chars": len(text),
|
||||
"matches": len(entities),
|
||||
"entities": entities,
|
||||
},
|
||||
}
|
||||
}
|
||||
61
core/services/routing/piidetector/pattern_test.go
Normal file
61
core/services/routing/piidetector/pattern_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package piidetector_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/piidetector"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestPiidetector(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "piidetector suite")
|
||||
}
|
||||
|
||||
func patternModel() config.ModelConfig {
|
||||
c := config.ModelConfig{Name: "secret-filter", Backend: "pattern"}
|
||||
c.PIIDetection.Builtins = []string{"anthropic_api_key"}
|
||||
c.PIIDetection.Patterns = []config.PIIPattern{{Name: "INTERNAL_TOKEN", Match: `tok-[A-Za-z0-9]{8,}`}}
|
||||
return c
|
||||
}
|
||||
|
||||
var _ = Describe("pattern detector", func() {
|
||||
It("matches built-in and custom secrets as whole-span deterministic hits", func() {
|
||||
det, err := piidetector.NewPattern(patternModel(), &config.ApplicationConfig{})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
ents, err := det.Detect(context.Background(), "use sk-ant-api03-AAAABBBBCCCCDDDDEEEE and tok-ABCD1234 ok")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
byGroup := map[string]pii.NEREntity{}
|
||||
for _, e := range ents {
|
||||
byGroup[e.Group] = e
|
||||
Expect(e.Score).To(BeEquivalentTo(float32(1.0)), "pattern matches are deterministic")
|
||||
}
|
||||
Expect(byGroup).To(HaveKey("ANTHROPIC_KEY"))
|
||||
Expect(byGroup["INTERNAL_TOKEN"].Text).To(Equal("tok-ABCD1234"))
|
||||
})
|
||||
|
||||
It("still detects (and exercises the trace path) with tracing enabled", func() {
|
||||
det, err := piidetector.NewPattern(patternModel(), &config.ApplicationConfig{
|
||||
EnableTracing: true, TracingMaxItems: 8,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
ents, err := det.Detect(context.Background(), "sk-ant-api03-AAAABBBBCCCCDDDDEEEE")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(ents).To(HaveLen(1))
|
||||
Expect(ents[0].Group).To(Equal("ANTHROPIC_KEY"))
|
||||
})
|
||||
|
||||
It("fails to build on an invalid (unanchored) custom pattern", func() {
|
||||
c := config.ModelConfig{Name: "bad", Backend: "pattern"}
|
||||
c.PIIDetection.Patterns = []config.PIIPattern{{Name: "X", Match: `.*`}}
|
||||
_, err := piidetector.NewPattern(c, &config.ApplicationConfig{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
61
core/services/routing/piipattern/builtins.go
Normal file
61
core/services/routing/piipattern/builtins.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package piipattern
|
||||
|
||||
import "sort"
|
||||
|
||||
// Builtin is a named, ready-made secret pattern. Group is the uppercase entity
|
||||
// label a match is reported under (so it keys into a detector model's
|
||||
// pii_detection.entity_actions, exactly like an NER group). Every Builtin
|
||||
// pattern is written in the restricted subset and is verified at test time to
|
||||
// pass ValidatePattern and compile.
|
||||
type Builtin struct {
|
||||
Name string
|
||||
Group string
|
||||
Pattern string
|
||||
Description string
|
||||
}
|
||||
|
||||
// builtins is the curated catalogue. Patterns intentionally anchor on each
|
||||
// provider's fixed prefix and require a long high-entropy tail, so they fire on
|
||||
// real credentials and not on ordinary prose. Names are stable identifiers
|
||||
// referenced from a model config's pii_detection.builtins list.
|
||||
var builtins = []Builtin{
|
||||
{"anthropic_api_key", "ANTHROPIC_KEY", `sk-ant-[A-Za-z0-9_-]{20,}`, "Anthropic API key (sk-ant-…)"},
|
||||
{"openai_api_key", "OPENAI_KEY", `sk-(?:proj-)?[A-Za-z0-9_-]{20,}`, "OpenAI API key (sk-… / sk-proj-…)"},
|
||||
{"github_token", "GITHUB_TOKEN", `(?:ghp|gho|ghs|ghr|ghu)_[A-Za-z0-9]{36,}`, "GitHub access token (ghp_/gho_/ghs_/ghr_/ghu_)"},
|
||||
{"github_pat", "GITHUB_TOKEN", `github_pat_[A-Za-z0-9_]{20,}`, "GitHub fine-grained personal access token"},
|
||||
{"aws_access_key", "AWS_ACCESS_KEY", `AKIA[0-9A-Z]{16}`, "AWS access key ID (AKIA…)"},
|
||||
{"google_api_key", "GOOGLE_API_KEY", `AIza[0-9A-Za-z_-]{35}`, "Google API key (AIza…)"},
|
||||
{"slack_token", "SLACK_TOKEN", `xox[baprs]-[0-9A-Za-z-]{10,}`, "Slack token (xoxb-/xoxa-/xoxp-/xoxr-/xoxs-)"},
|
||||
{"stripe_key", "STRIPE_KEY", `(?:sk|rk)_live_[0-9A-Za-z]{16,}`, "Stripe live secret/restricted key"},
|
||||
{"jwt", "JWT", `eyJ[A-Za-z0-9_-]{10,}\.eyJ[A-Za-z0-9_-]{10,}\.[A-Za-z0-9_-]{10,}`, "JSON Web Token (eyJ….eyJ….…)"},
|
||||
{"private_key_block", "PRIVATE_KEY", `-----BEGIN [A-Z ]*PRIVATE KEY-----`, "PEM private-key header"},
|
||||
}
|
||||
|
||||
// BuiltinCatalogue returns the built-in patterns sorted by name. Used by the
|
||||
// config-metadata registry to populate the editor's builtins checklist.
|
||||
func BuiltinCatalogue() []Builtin {
|
||||
out := make([]Builtin, len(builtins))
|
||||
copy(out, builtins)
|
||||
sort.Slice(out, func(i, j int) bool { return out[i].Name < out[j].Name })
|
||||
return out
|
||||
}
|
||||
|
||||
// BuiltinNames returns the built-in pattern names, sorted.
|
||||
func BuiltinNames() []string {
|
||||
out := make([]string, 0, len(builtins))
|
||||
for _, b := range builtins {
|
||||
out = append(out, b.Name)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// LookupBuiltin finds a built-in by name.
|
||||
func LookupBuiltin(name string) (Builtin, bool) {
|
||||
for _, b := range builtins {
|
||||
if b.Name == name {
|
||||
return b, true
|
||||
}
|
||||
}
|
||||
return Builtin{}, false
|
||||
}
|
||||
20
core/services/routing/piipattern/compile.go
Normal file
20
core/services/routing/piipattern/compile.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package piipattern
|
||||
|
||||
import "regexp"
|
||||
|
||||
// Compile validates src against the restricted grammar and, if it passes,
|
||||
// compiles it to an RE2 program set to leftmost-longest matching so a hit grabs
|
||||
// the whole secret (the entire key) rather than the shortest prefix.
|
||||
func Compile(src string) (*regexp.Regexp, error) {
|
||||
if err := ValidatePattern(src); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
re, err := regexp.Compile(src)
|
||||
if err != nil {
|
||||
// ValidatePattern already parsed with the same flags, so this is
|
||||
// effectively unreachable, but surface it rather than panic.
|
||||
return nil, err
|
||||
}
|
||||
re.Longest()
|
||||
return re, nil
|
||||
}
|
||||
163
core/services/routing/piipattern/grammar.go
Normal file
163
core/services/routing/piipattern/grammar.go
Normal file
@@ -0,0 +1,163 @@
|
||||
// Package piipattern is a bounded, restricted-regex matcher for high-entropy,
|
||||
// highly-regular secrets (API keys, tokens, private-key blocks) that the NER
|
||||
// PII tier cannot catch — it has no credential class, so it fragments a key
|
||||
// into the nearest-looking trained categories and may leave the secret part
|
||||
// exposed.
|
||||
//
|
||||
// The language is a deliberately restricted subset of regular expressions
|
||||
// compiled to Go's RE2 engine (regexp), which is linear-time with no
|
||||
// backtracking — there is no ReDoS class of failure. On top of RE2 we cap the
|
||||
// pattern source length, the {n,m} expansion bound, the pattern count, and the
|
||||
// scanned input, and we require every pattern to carry a fixed literal
|
||||
// "anchor". The anchor rule is what admits `sk-ant-…` / `ghp_…` style keys
|
||||
// while rejecting open-ended shapes like an email address or a bare `\w+`
|
||||
// (which would match almost anything) — those stay with the NER tier.
|
||||
//
|
||||
// This package is a leaf: it imports only the standard library, so both
|
||||
// core/config (validation at load) and core/application (the resolver) can use
|
||||
// it without an import cycle.
|
||||
package piipattern
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp/syntax"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxPatternLen caps the source length of a single pattern. Generous for a
|
||||
// credential shape, small enough that the compiled program stays tiny.
|
||||
MaxPatternLen = 256
|
||||
// MaxQuantifier caps an explicit {n,m} upper bound. RE2 expands a bounded
|
||||
// repeat into that many copies, so an uncapped {0,1000000} would blow up
|
||||
// the compiled program's memory. Unbounded {n,} (no upper) is a loop, not
|
||||
// an expansion, and is allowed.
|
||||
MaxQuantifier = 4096
|
||||
// MaxAlternation caps the arms of a single `a|b|c` alternation.
|
||||
MaxAlternation = 64
|
||||
// MaxAST bounds recursion depth so a pathologically nested pattern can't
|
||||
// blow the stack during validation.
|
||||
MaxAST = 64
|
||||
// MinAnchorLen is the shortest fixed literal run a pattern must contain to
|
||||
// be considered "anchored" to a recognisable secret prefix/shape.
|
||||
MinAnchorLen = 3
|
||||
)
|
||||
|
||||
// parseFlags enables Perl character classes (\w \d \s) and word boundaries,
|
||||
// matching what regexp.Compile uses, so validation and compilation agree.
|
||||
const parseFlags = syntax.Perl
|
||||
|
||||
// ValidatePattern reports whether src is an acceptable restricted-subset
|
||||
// pattern. It returns a descriptive error naming the offending construct so an
|
||||
// operator editing a model config gets actionable feedback (the error is
|
||||
// surfaced by config Validate at load and by the resolver, which fails closed).
|
||||
func ValidatePattern(src string) error {
|
||||
if src == "" {
|
||||
return fmt.Errorf("pattern is empty")
|
||||
}
|
||||
if len(src) > MaxPatternLen {
|
||||
return fmt.Errorf("pattern is too long (%d chars; max %d)", len(src), MaxPatternLen)
|
||||
}
|
||||
re, err := syntax.Parse(src, parseFlags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid pattern: %w", err)
|
||||
}
|
||||
if err := walk(re, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
if anchorLen(re) < MinAnchorLen {
|
||||
return fmt.Errorf("pattern must contain a fixed literal run of at least %d characters "+
|
||||
"(e.g. \"sk-ant-\", \"ghp_\", \"AKIA\") so it is anchored to a recognisable secret; "+
|
||||
"open-ended shapes like emails or bare \\w+ belong to the NER tier", MinAnchorLen)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// walk enforces the allow-list of regex constructs.
|
||||
func walk(re *syntax.Regexp, depth int) error {
|
||||
if depth > MaxAST {
|
||||
return fmt.Errorf("pattern is too deeply nested")
|
||||
}
|
||||
switch re.Op {
|
||||
case syntax.OpAnyChar, syntax.OpAnyCharNotNL:
|
||||
return fmt.Errorf("'.' (any character) is not allowed; use an explicit class like [A-Za-z0-9]")
|
||||
case syntax.OpCapture:
|
||||
return fmt.Errorf("capturing groups are not allowed; use a non-capturing group (?:…) if you need grouping")
|
||||
case syntax.OpRepeat:
|
||||
if re.Min > MaxQuantifier || (re.Max >= 0 && re.Max > MaxQuantifier) {
|
||||
return fmt.Errorf("{n,m} bound is too large (max %d)", MaxQuantifier)
|
||||
}
|
||||
case syntax.OpAlternate:
|
||||
if len(re.Sub) > MaxAlternation {
|
||||
return fmt.Errorf("too many alternation arms (%d; max %d)", len(re.Sub), MaxAlternation)
|
||||
}
|
||||
case syntax.OpLiteral, syntax.OpCharClass, syntax.OpConcat,
|
||||
syntax.OpStar, syntax.OpPlus, syntax.OpQuest,
|
||||
syntax.OpEmptyMatch,
|
||||
syntax.OpBeginLine, syntax.OpEndLine, syntax.OpBeginText, syntax.OpEndText,
|
||||
syntax.OpWordBoundary, syntax.OpNoWordBoundary:
|
||||
// allowed
|
||||
default:
|
||||
return fmt.Errorf("unsupported construct in pattern")
|
||||
}
|
||||
for _, sub := range re.Sub {
|
||||
if err := walk(sub, depth+1); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// anchorLen returns the number of fixed (non-space) literal characters every
|
||||
// match of re is guaranteed to contain — the pattern's "anchor strength".
|
||||
// Concatenation sums its parts; alternation takes the min (every arm must
|
||||
// carry the anchor); a `+`/{n,} with n>=1 contributes its body's literal once;
|
||||
// `*`, `?`, {0,m} and char classes/anchors contribute 0 (they may be absent).
|
||||
//
|
||||
// We sum rather than measure the longest contiguous run because RE2 factors
|
||||
// common prefixes — `(?:ghp|gho|ghs)_…` parses to `gh[ops]_…`, whose longest
|
||||
// contiguous literal is only "gh" (2) but whose guaranteed literals are
|
||||
// "gh"+"_" (3). Summing keeps such real key prefixes admissible while still
|
||||
// rejecting open-ended shapes: an email `[\w.]+@[\w.]+\.\w+` guarantees only
|
||||
// `@` and `.` (2 < MinAnchorLen).
|
||||
func anchorLen(re *syntax.Regexp) int {
|
||||
switch re.Op {
|
||||
case syntax.OpLiteral:
|
||||
n := 0
|
||||
for _, r := range re.Rune {
|
||||
if r != ' ' && r != '\t' && r != '\n' && r != '\r' {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
case syntax.OpConcat:
|
||||
sum := 0
|
||||
for _, sub := range re.Sub {
|
||||
sum += anchorLen(sub)
|
||||
}
|
||||
return sum
|
||||
case syntax.OpAlternate:
|
||||
if len(re.Sub) == 0 {
|
||||
return 0
|
||||
}
|
||||
min := -1
|
||||
for _, sub := range re.Sub {
|
||||
if a := anchorLen(sub); min < 0 || a < min {
|
||||
min = a
|
||||
}
|
||||
}
|
||||
return min
|
||||
case syntax.OpPlus:
|
||||
if len(re.Sub) == 1 {
|
||||
return anchorLen(re.Sub[0])
|
||||
}
|
||||
return 0
|
||||
case syntax.OpRepeat:
|
||||
if re.Min >= 1 && len(re.Sub) == 1 {
|
||||
return anchorLen(re.Sub[0])
|
||||
}
|
||||
return 0
|
||||
default:
|
||||
// char classes, anchors, OpStar, OpQuest carry no guaranteed literal.
|
||||
return 0
|
||||
}
|
||||
}
|
||||
100
core/services/routing/piipattern/matcher.go
Normal file
100
core/services/routing/piipattern/matcher.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package piipattern
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxPatternsPerMatcher bounds how many patterns one detector may hold.
|
||||
MaxPatternsPerMatcher = 128
|
||||
// MaxMatchesPerPattern bounds matches emitted per pattern per call, so a
|
||||
// pathological input can't produce an unbounded result set.
|
||||
MaxMatchesPerPattern = 1000
|
||||
)
|
||||
|
||||
// Pattern is one compiled-ready rule: matches are reported under Group, and a
|
||||
// match shorter than MinLen bytes is dropped (0 = no floor).
|
||||
type Pattern struct {
|
||||
Group string
|
||||
Pattern string
|
||||
MinLen int
|
||||
}
|
||||
|
||||
// Match is one detected span: a half-open byte range [Start,End) into the
|
||||
// scanned text, the matched text, and the reporting Group.
|
||||
type Match struct {
|
||||
Group string
|
||||
Start int
|
||||
End int
|
||||
Text string
|
||||
}
|
||||
|
||||
type compiled struct {
|
||||
group string
|
||||
re *regexp.Regexp
|
||||
minLen int
|
||||
}
|
||||
|
||||
// Matcher holds a set of compiled patterns and scans text for all of them.
|
||||
type Matcher struct {
|
||||
pats []compiled
|
||||
}
|
||||
|
||||
// NewMatcher compiles the named built-ins plus the custom patterns into a
|
||||
// Matcher. Unknown built-in names and patterns that fail the restricted grammar
|
||||
// are reported as errors (the caller fails closed). Built-in and custom counts
|
||||
// together may not exceed MaxPatternsPerMatcher.
|
||||
func NewMatcher(builtinNames []string, custom []Pattern) (*Matcher, error) {
|
||||
if len(builtinNames)+len(custom) > MaxPatternsPerMatcher {
|
||||
return nil, fmt.Errorf("too many patterns (%d; max %d)", len(builtinNames)+len(custom), MaxPatternsPerMatcher)
|
||||
}
|
||||
m := &Matcher{}
|
||||
for _, name := range builtinNames {
|
||||
b, ok := LookupBuiltin(name)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown built-in pattern %q", name)
|
||||
}
|
||||
re, err := Compile(b.Pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("built-in %q: %w", name, err)
|
||||
}
|
||||
m.pats = append(m.pats, compiled{group: b.Group, re: re})
|
||||
}
|
||||
for _, p := range custom {
|
||||
if p.Group == "" {
|
||||
return nil, fmt.Errorf("custom pattern is missing a name/group")
|
||||
}
|
||||
re, err := Compile(p.Pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pattern %q: %w", p.Group, err)
|
||||
}
|
||||
m.pats = append(m.pats, compiled{group: p.Group, re: re, minLen: p.MinLen})
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Find returns every match of every pattern over text. Spans from different
|
||||
// patterns may overlap; the caller (the redactor) unions and resolves them.
|
||||
func (m *Matcher) Find(text string) []Match {
|
||||
if m == nil || text == "" {
|
||||
return nil
|
||||
}
|
||||
var out []Match
|
||||
for _, p := range m.pats {
|
||||
locs := p.re.FindAllStringIndex(text, MaxMatchesPerPattern)
|
||||
for _, loc := range locs {
|
||||
start, end := loc[0], loc[1]
|
||||
if end-start < p.minLen {
|
||||
continue
|
||||
}
|
||||
out = append(out, Match{
|
||||
Group: p.group,
|
||||
Start: start,
|
||||
End: end,
|
||||
Text: text[start:end],
|
||||
})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
105
core/services/routing/piipattern/piipattern_test.go
Normal file
105
core/services/routing/piipattern/piipattern_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package piipattern
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestPiipattern(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "piipattern suite")
|
||||
}
|
||||
|
||||
var _ = Describe("ValidatePattern", func() {
|
||||
DescribeTable("accepts anchored, bounded patterns",
|
||||
func(src string) { Expect(ValidatePattern(src)).To(Succeed()) },
|
||||
Entry("anthropic", `sk-ant-[A-Za-z0-9_-]{20,200}`),
|
||||
Entry("github via alternation", `(?:ghp|gho|ghs)_[A-Za-z0-9]{36,}`),
|
||||
Entry("custom token", `tok-\w{32,64}`),
|
||||
Entry("aws", `AKIA[0-9A-Z]{16}`),
|
||||
Entry("anchored by mid-literal", `(?:sk|rk)_live_[0-9A-Za-z]{16,}`),
|
||||
)
|
||||
|
||||
DescribeTable("rejects unanchored or unsafe patterns",
|
||||
func(src string) { Expect(ValidatePattern(src)).NotTo(Succeed()) },
|
||||
Entry("email (no fixed anchor)", `[\w.]+@[\w.]+\.\w+`),
|
||||
Entry("bare word run", `\w+`),
|
||||
Entry("any-char greedy", `sk-.*`),
|
||||
Entry("capturing group", `(sk-ant-[A-Za-z0-9]+)`),
|
||||
Entry("two fixed chars only", `ab[0-9]{8,}`),
|
||||
Entry("over-long source", "sk-ant-"+strings.Repeat("a", MaxPatternLen)),
|
||||
Entry("huge bounded repeat", `sk-ant-[A-Za-z0-9]{5000}`),
|
||||
Entry("empty", ``),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("Compile", func() {
|
||||
It("compiles a valid pattern with leftmost-longest semantics", func() {
|
||||
re, err := Compile(`sk-ant-[A-Za-z0-9_-]{4,}`)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
// Longest() makes the match span the whole key, not a shorter prefix.
|
||||
loc := re.FindString("key sk-ant-AAAA1111bbbb end")
|
||||
Expect(loc).To(Equal("sk-ant-AAAA1111bbbb"))
|
||||
})
|
||||
It("refuses an invalid pattern", func() {
|
||||
_, err := Compile(`.*`)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("builtins", func() {
|
||||
It("every built-in validates, compiles, and is uniquely named", func() {
|
||||
seen := map[string]bool{}
|
||||
for _, b := range BuiltinCatalogue() {
|
||||
Expect(seen[b.Name]).To(BeFalse(), "duplicate builtin %s", b.Name)
|
||||
seen[b.Name] = true
|
||||
Expect(ValidatePattern(b.Pattern)).To(Succeed(), "builtin %s pattern %q", b.Name, b.Pattern)
|
||||
}
|
||||
})
|
||||
|
||||
DescribeTable("matches a real sample and not a decoy",
|
||||
func(name, sample, decoy string) {
|
||||
b, ok := LookupBuiltin(name)
|
||||
Expect(ok).To(BeTrue())
|
||||
re, err := Compile(b.Pattern)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(re.MatchString(sample)).To(BeTrue(), "should match %q", sample)
|
||||
Expect(re.MatchString(decoy)).To(BeFalse(), "should not match %q", decoy)
|
||||
},
|
||||
Entry("anthropic", "anthropic_api_key", "sk-ant-api03-AbCdEf012345_-AbCdEf012345", "sk-ant-short"),
|
||||
Entry("aws", "aws_access_key", "AKIAIOSFODNN7EXAMPLE", "AKIAshort"),
|
||||
Entry("github", "github_token", "ghp_"+strings.Repeat("a", 36), "ghp_short"),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("Matcher", func() {
|
||||
It("reports the whole key as one span under its group", func() {
|
||||
m, err := NewMatcher([]string{"anthropic_api_key"}, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
got := m.Find("my key is sk-ant-api03-AbCdEf012345AbCdEf012345 thanks")
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Group).To(Equal("ANTHROPIC_KEY"))
|
||||
Expect(got[0].Text).To(Equal("sk-ant-api03-AbCdEf012345AbCdEf012345"))
|
||||
})
|
||||
|
||||
It("compiles custom patterns and honours MinLen", func() {
|
||||
m, err := NewMatcher(nil, []Pattern{{Group: "INTERNAL", Pattern: `tok-[A-Za-z0-9]{4,}`, MinLen: 12}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
// "tok-AAAA" (8 bytes) is below MinLen 12 and is dropped.
|
||||
Expect(m.Find("tok-AAAA")).To(BeEmpty())
|
||||
Expect(m.Find("tok-AAAABBBBCCCC")).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("fails closed on an unknown built-in", func() {
|
||||
_, err := NewMatcher([]string{"nope"}, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects an invalid custom pattern", func() {
|
||||
_, err := NewMatcher(nil, []Pattern{{Group: "X", Pattern: `.*`}})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user