mirror of
https://github.com/mudler/LocalAI.git
synced 2026-06-18 21:58:58 -04:00
feat(pii): NER tier engine — privacy-filter.cpp backend + NER-centric PII filter (#10360)
Squashed feat/pii-ner-tier-engine rebased onto master (was 45 commits; see backup/pii-ner-tier-engine-prerebase). Net change: - privacy-filter.cpp: standalone GGML engine for the openai-privacy-filter PII/NER token classifier, wired as a LocalAI gRPC backend (CPU/CUDA/Vulkan). TokenClassify moves off the patched llama.cpp path onto this backend. - PII filter reworked to be NER-centric (encoder/NER detection tier scanning whole conversations as one document), with a recreated bounded restricted- regex secret-matching pattern detector tier alongside it (per-model pii_detection.builtins / .patterns + core/services/routing/piipattern). - Detection labelled by source (ner vs pattern); backend trace / confidence / debug observability; analyze/redact exposed as a synchronous API. - Instance-wide default detector policy + per-usecase default-on; request filtering extended to completions, embeddings, edits & Ollama. - React UI: NER-centric PII editor, detector-models table, pattern/builtins editor, middleware default-policy UI. - Gallery: privacy-filter-multilingual token-classify model + NER install filter; token_classify known_usecase; batch sized to context for NER models. privacy-filter backend registered in the backend gallery (cpu/vulkan/cuda-13 meta + image entries with a capabilities map) matching its CI matrix jobs, and an /import-model auto-detect importer (PrivacyFilterImporter, narrow privacy-filter GGUF detection) replacing the prior pref-only registration. Reconciled against master's independent evolution: - Dropped master's PIIPatternOverrides feature (global-pattern runtime overrides + /api/pii/patterns API + runtime_settings.json persistence). The per-model NER + pattern-detector design supersedes it; it was built on the global redactor pattern set this branch replaced. - Reverted the llama.cpp Score carry-patch (0006-server-task-type-score): removed the patch and restored master's grpc-server.cpp Score RPC (direct llama_decode, slot-loop bypass) and LLAMA_VERSION pin, plus master's model_config validation forbidding score + chat/completion/embeddings on llama-cpp. token_classify is unaffected (it runs on the privacy-filter backend, not llama-cpp). Assisted-by: Claude:claude-opus-4-8 [Claude Code] Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
committed by
GitHub
parent
c133ca39dc
commit
3fa7b2955c
@@ -12,14 +12,15 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
"github.com/mudler/LocalAI/core/services/agentpool"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
|
||||
"github.com/mudler/LocalAI/core/services/facerecognition"
|
||||
"github.com/mudler/LocalAI/core/services/galleryop"
|
||||
"github.com/mudler/LocalAI/core/services/monitoring"
|
||||
"github.com/mudler/LocalAI/core/services/nodes"
|
||||
"github.com/mudler/LocalAI/core/services/routing/admission"
|
||||
"github.com/mudler/LocalAI/core/services/routing/billing"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/piidetector"
|
||||
"github.com/mudler/LocalAI/core/services/routing/router"
|
||||
"github.com/mudler/LocalAI/core/services/voicerecognition"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
@@ -71,15 +72,15 @@ type Application struct {
|
||||
// 1-to-1 host↔model invariant the dispatcher relies on. Read by
|
||||
// /api/middleware/status so the admin UI can surface the cause.
|
||||
mitmHostConflicts atomic.Pointer[map[string][]string]
|
||||
routerDecisions router.DecisionStore
|
||||
routerRegistry *router.Registry
|
||||
admissionLimiter *admission.Limiter
|
||||
watchdogMutex sync.Mutex
|
||||
watchdogStop chan bool
|
||||
p2pMutex sync.Mutex
|
||||
p2pCtx context.Context
|
||||
p2pCancel context.CancelFunc
|
||||
agentJobMutex sync.Mutex
|
||||
routerDecisions router.DecisionStore
|
||||
routerRegistry *router.Registry
|
||||
admissionLimiter *admission.Limiter
|
||||
watchdogMutex sync.Mutex
|
||||
watchdogStop chan bool
|
||||
p2pMutex sync.Mutex
|
||||
p2pCtx context.Context
|
||||
p2pCancel context.CancelFunc
|
||||
agentJobMutex sync.Mutex
|
||||
|
||||
// Distributed mode services (nil when not in distributed mode)
|
||||
distributed *DistributedServices
|
||||
@@ -254,6 +255,122 @@ func (a *Application) PIIEvents() pii.EventStore {
|
||||
return a.piiEvents
|
||||
}
|
||||
|
||||
// PIINERResolver returns the resolver the chat PII middleware uses to
|
||||
// turn a configured detector model name into a ready-to-use NERConfig:
|
||||
// a token-classifier bound over the shared model loader (lazy — the
|
||||
// model loads on first Detect) plus the detection policy read from that
|
||||
// model's own pii_detection block. Unknown names resolve to (zero,
|
||||
// false) so the middleware fails closed. Pass it via pii.WithNERResolver.
|
||||
func (a *Application) PIINERResolver() pii.NERDetectorResolver {
|
||||
return func(modelName string) (pii.NERConfig, bool) {
|
||||
if modelName == "" {
|
||||
return pii.NERConfig{}, false
|
||||
}
|
||||
cfg, ok := a.ModelConfigLoader().GetModelConfig(modelName)
|
||||
if !ok {
|
||||
return pii.NERConfig{}, false
|
||||
}
|
||||
|
||||
// Pattern detectors match secrets with the restricted-regex tier
|
||||
// in-process (no backend load). Build a pattern matcher instead of the
|
||||
// gRPC token-classifier; on a compile error fail closed with an error
|
||||
// detector so the request is blocked, not silently unscanned.
|
||||
if cfg.IsPatternDetector() {
|
||||
det, err := piidetector.NewPattern(cfg, a.ApplicationConfig())
|
||||
if err != nil {
|
||||
det = pii.NewErrNERDetector(err.Error())
|
||||
}
|
||||
return pii.NERConfigFromRaw(
|
||||
det,
|
||||
0, // patterns are deterministic — no confidence floor
|
||||
cfg.PIIDetectionDefaultAction(),
|
||||
patternEntityActions(cfg),
|
||||
pii.SourcePattern,
|
||||
), true
|
||||
}
|
||||
|
||||
det := piidetector.New(a.ModelLoader(), cfg, a.ApplicationConfig())
|
||||
return pii.NERConfigFromRaw(
|
||||
det,
|
||||
cfg.PIIDetectionMinScore(),
|
||||
cfg.PIIDetectionDefaultAction(),
|
||||
cfg.PIIDetectionEntityActions(),
|
||||
pii.SourceNER,
|
||||
), true
|
||||
}
|
||||
}
|
||||
|
||||
// patternEntityActions merges a pattern detector's per-pattern Action overrides
|
||||
// into its entity_actions map. A pattern reports matches under its Name, so a
|
||||
// per-pattern action is just an entity_actions[Name] entry; explicit
|
||||
// entity_actions still win if both are set.
|
||||
func patternEntityActions(cfg config.ModelConfig) map[string]string {
|
||||
out := cfg.PIIDetectionEntityActions()
|
||||
for _, p := range cfg.PIIDetection.Patterns {
|
||||
if p.Action == "" || p.Name == "" {
|
||||
continue
|
||||
}
|
||||
if out == nil {
|
||||
out = map[string]string{}
|
||||
}
|
||||
if _, exists := out[p.Name]; !exists {
|
||||
out[p.Name] = p.Action
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ResolvePIIPolicy resolves the effective request-side PII policy for a
|
||||
// consuming model, layering the instance-wide default detector
|
||||
// (PIIDefaultDetectors, set via POST /api/settings) on top of the per-model
|
||||
// config. It is the single decision point shared by the chat middleware (via
|
||||
// WithPolicyResolver) and the MITM listener so both agree.
|
||||
//
|
||||
// - enabled: an explicit pii.enabled on the model always wins (true OR
|
||||
// false). Otherwise PII is on when the backend defaults it on — today
|
||||
// that means cloud-proxy models, which cross the network to a third party.
|
||||
// - detectors: the model's own pii.detectors, or — when it lists none — the
|
||||
// global PIIDefaultDetectors fallback. This is what makes cloud-proxy/MITM
|
||||
// redaction work out of the box.
|
||||
//
|
||||
// appConfig is read live, so changes via the settings API take effect on the
|
||||
// next request without a restart.
|
||||
func (a *Application) ResolvePIIPolicy(cfg *config.ModelConfig) (enabled bool, detectors []string) {
|
||||
if cfg == nil {
|
||||
return false, nil
|
||||
}
|
||||
appCfg := a.ApplicationConfig()
|
||||
|
||||
if cfg.PII.Enabled != nil {
|
||||
enabled = *cfg.PII.Enabled
|
||||
} else {
|
||||
enabled = cfg.PIIIsEnabled() // backend default (cloud-proxy)
|
||||
}
|
||||
if !enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
detectors = cfg.PIIDetectors()
|
||||
if len(detectors) == 0 {
|
||||
detectors = append([]string(nil), appCfg.PIIDefaultDetectors...)
|
||||
}
|
||||
return enabled, detectors
|
||||
}
|
||||
|
||||
// PIIPolicyResolver adapts ResolvePIIPolicy to pii.PolicyResolver for
|
||||
// pii.WithPolicyResolver. The middleware carries the resolved model config as
|
||||
// `any` (the MODEL_CONFIG context value, a *config.ModelConfig); this asserts
|
||||
// it back and applies the instance-wide defaults.
|
||||
func (a *Application) PIIPolicyResolver() pii.PolicyResolver {
|
||||
return func(modelCfg any) (bool, []string) {
|
||||
cfg, ok := modelCfg.(*config.ModelConfig)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
return a.ResolvePIIPolicy(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// MITMCA returns the cloudproxy MITM proxy's CA, or nil when the
|
||||
// MITM listener is disabled.
|
||||
func (a *Application) MITMCA() *mitm.CA { return a.mitmCA.Load() }
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/mitm"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -91,25 +92,41 @@ func startMITMLocked(app *Application, options *config.ApplicationConfig) error
|
||||
}
|
||||
sort.Strings(effectiveHosts)
|
||||
|
||||
// Per-host PII gate inherits from the owning model's pii.enabled.
|
||||
// A non-cloud-proxy backend with no explicit pii.enabled resolves
|
||||
// to false → host is intercepted but the regex pass is skipped
|
||||
// (audit events still record).
|
||||
var piiDisabled []string
|
||||
// Per-host NER detectors come from the owning model's pii.detectors
|
||||
// (resolved against each detector model's pii_detection policy). A
|
||||
// host whose model has pii.enabled=false, lists no detectors, or
|
||||
// whose detectors can't be resolved gets no entry → it is intercepted
|
||||
// and forwarded unredacted (audit events still record traffic). An
|
||||
// unresolvable detector is recorded as an error-detector so the
|
||||
// request fails closed at request time rather than leaking.
|
||||
resolver := app.PIINERResolver()
|
||||
detectorsByHost := map[string][]pii.NERConfig{}
|
||||
for host, modelName := range ownership.Owners {
|
||||
cfg, exists := app.backendLoader.GetModelConfig(modelName)
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
if !cfg.PIIIsEnabled() {
|
||||
piiDisabled = append(piiDisabled, host)
|
||||
// Resolve through the shared policy so cloud-proxy hosts inherit the
|
||||
// instance-wide default detector when they name none of their own.
|
||||
enabled, detectors := app.ResolvePIIPolicy(&cfg)
|
||||
if !enabled || len(detectors) == 0 {
|
||||
continue
|
||||
}
|
||||
cfgs := make([]pii.NERConfig, 0, len(detectors))
|
||||
for _, name := range detectors {
|
||||
nc, ok := resolver(name)
|
||||
if !ok {
|
||||
xlog.Error("mitm: detector model not resolvable; requests to host will fail closed", "host", host, "detector", name)
|
||||
nc = pii.NERConfig{Detector: pii.NewErrNERDetector("detector model '" + name + "' not resolvable")}
|
||||
}
|
||||
cfgs = append(cfgs, nc)
|
||||
}
|
||||
detectorsByHost[host] = cfgs
|
||||
}
|
||||
|
||||
handler := mitm.NewPIIHandler(mitm.PIIHandlerOptions{
|
||||
Redactor: app.piiRedactor,
|
||||
EventStore: app.piiEvents,
|
||||
HostsWithPIIDisabled: piiDisabled,
|
||||
EventStore: app.piiEvents,
|
||||
DetectorsByHost: detectorsByHost,
|
||||
})
|
||||
|
||||
srv, err := mitm.NewServer(mitm.Config{
|
||||
@@ -132,7 +149,7 @@ func startMITMLocked(app *Application, options *config.ApplicationConfig) error
|
||||
"ca_dir", caDir,
|
||||
"intercept_hosts", effectiveHosts,
|
||||
"model_owned_hosts", len(ownership.Owners),
|
||||
"pii_disabled_hosts", len(piiDisabled),
|
||||
"pii_detector_hosts", len(detectorsByHost),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
51
core/application/pii_policy_test.go
Normal file
51
core/application/pii_policy_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("ResolvePIIPolicy", func() {
|
||||
chat := config.FLAG_CHAT
|
||||
bp := func(b bool) *bool { return &b }
|
||||
mk := func(c *config.ApplicationConfig) *Application {
|
||||
return &Application{applicationConfig: c}
|
||||
}
|
||||
|
||||
It("lets an explicit pii.enabled=false win over the global default detector", func() {
|
||||
app := mk(&config.ApplicationConfig{PIIDefaultDetectors: []string{"pf"}})
|
||||
cfg := &config.ModelConfig{Backend: "cloud-proxy", KnownUsecases: &chat}
|
||||
cfg.PII.Enabled = bp(false)
|
||||
enabled, dets := app.ResolvePIIPolicy(cfg)
|
||||
Expect(enabled).To(BeFalse())
|
||||
Expect(dets).To(BeNil())
|
||||
})
|
||||
|
||||
It("enables a cloud-proxy model with the global default detector (closes the no-op gap)", func() {
|
||||
// cloud-proxy defaults PIIIsEnabled()==true but lists no detectors, so
|
||||
// without a global default it scans with nothing.
|
||||
app := mk(&config.ApplicationConfig{PIIDefaultDetectors: []string{"pf"}})
|
||||
cfg := &config.ModelConfig{Backend: "cloud-proxy"}
|
||||
enabled, dets := app.ResolvePIIPolicy(cfg)
|
||||
Expect(enabled).To(BeTrue())
|
||||
Expect(dets).To(Equal([]string{"pf"}))
|
||||
})
|
||||
|
||||
It("leaves a non-cloud model off by default (no instance usecase default-on)", func() {
|
||||
app := mk(&config.ApplicationConfig{PIIDefaultDetectors: []string{"pf"}})
|
||||
cfg := &config.ModelConfig{Backend: "llama-cpp", KnownUsecases: &chat}
|
||||
enabled, _ := app.ResolvePIIPolicy(cfg)
|
||||
Expect(enabled).To(BeFalse())
|
||||
})
|
||||
|
||||
It("prefers the model's own detectors over the global default", func() {
|
||||
app := mk(&config.ApplicationConfig{PIIDefaultDetectors: []string{"global-pf"}})
|
||||
cfg := &config.ModelConfig{Backend: "cloud-proxy"}
|
||||
cfg.PII.Detectors = []string{"own-pf"}
|
||||
enabled, dets := app.ResolvePIIPolicy(cfg)
|
||||
Expect(enabled).To(BeTrue())
|
||||
Expect(dets).To(Equal([]string{"own-pf"}))
|
||||
})
|
||||
})
|
||||
@@ -53,7 +53,6 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
caps, err := xsysinfo.CPUCapabilities()
|
||||
if err == nil {
|
||||
xlog.Debug("CPU capabilities", "capabilities", caps)
|
||||
|
||||
}
|
||||
gpus, err := xsysinfo.GPUs()
|
||||
if err == nil {
|
||||
@@ -68,18 +67,18 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
return nil, fmt.Errorf("models path cannot be empty")
|
||||
}
|
||||
|
||||
err = os.MkdirAll(options.SystemState.Model.ModelsPath, 0750)
|
||||
err = os.MkdirAll(options.SystemState.Model.ModelsPath, 0o750)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create ModelPath: %q", err)
|
||||
}
|
||||
if options.GeneratedContentDir != "" {
|
||||
err := os.MkdirAll(options.GeneratedContentDir, 0750)
|
||||
err := os.MkdirAll(options.GeneratedContentDir, 0o750)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create ImageDir: %q", err)
|
||||
}
|
||||
}
|
||||
if options.UploadDir != "" {
|
||||
err := os.MkdirAll(options.UploadDir, 0750)
|
||||
err := os.MkdirAll(options.UploadDir, 0o750)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create UploadDir: %q", err)
|
||||
}
|
||||
@@ -87,7 +86,7 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
|
||||
// Create and migrate data directory
|
||||
if options.DataPath != "" {
|
||||
if err := os.MkdirAll(options.DataPath, 0750); err != nil {
|
||||
if err := os.MkdirAll(options.DataPath, 0o750); err != nil {
|
||||
return nil, fmt.Errorf("unable to create DataPath: %q", err)
|
||||
}
|
||||
// Migrate data from DynamicConfigsDir to DataPath if needed
|
||||
@@ -192,44 +191,14 @@ func New(opts ...config.AppOption) (*Application, error) {
|
||||
xlog.Info("stats: disabled by --disable-stats")
|
||||
}
|
||||
|
||||
// Wire the regex PII filter. Default-on: a single-user box gets
|
||||
// the built-in pattern set the first time it starts, with email/
|
||||
// phone/SSN/credit-card on mask and api_key_prefix on block. If
|
||||
// the operator wants different actions, --pii-config points at a
|
||||
// YAML file that overrides per-id; --disable-pii turns it off
|
||||
// entirely.
|
||||
if !options.DisablePII {
|
||||
patterns, err := pii.LoadConfig(options.PIIConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pii config: %w", err)
|
||||
}
|
||||
application.piiRedactor = pii.NewRedactor(patterns)
|
||||
application.piiEvents = pii.NewMemoryEventStore(0)
|
||||
// Apply persisted per-pattern overrides — admins toggling
|
||||
// action/disabled via the UI and clicking "Save to disk" land
|
||||
// here on the next start. Bad ids are warned and ignored so a
|
||||
// stale entry doesn't block startup.
|
||||
for id, ov := range options.PIIPatternOverrides {
|
||||
if ov.Action != nil {
|
||||
if err := application.piiRedactor.SetAction(id, pii.Action(*ov.Action)); err != nil {
|
||||
xlog.Warn("pii: persisted override skipped", "pattern", id, "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if ov.Disabled != nil {
|
||||
if err := application.piiRedactor.SetDisabled(id, *ov.Disabled); err != nil {
|
||||
xlog.Warn("pii: persisted disable skipped", "pattern", id, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
xlog.Info("pii: filter enabled",
|
||||
"patterns", len(patterns),
|
||||
"config_path", options.PIIConfigPath,
|
||||
"persisted_overrides", len(options.PIIPatternOverrides),
|
||||
)
|
||||
} else {
|
||||
xlog.Info("pii: disabled by --disable-pii")
|
||||
}
|
||||
// Wire the PII filter subsystem. The redactor is now a stateless
|
||||
// handle — detection is driven by per-model NER detectors
|
||||
// (pii.detectors → the detector model's pii_detection policy), run
|
||||
// request-side by the chat middleware and the MITM input path. The
|
||||
// regex tier was removed; redaction is opt-in per model via
|
||||
// PIIIsEnabled(). The event store backs the /api/pii/events audit log.
|
||||
application.piiRedactor = &pii.Redactor{}
|
||||
application.piiEvents = pii.NewMemoryEventStore(0)
|
||||
|
||||
// Wire the routing decision log. Always-on when stats are enabled —
|
||||
// the per-router admin page reads this as the live activity feed
|
||||
@@ -517,7 +486,7 @@ func startWatcher(options *config.ApplicationConfig) {
|
||||
if _, err := os.Stat(options.DynamicConfigsDir); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// We try to create the directory if it does not exist and was specified
|
||||
if err := os.MkdirAll(options.DynamicConfigsDir, 0700); err != nil {
|
||||
if err := os.MkdirAll(options.DynamicConfigsDir, 0o700); err != nil {
|
||||
xlog.Error("failed creating DynamicConfigsDir", "error", err)
|
||||
}
|
||||
} else {
|
||||
@@ -764,16 +733,6 @@ func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) {
|
||||
options.MITMListen = *settings.MITMListen
|
||||
}
|
||||
|
||||
// PII pattern overrides — file is the only source; CLI flags don't
|
||||
// reach into this map. Apply unconditionally when present; the
|
||||
// redactor wiring below sees the result on first construction.
|
||||
if settings.PIIPatternOverrides != nil {
|
||||
options.PIIPatternOverrides = make(map[string]config.PIIPatternRuntimeOverride, len(*settings.PIIPatternOverrides))
|
||||
for id, ov := range *settings.PIIPatternOverrides {
|
||||
options.PIIPatternOverrides[id] = ov
|
||||
}
|
||||
}
|
||||
|
||||
// Backend upgrade flags
|
||||
if settings.AutoUpgradeBackends != nil {
|
||||
if !options.AutoUpgradeBackends {
|
||||
@@ -924,7 +883,7 @@ func loadOrGenerateHMACSecret(path string) (string, error) {
|
||||
}
|
||||
secret := hex.EncodeToString(b)
|
||||
|
||||
if err := os.WriteFile(path, []byte(secret), 0600); err != nil {
|
||||
if err := os.WriteFile(path, []byte(secret), 0o600); err != nil {
|
||||
return "", fmt.Errorf("failed to persist HMAC secret: %w", err)
|
||||
}
|
||||
|
||||
|
||||
150
core/backend/token_classify.go
Normal file
150
core/backend/token_classify.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// TokenEntity is one detected span from a token-classification (NER)
|
||||
// model. Mirrors pb.TokenClassifyEntity but keeps the proto type out of
|
||||
// consumers. Start/End are BYTE offsets into the classified text,
|
||||
// half-open (addressing text[Start:End]) — the proto contract. Group is
|
||||
// the model's entity label (e.g. "private_person", "EMAIL").
|
||||
type TokenEntity struct {
|
||||
Group string `json:"group"`
|
||||
Start int `json:"start"`
|
||||
End int `json:"end"`
|
||||
Score float32 `json:"score"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// TokenClassifyOptions controls a single TokenClassify request.
|
||||
type TokenClassifyOptions struct {
|
||||
// Threshold drops entities the backend scores below this value at
|
||||
// the source. 0 returns everything the model emits; downstream
|
||||
// callers (e.g. the PII redactor's MinScore) can still filter
|
||||
// further once they know the per-request policy.
|
||||
Threshold float32
|
||||
}
|
||||
|
||||
// TokenClassifier runs a token-classification model over text and
|
||||
// returns the detected entity spans. Implemented by NewTokenClassifier
|
||||
// over a model-loaded backend; the PII redactor's encoder/NER tier
|
||||
// consumes this via a pii.NERDetector adapter (see
|
||||
// core/services/routing/piidetector).
|
||||
type TokenClassifier interface {
|
||||
TokenClassify(ctx context.Context, text string) ([]TokenEntity, error)
|
||||
}
|
||||
|
||||
// NewTokenClassifier binds (loader, modelConfig, appConfig) into a
|
||||
// TokenClassifier. The underlying backend is resolved lazily on the
|
||||
// first call, mirroring NewScorer.
|
||||
func NewTokenClassifier(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, opts TokenClassifyOptions) TokenClassifier {
|
||||
return &modelTokenClassifier{loader: loader, modelConfig: modelConfig, appConfig: appConfig, opts: opts}
|
||||
}
|
||||
|
||||
type modelTokenClassifier struct {
|
||||
loader *model.ModelLoader
|
||||
modelConfig config.ModelConfig
|
||||
appConfig *config.ApplicationConfig
|
||||
opts TokenClassifyOptions
|
||||
}
|
||||
|
||||
func (m *modelTokenClassifier) TokenClassify(ctx context.Context, text string) ([]TokenEntity, error) {
|
||||
fn, err := ModelTokenClassify(text, m.opts, m.loader, m.modelConfig, m.appConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
// ModelTokenClassify loads the backend for modelConfig and returns a
|
||||
// closure that classifies `text`. Mirrors ModelScore: the closure is
|
||||
// bound to the loaded model so a caller can reuse it within a request
|
||||
// without re-resolving the backend.
|
||||
//
|
||||
// When tracing is enabled it records a BackendTraceTokenClassify row so the
|
||||
// detector's output — every entity's group, byte range, confidence and the
|
||||
// matched substring — shows in the Traces UI alongside the request it gated.
|
||||
// This is the technical view for debugging false positives (e.g. a phone
|
||||
// number scored as SSN); the persisted PIIEvent keeps only a hash.
|
||||
func ModelTokenClassify(text string, opts TokenClassifyOptions, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func(ctx context.Context) ([]TokenEntity, error), error) {
|
||||
modelOpts := ModelOptions(modelConfig, appConfig)
|
||||
inferenceModel, err := loader.Load(modelOpts...)
|
||||
if err != nil {
|
||||
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
|
||||
return nil, err
|
||||
}
|
||||
return func(ctx context.Context) ([]TokenEntity, error) {
|
||||
var startTime time.Time
|
||||
if appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems, appConfig.TracingMaxBodyBytes)
|
||||
startTime = time.Now()
|
||||
}
|
||||
resp, err := inferenceModel.TokenClassify(ctx, &pb.TokenClassifyRequest{
|
||||
Text: text,
|
||||
Threshold: opts.Threshold,
|
||||
})
|
||||
entities := tokenClassifyResponseToEntities(resp)
|
||||
if appConfig.EnableTracing {
|
||||
trace.RecordBackendTrace(tokenClassifyTrace(modelConfig, text, opts.Threshold, entities, startTime, err))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return entities, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
// tokenClassifyTrace assembles the Traces-UI row for one NER call: the input
|
||||
// preview, the threshold, and every detected entity (group, byte range,
|
||||
// confidence, matched text). Split out from the closure so the Data assembly
|
||||
// is unit-testable without a live backend.
|
||||
func tokenClassifyTrace(modelConfig config.ModelConfig, text string, threshold float32, entities []TokenEntity, start time.Time, callErr error) trace.BackendTrace {
|
||||
errStr := ""
|
||||
if callErr != nil {
|
||||
errStr = callErr.Error()
|
||||
}
|
||||
return trace.BackendTrace{
|
||||
Timestamp: start,
|
||||
Duration: time.Since(start),
|
||||
Type: trace.BackendTraceTokenClassify,
|
||||
ModelName: modelConfig.Name,
|
||||
Backend: modelConfig.Backend,
|
||||
Summary: trace.TruncateString(text, 200),
|
||||
Error: errStr,
|
||||
Data: map[string]any{
|
||||
"input_chars": len(text),
|
||||
"threshold": threshold,
|
||||
"entities": entities,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// tokenClassifyResponseToEntities converts the wire-format response into
|
||||
// the value type consumed by callers. Extracted so the conversion can be
|
||||
// unit-tested without a real backend (see token_classify_test.go).
|
||||
func tokenClassifyResponseToEntities(resp *pb.TokenClassifyResponse) []TokenEntity {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]TokenEntity, 0, len(resp.Entities))
|
||||
for _, e := range resp.Entities {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
out = append(out, TokenEntity{
|
||||
Group: e.EntityGroup,
|
||||
Start: int(e.Start),
|
||||
End: int(e.End),
|
||||
Score: e.Score,
|
||||
Text: e.Text,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
61
core/backend/token_classify_test.go
Normal file
61
core/backend/token_classify_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("tokenClassifyResponseToEntities", func() {
|
||||
It("returns nil for a nil response", func() {
|
||||
Expect(tokenClassifyResponseToEntities(nil)).To(BeNil())
|
||||
})
|
||||
|
||||
It("maps proto entities to TokenEntity, skipping nil rows", func() {
|
||||
resp := &pb.TokenClassifyResponse{
|
||||
Entities: []*pb.TokenClassifyEntity{
|
||||
{EntityGroup: "private_person", Start: 3, End: 8, Score: 0.97, Text: "Alice"},
|
||||
nil,
|
||||
{EntityGroup: "EMAIL", Start: 20, End: 40, Score: 0.5, Text: "a@b.com"},
|
||||
},
|
||||
}
|
||||
Expect(tokenClassifyResponseToEntities(resp)).To(Equal([]TokenEntity{
|
||||
{Group: "private_person", Start: 3, End: 8, Score: 0.97, Text: "Alice"},
|
||||
{Group: "EMAIL", Start: 20, End: 40, Score: 0.5, Text: "a@b.com"},
|
||||
}))
|
||||
})
|
||||
|
||||
It("returns an empty (non-nil) slice for a response with no entities", func() {
|
||||
out := tokenClassifyResponseToEntities(&pb.TokenClassifyResponse{})
|
||||
Expect(out).NotTo(BeNil())
|
||||
Expect(out).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("tokenClassifyTrace", func() {
|
||||
cfg := config.ModelConfig{Name: "privacy-filter", Backend: "privacy-filter"}
|
||||
ents := []TokenEntity{{Group: "SSN", Start: 5, End: 16, Score: 0.62, Text: "123-45-6789"}}
|
||||
|
||||
It("captures model, input preview, threshold and per-entity detail", func() {
|
||||
tr := tokenClassifyTrace(cfg, "ssn is 123-45-6789", 0.5, ents, time.Now(), nil)
|
||||
Expect(tr.Type).To(Equal(trace.BackendTraceTokenClassify))
|
||||
Expect(tr.ModelName).To(Equal("privacy-filter"))
|
||||
Expect(tr.Backend).To(Equal("privacy-filter"))
|
||||
Expect(tr.Summary).To(ContainSubstring("ssn is"))
|
||||
Expect(tr.Error).To(BeEmpty())
|
||||
Expect(tr.Data["input_chars"]).To(Equal(len("ssn is 123-45-6789")))
|
||||
Expect(tr.Data["threshold"]).To(BeEquivalentTo(float32(0.5)))
|
||||
Expect(tr.Data["entities"]).To(Equal(ents))
|
||||
})
|
||||
|
||||
It("records the backend error string when the call failed", func() {
|
||||
tr := tokenClassifyTrace(cfg, "x", 0, nil, time.Now(), errors.New("boom"))
|
||||
Expect(tr.Error).To(Equal("boom"))
|
||||
})
|
||||
})
|
||||
@@ -57,25 +57,6 @@ type ApplicationConfig struct {
|
||||
// touch disk or memory.
|
||||
DisableStats bool
|
||||
|
||||
// PIIConfigPath points to an optional YAML file describing the PII
|
||||
// pattern set. When empty, the routing/pii module's DefaultPatterns()
|
||||
// (email, phone, SSN, credit card, IPv4, API key prefixes) are
|
||||
// loaded with their default actions. Each entry overrides the
|
||||
// matching default by ID:
|
||||
//
|
||||
// patterns:
|
||||
// - id: email
|
||||
// action: allow # downgrade default mask -> allow (log only)
|
||||
// - id: ssn
|
||||
// action: block # upgrade default mask -> block
|
||||
//
|
||||
// Unknown ids are rejected with a clear error at startup.
|
||||
PIIConfigPath string
|
||||
|
||||
// DisablePII turns the regex PII filter off entirely. Default
|
||||
// (false) enables it on the OpenAI chat completions route.
|
||||
DisablePII bool
|
||||
|
||||
// MITMListen is the address (host:port) the cloudproxy MITM
|
||||
// listener binds on. Empty disables the MITM proxy entirely.
|
||||
// Use case: redacting PII from Claude Code / Codex CLI traffic
|
||||
@@ -84,18 +65,20 @@ type ApplicationConfig struct {
|
||||
// LocalAI exposes at /api/middleware/proxy-ca.crt.
|
||||
MITMListen string
|
||||
|
||||
// PIIDefaultDetectors lists token-classification (NER) detector model
|
||||
// names applied to any PII-enabled model that does not name its own
|
||||
// pii.detectors. This makes cloud-proxy / MITM redaction work out of the
|
||||
// box (those default to PII-enabled but carry no detector list) and lets
|
||||
// an operator set one detector for the whole instance. Set at runtime via
|
||||
// POST /api/settings; read live by Application.ResolvePIIPolicy.
|
||||
PIIDefaultDetectors []string
|
||||
|
||||
// MITMCADir holds the persisted MITM proxy CA cert and private
|
||||
// key. The CA is generated on first start; subsequent starts
|
||||
// reload it so clients keep trusting the same root. The key
|
||||
// file is mode 0600.
|
||||
MITMCADir string
|
||||
|
||||
// PIIPatternOverrides applies persisted per-id deltas (action,
|
||||
// disabled) to the live redactor at startup. Loaded from
|
||||
// runtime_settings.json and applied right after pii.NewRedactor.
|
||||
// nil/empty leaves the YAML defaults in place.
|
||||
PIIPatternOverrides map[string]PIIPatternRuntimeOverride
|
||||
|
||||
DisableWebUI bool
|
||||
OllamaAPIRootEndpoint bool
|
||||
EnforcePredownloadScans bool
|
||||
@@ -613,6 +596,7 @@ func WithJSONStringPreload(configFile string) AppOption {
|
||||
o.PreloadJSONModels = configFile
|
||||
}
|
||||
}
|
||||
|
||||
func WithConfigFile(configFile string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.ConfigFile = configFile
|
||||
@@ -701,21 +685,6 @@ func WithDisableStats(disable bool) AppOption {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPIIConfigPath points the routing PII filter at a YAML config
|
||||
// file. CLI: --pii-config.
|
||||
func WithPIIConfigPath(path string) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.PIIConfigPath = path
|
||||
}
|
||||
}
|
||||
|
||||
// WithDisablePII turns the regex PII filter off. CLI: --disable-pii.
|
||||
func WithDisablePII(disable bool) AppOption {
|
||||
return func(o *ApplicationConfig) {
|
||||
o.DisablePII = disable
|
||||
}
|
||||
}
|
||||
|
||||
// WithMITMListen sets the address the cloudproxy MITM listener
|
||||
// binds on. Empty = disabled. CLI: --mitm-listen.
|
||||
func WithMITMListen(addr string) AppOption {
|
||||
@@ -1137,6 +1106,8 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
|
||||
mitmListen := o.MITMListen
|
||||
|
||||
piiDefaultDetectors := append([]string(nil), o.PIIDefaultDetectors...)
|
||||
|
||||
return RuntimeSettings{
|
||||
WatchdogEnabled: &watchdogEnabled,
|
||||
WatchdogIdleEnabled: &watchdogIdle,
|
||||
@@ -1191,6 +1162,7 @@ func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings {
|
||||
LogoHorizontalFile: &logoHorizontalFile,
|
||||
FaviconFile: &faviconFile,
|
||||
MITMListen: &mitmListen,
|
||||
PIIDefaultDetectors: &piiDefaultDetectors,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1424,6 +1396,10 @@ func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (req
|
||||
o.MITMListen = *settings.MITMListen
|
||||
}
|
||||
|
||||
if settings.PIIDefaultDetectors != nil {
|
||||
o.PIIDefaultDetectors = append([]string(nil), (*settings.PIIDefaultDetectors)...)
|
||||
}
|
||||
|
||||
// Note: ApiKeys requires special handling (merging with startup keys) - handled in caller
|
||||
|
||||
return requireRestart
|
||||
|
||||
@@ -8,26 +8,27 @@ import (
|
||||
// Usecase name constants — the canonical string values used in gallery entries,
|
||||
// model configs (known_usecases), and UsecaseInfoMap keys.
|
||||
const (
|
||||
UsecaseChat = "chat"
|
||||
UsecaseCompletion = "completion"
|
||||
UsecaseEdit = "edit"
|
||||
UsecaseVision = "vision"
|
||||
UsecaseEmbeddings = "embeddings"
|
||||
UsecaseTokenize = "tokenize"
|
||||
UsecaseImage = "image"
|
||||
UsecaseVideo = "video"
|
||||
UsecaseTranscript = "transcript"
|
||||
UsecaseTTS = "tts"
|
||||
UsecaseSoundGeneration = "sound_generation"
|
||||
UsecaseRerank = "rerank"
|
||||
UsecaseDetection = "detection"
|
||||
UsecaseDepth = "depth"
|
||||
UsecaseVAD = "vad"
|
||||
UsecaseAudioTransform = "audio_transform"
|
||||
UsecaseDiarization = "diarization"
|
||||
UsecaseRealtimeAudio = "realtime_audio"
|
||||
UsecaseFaceRecognition = "face_recognition"
|
||||
UsecaseSpeakerRecognition = "speaker_recognition"
|
||||
UsecaseChat = "chat"
|
||||
UsecaseCompletion = "completion"
|
||||
UsecaseEdit = "edit"
|
||||
UsecaseVision = "vision"
|
||||
UsecaseEmbeddings = "embeddings"
|
||||
UsecaseTokenize = "tokenize"
|
||||
UsecaseImage = "image"
|
||||
UsecaseVideo = "video"
|
||||
UsecaseTranscript = "transcript"
|
||||
UsecaseTTS = "tts"
|
||||
UsecaseSoundGeneration = "sound_generation"
|
||||
UsecaseRerank = "rerank"
|
||||
UsecaseDetection = "detection"
|
||||
UsecaseDepth = "depth"
|
||||
UsecaseVAD = "vad"
|
||||
UsecaseAudioTransform = "audio_transform"
|
||||
UsecaseDiarization = "diarization"
|
||||
UsecaseRealtimeAudio = "realtime_audio"
|
||||
UsecaseFaceRecognition = "face_recognition"
|
||||
UsecaseSpeakerRecognition = "speaker_recognition"
|
||||
UsecaseTokenClassify = "token_classify"
|
||||
)
|
||||
|
||||
// GRPCMethod identifies a Backend service RPC from backend.proto.
|
||||
@@ -56,6 +57,7 @@ const (
|
||||
MethodVoiceVerify GRPCMethod = "VoiceVerify"
|
||||
MethodVoiceEmbed GRPCMethod = "VoiceEmbed"
|
||||
MethodVoiceAnalyze GRPCMethod = "VoiceAnalyze"
|
||||
MethodTokenClassify GRPCMethod = "TokenClassify"
|
||||
)
|
||||
|
||||
// UsecaseInfo describes a single known_usecase value and how it maps
|
||||
@@ -178,6 +180,11 @@ var UsecaseInfoMap = map[string]UsecaseInfo{
|
||||
GRPCMethod: MethodVoiceVerify,
|
||||
Description: "Speaker recognition — verify identity, embed and analyze voice via VoiceVerify, VoiceEmbed and VoiceAnalyze RPCs.",
|
||||
},
|
||||
UsecaseTokenClassify: {
|
||||
Flag: FLAG_TOKEN_CLASSIFY,
|
||||
GRPCMethod: MethodTokenClassify,
|
||||
Description: "Per-token classification (NER) via the TokenClassify RPC — the PII detector tier. Declared explicitly via known_usecases; never auto-guessed, since the token-classification head is not useful as general generation or embeddings.",
|
||||
},
|
||||
}
|
||||
|
||||
// BackendCapability describes which gRPC methods and usecases a backend supports.
|
||||
@@ -214,6 +221,17 @@ var BackendCapabilities = map[string]BackendCapability{
|
||||
AcceptsImages: true, // requires mmproj
|
||||
Description: "llama.cpp GGUF models — LLM inference with optional vision via mmproj",
|
||||
},
|
||||
// privacy-filter is the standalone GGML engine (backend/cpp/privacy-filter,
|
||||
// wrapping privacy-filter.cpp) for the openai-privacy-filter PII/NER token
|
||||
// classifier — the dedicated TokenClassify path that replaces the
|
||||
// patched-llama.cpp route. Never auto-guessed; declared explicitly via
|
||||
// known_usecases: [token_classify].
|
||||
"privacy-filter": {
|
||||
GRPCMethods: []GRPCMethod{MethodTokenClassify},
|
||||
PossibleUsecases: []string{UsecaseTokenClassify},
|
||||
DefaultUsecases: []string{UsecaseTokenClassify},
|
||||
Description: "privacy-filter.cpp — standalone GGML backend for openai-privacy-filter PII/NER token classification",
|
||||
},
|
||||
"vllm": {
|
||||
GRPCMethods: []GRPCMethod{MethodPredict, MethodPredictStream, MethodEmbedding},
|
||||
PossibleUsecases: []string{UsecaseChat, UsecaseCompletion, UsecaseEmbeddings, UsecaseVision},
|
||||
|
||||
@@ -19,8 +19,19 @@ const (
|
||||
defaultNGPULayers = 99999999
|
||||
)
|
||||
|
||||
func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
// reservedNonChatModel reports whether the operator reserved this model for an
|
||||
// internal primitive — the router score classifier or the PII NER
|
||||
// token_classify tier. Such a model has no chat template and must not be
|
||||
// given the generative-chat defaults the GGUF importer otherwise applies
|
||||
// (FLAG_CHAT, jinja templating): surfacing it in chat pickers defeats the
|
||||
// reservation. Operators who do want a combined model declare both usecases
|
||||
// explicitly — the combination is valid.
|
||||
func reservedNonChatModel(cfg *ModelConfig) bool {
|
||||
return cfg.KnownUsecases != nil &&
|
||||
(*cfg.KnownUsecases&(FLAG_SCORE|FLAG_TOKEN_CLASSIFY)) != 0
|
||||
}
|
||||
|
||||
func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
if defaultCtx == 0 && cfg.ContextSize == nil {
|
||||
ctxSize := f.EstimateLLaMACppRun().ContextSize
|
||||
if ctxSize > 0 {
|
||||
@@ -77,11 +88,19 @@ func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) {
|
||||
cfg.Name = f.Metadata().Name
|
||||
}
|
||||
|
||||
// Instruct to use template from llama.cpp
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.FunctionsConfig.GrammarConfig.NoGrammar = true
|
||||
cfg.Options = append(cfg.Options, "use_jinja:true")
|
||||
cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT")
|
||||
// A model the operator reserved for an internal primitive (the router
|
||||
// score classifier, or the PII NER token_classify tier) is not a chat
|
||||
// model: it carries no chat template and must not be painted with the
|
||||
// generative-chat defaults — appending FLAG_CHAT here would fold chat
|
||||
// into KnownUsecases on the next sync and surface the model in every
|
||||
// chat picker. Respect the declaration.
|
||||
if !reservedNonChatModel(cfg) {
|
||||
// Instruct to use template from llama.cpp
|
||||
cfg.TemplateConfig.UseTokenizerTemplate = true
|
||||
cfg.FunctionsConfig.GrammarConfig.NoGrammar = true
|
||||
cfg.Options = append(cfg.Options, "use_jinja:true")
|
||||
cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT")
|
||||
}
|
||||
|
||||
// Apply per-model-family inference parameter defaults (temperature, top_p, etc.)
|
||||
ApplyInferenceDefaults(cfg, f.Metadata().Name)
|
||||
|
||||
41
core/config/meta/pattern_meta_test.go
Normal file
41
core/config/meta/pattern_meta_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package meta_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/config/meta"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestMeta(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "config/meta suite")
|
||||
}
|
||||
|
||||
var _ = Describe("pattern detector field metadata", func() {
|
||||
byPath := func() map[string]meta.FieldMeta {
|
||||
md := meta.BuildForTest(reflect.TypeOf(config.ModelConfig{}), meta.DefaultRegistry())
|
||||
out := make(map[string]meta.FieldMeta, len(md.Fields))
|
||||
for _, f := range md.Fields {
|
||||
out[f.Path] = f
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
It("renders builtins as a select with the catalogue as options", func() {
|
||||
f, ok := byPath()["pii_detection.builtins"]
|
||||
Expect(ok).To(BeTrue(), "pii_detection.builtins field should exist")
|
||||
Expect(f.Component).To(Equal("pii-builtins-select"))
|
||||
Expect(f.Options).NotTo(BeEmpty())
|
||||
})
|
||||
|
||||
It("renders custom patterns with the pattern-list editor", func() {
|
||||
f, ok := byPath()["pii_detection.patterns"]
|
||||
Expect(ok).To(BeTrue(), "pii_detection.patterns field should exist")
|
||||
Expect(f.Component).To(Equal("pii-pattern-list"))
|
||||
})
|
||||
})
|
||||
@@ -1,5 +1,19 @@
|
||||
package meta
|
||||
|
||||
import "github.com/mudler/LocalAI/core/services/routing/piipattern"
|
||||
|
||||
// builtinPatternOptions turns the piipattern built-in catalogue into select
|
||||
// options for the editor's built-in-patterns checklist, keeping the catalogue
|
||||
// the single source of truth.
|
||||
func builtinPatternOptions() []FieldOption {
|
||||
cat := piipattern.BuiltinCatalogue()
|
||||
out := make([]FieldOption, 0, len(cat))
|
||||
for _, b := range cat {
|
||||
out = append(out, FieldOption{Value: b.Name, Label: b.Name + " — " + b.Description})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// DefaultRegistry returns enrichment overrides for the ~30 most commonly used
|
||||
// config fields. Fields not listed here still appear with auto-generated
|
||||
// labels and type-inferred components.
|
||||
@@ -504,12 +518,60 @@ func DefaultRegistry() map[string]FieldMetaOverride {
|
||||
Component: "toggle",
|
||||
Order: 200,
|
||||
},
|
||||
"pii.patterns": {
|
||||
"pii.detectors": {
|
||||
Section: "pii",
|
||||
Label: "PII Detector Models",
|
||||
Description: "Token-classification (NER) models that scan this model's requests for PII. The detection policy (which entities, what action, min score) lives on each detector model's own PII Detection block. Multiple detectors union their hits.",
|
||||
Component: "model-multi-select",
|
||||
AutocompleteProvider: "models:token_classify",
|
||||
Order: 201,
|
||||
},
|
||||
|
||||
// --- PII detection policy (on a token_classify detector model) ---
|
||||
"pii_detection.min_score": {
|
||||
Section: "pii",
|
||||
Label: "PII Pattern Overrides",
|
||||
Description: "Override the global default action for specific patterns on this model. Patterns not listed here inherit the global action (Settings → Middleware → Filtering).",
|
||||
Label: "Detector Min Score",
|
||||
Description: "When this model is used as a PII detector, drop detections scored below this confidence before they are acted on. 0 keeps every detection.",
|
||||
Component: "slider",
|
||||
Min: f64(0),
|
||||
Max: f64(1),
|
||||
Step: f64(0.01),
|
||||
Order: 210,
|
||||
},
|
||||
"pii_detection.default_action": {
|
||||
Section: "pii",
|
||||
Label: "Detector Default Action",
|
||||
Description: "Action applied to detected entity groups with no explicit per-entity override. Defaults to mask — the safe-by-default policy for a PII filter.",
|
||||
Component: "select",
|
||||
Options: []FieldOption{
|
||||
{Value: "mask", Label: "mask (redact the span)"},
|
||||
{Value: "block", Label: "block (reject the request)"},
|
||||
{Value: "allow", Label: "allow (detect & log only)"},
|
||||
},
|
||||
Default: "mask",
|
||||
Order: 211,
|
||||
},
|
||||
"pii_detection.entity_actions": {
|
||||
Section: "pii",
|
||||
Label: "Detector Entity Actions",
|
||||
Description: "Per-entity-group action policy for this detector model (e.g. PASSWORD → block, EMAIL → mask). Groups without an entry use the default action.",
|
||||
Component: "entity-action-list",
|
||||
Order: 212,
|
||||
},
|
||||
"pii_detection.builtins": {
|
||||
Section: "pii",
|
||||
Label: "Built-in Secret Patterns",
|
||||
Description: "Built-in regex patterns for common credentials (API keys, tokens, private keys). Turning any on makes this a pattern detector — it matches high-entropy secrets the NER tier can't, in-process with no model load.",
|
||||
Component: "pii-builtins-select",
|
||||
Options: builtinPatternOptions(),
|
||||
Order: 213,
|
||||
},
|
||||
"pii_detection.patterns": {
|
||||
Section: "pii",
|
||||
Label: "Custom Secret Patterns",
|
||||
Description: "Operator-defined patterns in a restricted regex subset (e.g. \"sk-prefix-\\w+\"). Each must contain a fixed literal anchor of ≥3 chars; open-ended shapes like emails are rejected (leave those to NER). Matches report under the pattern name as the entity group.",
|
||||
Component: "pii-pattern-list",
|
||||
Order: 201,
|
||||
Order: 214,
|
||||
},
|
||||
|
||||
// --- Cloud passthrough proxy ---
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"text/template"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/piipattern"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/reasoning"
|
||||
@@ -23,7 +24,6 @@ const (
|
||||
|
||||
// @Description TTS configuration
|
||||
type TTSConfig struct {
|
||||
|
||||
// Voice wav path or id
|
||||
Voice string `yaml:"voice,omitempty" json:"voice,omitempty"`
|
||||
|
||||
@@ -116,13 +116,18 @@ type ModelConfig struct {
|
||||
Options []string `yaml:"options,omitempty" json:"options,omitempty"`
|
||||
Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"`
|
||||
|
||||
MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"`
|
||||
Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"`
|
||||
PII PIIConfig `yaml:"pii,omitempty" json:"pii,omitempty"`
|
||||
Router RouterConfig `yaml:"router,omitempty" json:"router,omitempty"`
|
||||
Proxy ProxyConfig `yaml:"proxy,omitempty" json:"proxy,omitempty"`
|
||||
MITM MITMModelConfig `yaml:"mitm,omitempty" json:"mitm,omitempty"`
|
||||
Limits LimitsConfig `yaml:"limits,omitempty" json:"limits,omitempty"`
|
||||
MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"`
|
||||
Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"`
|
||||
PII PIIConfig `yaml:"pii,omitempty" json:"pii,omitempty"`
|
||||
// PIIDetection is the detection policy when THIS model is used as a
|
||||
// PII detector (a token_classify model named in another model's
|
||||
// pii.detectors). Ignored on models that aren't referenced as
|
||||
// detectors.
|
||||
PIIDetection PIIDetectionConfig `yaml:"pii_detection,omitempty" json:"pii_detection,omitempty"`
|
||||
Router RouterConfig `yaml:"router,omitempty" json:"router,omitempty"`
|
||||
Proxy ProxyConfig `yaml:"proxy,omitempty" json:"proxy,omitempty"`
|
||||
MITM MITMModelConfig `yaml:"mitm,omitempty" json:"mitm,omitempty"`
|
||||
Limits LimitsConfig `yaml:"limits,omitempty" json:"limits,omitempty"`
|
||||
}
|
||||
|
||||
// @Description Admission-control limits applied per request. The
|
||||
@@ -397,18 +402,54 @@ type PIIConfig struct {
|
||||
// the YAML key is distinguishable from explicit false.
|
||||
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
|
||||
|
||||
// Patterns lets a model upgrade or downgrade individual pattern
|
||||
// actions (mask | block | allow) relative to the global
|
||||
// defaults loaded from --pii-config / DefaultPatterns. Pattern IDs
|
||||
// not listed inherit the global action. The regex itself stays
|
||||
// global — only the action is settable per-model.
|
||||
Patterns []PIIPatternOverride `yaml:"patterns,omitempty" json:"patterns,omitempty"`
|
||||
// Detectors lists the token-classification (NER) models whose
|
||||
// detections drive PII redaction for this model. The detection policy
|
||||
// (min score, per-entity actions, default action) lives on each named
|
||||
// detector model's own pii_detection block, not here — a consuming
|
||||
// model just opts in by listing detectors. Multiple detectors union
|
||||
// their hits; overlapping spans resolve to the strongest action.
|
||||
Detectors []string `yaml:"detectors,omitempty" json:"detectors,omitempty"`
|
||||
}
|
||||
|
||||
// @Description Per-model action override for a single PII pattern.
|
||||
type PIIPatternOverride struct {
|
||||
ID string `yaml:"id" json:"id"`
|
||||
Action string `yaml:"action" json:"action"`
|
||||
// @Description Detection policy for a token-classification (NER) model
|
||||
// used as a PII detector. Lives on the detector model's own config so the
|
||||
// model is a self-describing policy unit: consuming models reference it by
|
||||
// name (via pii.detectors) and inherit this policy with no per-consumer
|
||||
// overrides.
|
||||
type PIIDetectionConfig struct {
|
||||
// MinScore drops detections the model scores below this confidence
|
||||
// before they are acted on. 0 keeps every detection.
|
||||
MinScore float32 `yaml:"min_score,omitempty" json:"min_score,omitempty"`
|
||||
// DefaultAction (mask | block | allow) applies to detected entity
|
||||
// groups with no explicit EntityActions entry. Empty defaults to
|
||||
// "mask" — the safe-by-default policy for a PII filter.
|
||||
DefaultAction string `yaml:"default_action,omitempty" json:"default_action,omitempty"`
|
||||
// EntityActions maps an entity group the model emits (e.g. "EMAIL",
|
||||
// "PASSWORD") to an action, overriding DefaultAction for that group.
|
||||
// This is where an operator says which PII to block vs mask vs
|
||||
// allow-log.
|
||||
EntityActions map[string]string `yaml:"entity_actions,omitempty" json:"entity_actions,omitempty"`
|
||||
|
||||
// Builtins names the built-in pattern groups this (pattern) detector
|
||||
// enables, e.g. "anthropic_api_key", "github_token". Pattern detectors
|
||||
// match high-entropy structured secrets the NER tier can't; see
|
||||
// core/services/routing/piipattern.
|
||||
Builtins []string `yaml:"builtins,omitempty" json:"builtins,omitempty"`
|
||||
// Patterns lists operator-defined secret patterns in the restricted-regex
|
||||
// subset (validated at load). Each match is reported under its Name as the
|
||||
// entity group, so EntityActions/DefaultAction apply by Name.
|
||||
Patterns []PIIPattern `yaml:"patterns,omitempty" json:"patterns,omitempty"`
|
||||
}
|
||||
|
||||
// PIIPattern is one operator-defined pattern on a pattern detector model. Name
|
||||
// is the entity group reported for matches (and the EntityActions key). Match
|
||||
// is the restricted-regex source. Action optionally overrides DefaultAction for
|
||||
// this pattern. MinLen drops matches shorter than N bytes (0 = no floor).
|
||||
type PIIPattern struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Match string `yaml:"match" json:"match"`
|
||||
Action string `yaml:"action,omitempty" json:"action,omitempty"`
|
||||
MinLen int `yaml:"min_len,omitempty" json:"min_len,omitempty"`
|
||||
}
|
||||
|
||||
// PIIIsEnabled returns the resolved PII state for this model. Single
|
||||
@@ -421,27 +462,71 @@ func (c *ModelConfig) PIIIsEnabled() bool {
|
||||
return c.Backend == "cloud-proxy"
|
||||
}
|
||||
|
||||
// PIIPatternOverrides returns the per-pattern action overrides as a map
|
||||
// keyed by pattern ID. The values are the raw action strings — the pii
|
||||
// package validates and converts them.
|
||||
//
|
||||
// Returned via the documented modelPIIConfig interface in
|
||||
// core/services/routing/pii/middleware.go without taking a config
|
||||
// dependency on this package.
|
||||
func (c *ModelConfig) PIIPatternOverrides() map[string]string {
|
||||
if len(c.PII.Patterns) == 0 {
|
||||
// PIIDetectors returns the names of the token-classification models that
|
||||
// drive PII redaction for this (consuming) model. Read via the
|
||||
// ModelPIIConfig interface in core/services/routing/pii/middleware.go.
|
||||
func (c *ModelConfig) PIIDetectors() []string {
|
||||
if len(c.PII.Detectors) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]string, len(c.PII.Patterns))
|
||||
for _, p := range c.PII.Patterns {
|
||||
if p.ID == "" {
|
||||
continue
|
||||
}
|
||||
out[p.ID] = p.Action
|
||||
out := make([]string, len(c.PII.Detectors))
|
||||
copy(out, c.PII.Detectors)
|
||||
return out
|
||||
}
|
||||
|
||||
// piiCoverableUsecases lists the model usecases whose serving API has a
|
||||
// request-side PII filter wired (a piiadapter + the pii middleware). It scopes
|
||||
// the Middleware admin list (PIIFilterApplies). Grow it as adapters are added
|
||||
// for new endpoints. cloud-proxy carries no usecase flag but is always covered
|
||||
// (via the MITM / proxy chat path), so PIIFilterApplies handles it separately.
|
||||
var piiCoverableUsecases = []ModelConfigUsecase{FLAG_CHAT, FLAG_COMPLETION, FLAG_EDIT, FLAG_EMBEDDINGS}
|
||||
|
||||
// PIIFilterApplies reports whether request-side PII filtering can apply to
|
||||
// this model at all — i.e. it is reachable through a text-accepting endpoint
|
||||
// that has a PII adapter wired. Used to scope the Middleware admin view so it
|
||||
// lists only models PII could protect, not every config (VAD, STT,
|
||||
// embedding-only, image, or the token_classify detector models themselves,
|
||||
// which are the filters rather than consumers). Detector/score models return
|
||||
// false naturally: HasUsecases short-circuits to false for any usecase a
|
||||
// declared score/token_classify model did not itself declare.
|
||||
func (c *ModelConfig) PIIFilterApplies() bool {
|
||||
if c.Backend == "cloud-proxy" {
|
||||
return true
|
||||
}
|
||||
return slices.ContainsFunc(piiCoverableUsecases, c.HasUsecases)
|
||||
}
|
||||
|
||||
// PIIDetectionMinScore returns the confidence floor this model applies
|
||||
// when used as a PII detector.
|
||||
func (c *ModelConfig) PIIDetectionMinScore() float32 { return c.PIIDetection.MinScore }
|
||||
|
||||
// PIIDetectionDefaultAction returns the raw default-action string applied
|
||||
// to detected entity groups without an explicit override. The pii package
|
||||
// validates it and applies the "mask" fallback.
|
||||
func (c *ModelConfig) PIIDetectionDefaultAction() string { return c.PIIDetection.DefaultAction }
|
||||
|
||||
// PIIDetectionEntityActions returns the per-entity-group action policy as
|
||||
// a fresh map of raw action strings (validated by the pii package).
|
||||
func (c *ModelConfig) PIIDetectionEntityActions() map[string]string {
|
||||
if len(c.PIIDetection.EntityActions) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]string, len(c.PIIDetection.EntityActions))
|
||||
for k, v := range c.PIIDetection.EntityActions {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// IsPatternDetector reports whether this detector model matches secrets with
|
||||
// regex patterns (built-in and/or operator-defined) rather than a neural NER
|
||||
// model. Such a model runs entirely in-process (no backend / GGUF / VRAM); the
|
||||
// PII resolver builds an in-process pattern matcher for it instead of loading a
|
||||
// gRPC token-classifier.
|
||||
func (c *ModelConfig) IsPatternDetector() bool {
|
||||
return len(c.PIIDetection.Builtins) > 0 || len(c.PIIDetection.Patterns) > 0
|
||||
}
|
||||
|
||||
// @Description MCP configuration
|
||||
type MCPConfig struct {
|
||||
Servers string `yaml:"remote,omitempty" json:"remote,omitempty"`
|
||||
@@ -485,8 +570,10 @@ func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCP
|
||||
type MCPGenericConfig[T any] struct {
|
||||
Servers T `yaml:"mcpServers,omitempty" json:"mcpServers,omitempty"`
|
||||
}
|
||||
type MCPRemoteServers map[string]MCPRemoteServer
|
||||
type MCPSTDIOServers map[string]MCPSTDIOServer
|
||||
type (
|
||||
MCPRemoteServers map[string]MCPRemoteServer
|
||||
MCPSTDIOServers map[string]MCPSTDIOServer
|
||||
)
|
||||
|
||||
// @Description MCP remote server configuration
|
||||
type MCPRemoteServer struct {
|
||||
@@ -1221,6 +1308,8 @@ func (c *ModelConfig) Validate() (bool, error) {
|
||||
// llama_context against concurrent generation/embedding traffic
|
||||
// (see backend/cpp/llama-cpp/grpc-server.cpp on Score). Reject the
|
||||
// combination here so operators are forced to split the model.
|
||||
// (token_classify is unaffected — it runs on the standalone
|
||||
// privacy-filter backend, not llama-cpp.)
|
||||
const scoreConflicts = FLAG_CHAT | FLAG_COMPLETION | FLAG_EMBEDDINGS
|
||||
if (c.Backend == "llama-cpp" || c.Backend == "llama") &&
|
||||
c.HasUsecases(FLAG_SCORE) && c.KnownUsecases != nil &&
|
||||
@@ -1230,6 +1319,26 @@ func (c *ModelConfig) Validate() (bool, error) {
|
||||
"with chat/completion/embeddings — split into separate model configs")
|
||||
}
|
||||
|
||||
// Pattern detector: validate built-in names and that each operator-defined
|
||||
// pattern is a well-formed, anchored, bounded restricted-regex. Reject at
|
||||
// load so a bad pattern surfaces as a clear config error rather than a
|
||||
// silent no-op (or a fail-closed block) at request time.
|
||||
if c.IsPatternDetector() {
|
||||
for _, name := range c.PIIDetection.Builtins {
|
||||
if _, ok := piipattern.LookupBuiltin(name); !ok {
|
||||
return false, fmt.Errorf("pii_detection: unknown built-in pattern %q", name)
|
||||
}
|
||||
}
|
||||
for _, p := range c.PIIDetection.Patterns {
|
||||
if p.Name == "" {
|
||||
return false, fmt.Errorf("pii_detection: pattern is missing a name")
|
||||
}
|
||||
if err := piipattern.ValidatePattern(p.Match); err != nil {
|
||||
return false, fmt.Errorf("pii_detection: pattern %q: %w", p.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// router.score_normalization is consumed lazily by the score
|
||||
// classifier at first-request time; without load-time validation
|
||||
// a typo wouldn't surface until the first router request panicked
|
||||
@@ -1336,16 +1445,24 @@ const (
|
||||
// Marks a model as wired for the Score gRPC primitive (joint
|
||||
// log-prob of candidate continuations under a shared prompt). Must
|
||||
// be declared explicitly via `known_usecases: [score]` — there's
|
||||
// no heuristic for it. On the llama-cpp backend, Score bypasses
|
||||
// the slot loop and races the llama_context, so Validate() refuses
|
||||
// to load a llama-cpp config that combines FLAG_SCORE with
|
||||
// chat/completion/embeddings.
|
||||
// no heuristic for it. On llama-cpp, Score bypasses the slot loop
|
||||
// (direct llama_decode), so combining score with
|
||||
// chat/completion/embeddings in one config is rejected at validation.
|
||||
FLAG_SCORE ModelConfigUsecase = 0b10000000000000000000
|
||||
|
||||
// Marks a model as wired for the Depth gRPC primitive (per-pixel
|
||||
// metric depth + camera pose + 3D point cloud via Depth Anything 3).
|
||||
FLAG_DEPTH ModelConfigUsecase = 0b100000000000000000000
|
||||
|
||||
// Marks a model as wired for the TokenClassify gRPC primitive (the
|
||||
// openai-privacy-filter PII NER tier — per-token BIOES classification).
|
||||
// Like FLAG_SCORE it must be declared explicitly via
|
||||
// `known_usecases: [token_classify]`; there's no heuristic. Requires
|
||||
// TOKEN_CLS pooling, which is loaded via the embeddings flag. On
|
||||
// llama-cpp the classification windows ride the embedding task queue,
|
||||
// so it may combine freely with other usecases.
|
||||
FLAG_TOKEN_CLASSIFY ModelConfigUsecase = 0b1000000000000000000000
|
||||
|
||||
// Common Subsets
|
||||
FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
|
||||
)
|
||||
@@ -1404,6 +1521,7 @@ func GetAllModelConfigUsecases() map[string]ModelConfigUsecase {
|
||||
"FLAG_REALTIME_AUDIO": FLAG_REALTIME_AUDIO,
|
||||
"FLAG_SCORE": FLAG_SCORE,
|
||||
"FLAG_DEPTH": FLAG_DEPTH,
|
||||
"FLAG_TOKEN_CLASSIFY": FLAG_TOKEN_CLASSIFY,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1431,19 +1549,20 @@ func GetUsecasesFromYAML(input []string) *ModelConfigUsecase {
|
||||
// HasUsecases examines a ModelConfig and determines which endpoints have a chance of success.
|
||||
//
|
||||
// Declared known_usecases are normally additive — the guessing heuristic
|
||||
// still adds whatever it can infer from backend/templates. The one
|
||||
// exception is FLAG_SCORE: when the operator declared score, they
|
||||
// reserved the model for the router classifier. Letting GuessUsecases
|
||||
// paint chat/completion on top would surface it in chat pickers it was
|
||||
// deliberately kept out of, and (on llama-cpp) reintroduce the slot
|
||||
// contention the score/chat conflict check exists to prevent. So a
|
||||
// declared score list is authoritative.
|
||||
// still adds whatever it can infer from backend/templates. The exceptions
|
||||
// are FLAG_SCORE and FLAG_TOKEN_CLASSIFY: when the operator declared
|
||||
// either, they reserved the model for an internal direct-decode primitive
|
||||
// (the router classifier, or the PII NER tier). Letting GuessUsecases
|
||||
// paint chat/completion/embeddings on top would surface it in pickers it
|
||||
// was deliberately kept out of, and (on llama-cpp) reintroduce the slot
|
||||
// contention the conflict check exists to prevent. So a declared score or
|
||||
// token_classify list is authoritative.
|
||||
func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool {
|
||||
if c.KnownUsecases != nil {
|
||||
if (u & *c.KnownUsecases) == u {
|
||||
return true
|
||||
}
|
||||
if (*c.KnownUsecases & FLAG_SCORE) == FLAG_SCORE {
|
||||
if (*c.KnownUsecases & (FLAG_SCORE | FLAG_TOKEN_CLASSIFY)) != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -1623,6 +1742,15 @@ func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if (u & FLAG_TOKEN_CLASSIFY) == FLAG_TOKEN_CLASSIFY {
|
||||
// No heuristic: token-classification intent is a deliberate
|
||||
// operator choice (it reserves the model from generation traffic
|
||||
// on llama-cpp, and the model's TOKEN_CLS head isn't useful as
|
||||
// general embeddings), so HasUsecases(FLAG_TOKEN_CLASSIFY) is true
|
||||
// only when KnownUsecases declares it explicitly.
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
var _ = Describe("Test cases for config related functions", func() {
|
||||
@@ -72,9 +73,10 @@ parameters:
|
||||
Expect(valid).To(BeTrue())
|
||||
|
||||
// llama-cpp configs can't mix the score usecase with
|
||||
// chat/completion/embeddings — Score bypasses the slot
|
||||
// loop and would race the llama_context. The check fires
|
||||
// at load and save time; here we exercise it directly.
|
||||
// chat/completion/embeddings — Score bypasses the slot loop
|
||||
// and would race the llama_context. (token_classify is exempt:
|
||||
// it runs on the privacy-filter backend, not llama-cpp, so the
|
||||
// token_classify combinations below stay valid.)
|
||||
scoreFlag := FLAG_SCORE | FLAG_CHAT
|
||||
conflicting := ModelConfig{
|
||||
Name: "router-but-also-chat",
|
||||
@@ -96,15 +98,23 @@ parameters:
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// The constraint is llama-cpp-specific; other backends
|
||||
// may safely combine.
|
||||
scoreAndChat := FLAG_SCORE | FLAG_CHAT
|
||||
otherBackend := ModelConfig{
|
||||
Name: "vllm-router-and-chat",
|
||||
Backend: "vllm",
|
||||
KnownUsecases: &scoreAndChat,
|
||||
tcAndChat := FLAG_TOKEN_CLASSIFY | FLAG_CHAT
|
||||
tcCombined := ModelConfig{
|
||||
Name: "ner-and-chat",
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecases: &tcAndChat,
|
||||
}
|
||||
valid, err = otherBackend.Validate()
|
||||
valid, err = tcCombined.Validate()
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
tcAndEmbeddings := FLAG_TOKEN_CLASSIFY | FLAG_EMBEDDINGS
|
||||
tcWithEmbeddings := ModelConfig{
|
||||
Name: "pii-ner",
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecases: &tcAndEmbeddings,
|
||||
}
|
||||
valid, err = tcWithEmbeddings.Validate()
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
@@ -228,7 +238,6 @@ parameters:
|
||||
})
|
||||
})
|
||||
It("Properly handles backend usecase matching", func() {
|
||||
|
||||
a := ModelConfig{
|
||||
Name: "a",
|
||||
}
|
||||
@@ -336,17 +345,17 @@ parameters:
|
||||
// Declared `known_usecases: [score]` is authoritative — the
|
||||
// guessing heuristic must NOT add chat on top, even though the
|
||||
// inherited chatml template would otherwise satisfy the chat
|
||||
// heuristic. Score means "this model is reserved for the
|
||||
// router classifier"; surfacing it as a chat model defeats the
|
||||
// reservation and reintroduces the slot contention the load-time
|
||||
// score/chat conflict check exists to prevent.
|
||||
// heuristic. A score-only declaration means "this model is
|
||||
// reserved for the router classifier"; surfacing it as a chat
|
||||
// model defeats the reservation. (Operators who do want both
|
||||
// may declare both — the combination is supported.)
|
||||
scoreReserved := FLAG_SCORE
|
||||
j := ModelConfig{
|
||||
Name: "arch-router",
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecases: &scoreReserved,
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "inherited from chatml",
|
||||
Chat: "inherited from chatml",
|
||||
ChatMessage: "inherited from chatml",
|
||||
Completion: "inherited from chatml",
|
||||
},
|
||||
@@ -355,6 +364,27 @@ parameters:
|
||||
Expect(j.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
||||
Expect(j.HasUsecases(FLAG_COMPLETION)).To(BeFalse())
|
||||
Expect(j.HasUsecases(FLAG_EMBEDDINGS)).To(BeFalse())
|
||||
|
||||
// Declared `known_usecases: [token_classify]` is likewise
|
||||
// authoritative — a PII NER model is reserved for the redactor's
|
||||
// NER tier and must not surface as chat or as a general embeddings
|
||||
// model, even though it loads with embeddings enabled (its
|
||||
// TOKEN_CLS head produces BIOES logits, not reusable embeddings).
|
||||
tcReserved := FLAG_TOKEN_CLASSIFY
|
||||
embTrue := true
|
||||
k := ModelConfig{
|
||||
Name: "privacy-filter",
|
||||
Backend: "llama-cpp",
|
||||
KnownUsecases: &tcReserved,
|
||||
Embeddings: &embTrue,
|
||||
TemplateConfig: TemplateConfig{
|
||||
Chat: "inherited from chatml",
|
||||
ChatMessage: "inherited from chatml",
|
||||
},
|
||||
}
|
||||
Expect(k.HasUsecases(FLAG_TOKEN_CLASSIFY)).To(BeTrue())
|
||||
Expect(k.HasUsecases(FLAG_CHAT)).To(BeFalse())
|
||||
Expect(k.HasUsecases(FLAG_EMBEDDINGS)).To(BeFalse())
|
||||
})
|
||||
It("Test Validate with invalid MCP config", func() {
|
||||
tmp, err := os.CreateTemp("", "config.yaml")
|
||||
@@ -598,3 +628,162 @@ concurrency_groups:
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("PII config accessors", func() {
|
||||
It("PIIDetectors returns a fresh copy of the consumer's detector list", func() {
|
||||
cfg := &ModelConfig{PII: PIIConfig{Detectors: []string{"a", "b"}}}
|
||||
got := cfg.PIIDetectors()
|
||||
Expect(got).To(Equal([]string{"a", "b"}))
|
||||
got[0] = "mutated"
|
||||
Expect(cfg.PII.Detectors[0]).To(Equal("a"), "accessor must not alias the underlying slice")
|
||||
})
|
||||
|
||||
It("PIIDetectors is nil when none are configured", func() {
|
||||
Expect((&ModelConfig{}).PIIDetectors()).To(BeNil())
|
||||
})
|
||||
|
||||
It("exposes the detector model's pii_detection policy", func() {
|
||||
cfg := &ModelConfig{PIIDetection: PIIDetectionConfig{
|
||||
MinScore: 0.5,
|
||||
DefaultAction: "mask",
|
||||
EntityActions: map[string]string{"PASSWORD": "block", "EMAIL": "mask"},
|
||||
}}
|
||||
Expect(cfg.PIIDetectionMinScore()).To(BeNumerically("~", 0.5, 1e-6))
|
||||
Expect(cfg.PIIDetectionDefaultAction()).To(Equal("mask"))
|
||||
ea := cfg.PIIDetectionEntityActions()
|
||||
Expect(ea).To(HaveKeyWithValue("PASSWORD", "block"))
|
||||
ea["PASSWORD"] = "mutated"
|
||||
Expect(cfg.PIIDetection.EntityActions["PASSWORD"]).To(Equal("block"), "accessor must return a fresh map")
|
||||
})
|
||||
|
||||
It("unmarshals pii.detectors and pii_detection from YAML", func() {
|
||||
var cfg ModelConfig
|
||||
raw := []byte("name: consumer\npii:\n enabled: true\n detectors: [pf]\npii_detection:\n min_score: 0.4\n default_action: mask\n entity_actions:\n PASSWORD: block\n")
|
||||
Expect(yaml.Unmarshal(raw, &cfg)).To(Succeed())
|
||||
Expect(cfg.PIIDetectors()).To(Equal([]string{"pf"}))
|
||||
Expect(cfg.PIIDetectionDefaultAction()).To(Equal("mask"))
|
||||
Expect(cfg.PIIDetectionEntityActions()).To(HaveKeyWithValue("PASSWORD", "block"))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("GGUF importer chat-default guard (reservedNonChatModel)", func() {
|
||||
mk := func(flags ModelConfigUsecase) *ModelConfig {
|
||||
return &ModelConfig{Backend: "llama-cpp", KnownUsecases: &flags}
|
||||
}
|
||||
|
||||
It("treats declared score / token_classify models as reserved (no chat defaults)", func() {
|
||||
Expect(reservedNonChatModel(mk(FLAG_SCORE))).To(BeTrue())
|
||||
Expect(reservedNonChatModel(mk(FLAG_TOKEN_CLASSIFY))).To(BeTrue())
|
||||
// embeddings declared alongside token_classify (the PII NER shape) is
|
||||
// still reserved.
|
||||
Expect(reservedNonChatModel(mk(FLAG_TOKEN_CLASSIFY | FLAG_EMBEDDINGS))).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does not reserve ordinary or undeclared models", func() {
|
||||
Expect(reservedNonChatModel(mk(FLAG_CHAT))).To(BeFalse())
|
||||
Expect(reservedNonChatModel(mk(FLAG_EMBEDDINGS))).To(BeFalse())
|
||||
Expect(reservedNonChatModel(&ModelConfig{Backend: "llama-cpp"})).To(BeFalse())
|
||||
})
|
||||
|
||||
It("keeps a token_classify GGUF config valid by withholding FLAG_CHAT", func() {
|
||||
// The privacy-filter import shape: the GGUF importer appends FLAG_CHAT
|
||||
// to a templateless model, which the next sync folds into
|
||||
// KnownUsecases. token_classify+chat is a VALID combination
|
||||
// (token_classify runs on the privacy-filter backend, not llama-cpp,
|
||||
// so the score/chat conflict check does not apply to it), but the
|
||||
// importer must still not paint a declared-reserved model as chat
|
||||
// — that would surface it in every chat picker.
|
||||
reserved := []string{"token_classify"}
|
||||
withChat := append(append([]string{}, reserved...), "FLAG_CHAT")
|
||||
|
||||
// What the importer would produce WITHOUT the guard: valid (the
|
||||
// score/chat conflict check is score-specific), just undesirable
|
||||
// defaults.
|
||||
combined := &ModelConfig{Backend: "llama-cpp", KnownUsecaseStrings: withChat}
|
||||
combined.syncKnownUsecasesFromString()
|
||||
valid, err := combined.Validate()
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
// With the guard (FLAG_CHAT withheld): the declaration survives and the
|
||||
// config validates.
|
||||
good := &ModelConfig{Backend: "llama-cpp", KnownUsecaseStrings: reserved}
|
||||
good.syncKnownUsecasesFromString()
|
||||
Expect(reservedNonChatModel(good)).To(BeTrue())
|
||||
valid, err = good.Validate()
|
||||
Expect(valid).To(BeTrue())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(good.HasUsecases(FLAG_TOKEN_CLASSIFY)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("PIIFilterApplies (Middleware admin list scoping)", func() {
|
||||
withUsecases := func(backend string, flags ModelConfigUsecase) *ModelConfig {
|
||||
return &ModelConfig{Name: "m", Backend: backend, KnownUsecases: &flags}
|
||||
}
|
||||
|
||||
It("includes chat-capable models and cloud-proxy models", func() {
|
||||
Expect(withUsecases("llama-cpp", FLAG_CHAT).PIIFilterApplies()).To(BeTrue())
|
||||
// cloud-proxy is always covered (MITM / proxy chat path), regardless
|
||||
// of declared usecases.
|
||||
Expect((&ModelConfig{Name: "claude", Backend: "cloud-proxy"}).PIIFilterApplies()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("excludes the detector and score models themselves", func() {
|
||||
// token_classify detectors are the filters, not consumers; score
|
||||
// classifiers are internal primitives. Both short-circuit
|
||||
// HasUsecases(FLAG_CHAT) to false.
|
||||
Expect(withUsecases("llama-cpp", FLAG_TOKEN_CLASSIFY).PIIFilterApplies()).To(BeFalse())
|
||||
Expect(withUsecases("llama-cpp", FLAG_SCORE).PIIFilterApplies()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("includes embedding and completion models (their request text is filtered)", func() {
|
||||
// Phase 4 wired PII onto /v1/embeddings, /v1/completions and /v1/edits,
|
||||
// so those usecases are now coverable.
|
||||
emb := withUsecases("llama-cpp", FLAG_EMBEDDINGS)
|
||||
t := true
|
||||
emb.Embeddings = &t
|
||||
Expect(emb.PIIFilterApplies()).To(BeTrue())
|
||||
Expect(withUsecases("llama-cpp", FLAG_COMPLETION).PIIFilterApplies()).To(BeTrue())
|
||||
})
|
||||
|
||||
It("excludes models with no text-accepting, PII-covered endpoint", func() {
|
||||
// VAD / audio-in models carry no coverable usecase.
|
||||
Expect((&ModelConfig{Name: "vad", Backend: "silero-vad"}).PIIFilterApplies()).To(BeFalse())
|
||||
Expect(withUsecases("whisper", FLAG_TRANSCRIPT).PIIFilterApplies()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("pattern detector config", func() {
|
||||
patternCfg := func() *ModelConfig {
|
||||
c := &ModelConfig{Name: "secret-filter", Backend: "pattern"}
|
||||
c.PIIDetection.Builtins = []string{"anthropic_api_key"}
|
||||
c.PIIDetection.Patterns = []PIIPattern{{Name: "INTERNAL", Match: `tok-[A-Za-z0-9]{20,}`}}
|
||||
return c
|
||||
}
|
||||
|
||||
It("IsPatternDetector keys off builtins/patterns", func() {
|
||||
Expect(patternCfg().IsPatternDetector()).To(BeTrue())
|
||||
Expect((&ModelConfig{Name: "ner", Backend: "llama-cpp"}).IsPatternDetector()).To(BeFalse())
|
||||
})
|
||||
|
||||
It("Validate accepts a well-formed pattern detector (no model file needed)", func() {
|
||||
ok, err := patternCfg().Validate()
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(ok).To(BeTrue())
|
||||
})
|
||||
|
||||
It("Validate rejects an unknown built-in", func() {
|
||||
c := &ModelConfig{Name: "x", Backend: "pattern"}
|
||||
c.PIIDetection.Builtins = []string{"does_not_exist"}
|
||||
_, err := c.Validate()
|
||||
Expect(err).To(MatchError(ContainSubstring("unknown built-in")))
|
||||
})
|
||||
|
||||
It("Validate rejects an unanchored custom pattern", func() {
|
||||
c := &ModelConfig{Name: "x", Backend: "pattern"}
|
||||
c.PIIDetection.Patterns = []PIIPattern{{Name: "EMAILish", Match: `[\w.]+@[\w.]+\.\w+`}}
|
||||
_, err := c.Validate()
|
||||
Expect(err).To(MatchError(ContainSubstring("pattern \"EMAILish\"")))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -18,8 +18,8 @@ type RuntimeSettings struct {
|
||||
WatchdogInterval *string `json:"watchdog_interval,omitempty"` // Interval between watchdog checks (e.g., 2s, 30s)
|
||||
|
||||
// Backend management
|
||||
SingleBackend *bool `json:"single_backend,omitempty"` // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends *int `json:"max_active_backends,omitempty"` // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
SingleBackend *bool `json:"single_backend,omitempty"` // Deprecated: use MaxActiveBackends = 1 instead
|
||||
MaxActiveBackends *int `json:"max_active_backends,omitempty"` // Maximum number of active backends (0 = unlimited, 1 = single backend mode)
|
||||
AutoUpgradeBackends *bool `json:"auto_upgrade_backends,omitempty"` // Automatically upgrade backends when new versions are detected
|
||||
PreferDevelopmentBackends *bool `json:"prefer_development_backends,omitempty"` // Prefer development backend versions by default in UI
|
||||
// Memory Reclaimer settings (works with GPU if available, otherwise RAM)
|
||||
@@ -97,19 +97,9 @@ type RuntimeSettings struct {
|
||||
// trusted clients.
|
||||
MITMListen *string `json:"mitm_listen,omitempty"`
|
||||
|
||||
// PII pattern overrides — keyed by pattern id, applied to the live
|
||||
// redactor at startup and persisted by POST /api/pii/patterns/persist.
|
||||
// Distinguishes from --pii-config (which replaces the entire
|
||||
// pattern set) by only carrying the per-id action/enabled deltas
|
||||
// against the global default catalog.
|
||||
PIIPatternOverrides *map[string]PIIPatternRuntimeOverride `json:"pii_pattern_overrides,omitempty"`
|
||||
}
|
||||
|
||||
// PIIPatternRuntimeOverride captures the persistable deltas an admin
|
||||
// has applied to a single global PII pattern. Both fields are pointers
|
||||
// so an override that only flips Disabled doesn't have to also restate
|
||||
// Action (and vice versa).
|
||||
type PIIPatternRuntimeOverride struct {
|
||||
Action *string `json:"action,omitempty"`
|
||||
Disabled *bool `json:"disabled,omitempty"`
|
||||
// PIIDefaultDetectors are the token-classification detector models applied
|
||||
// to any PII-enabled model that names no detectors of its own (so
|
||||
// cloud-proxy/MITM redaction works without per-model config). No omitempty:
|
||||
// an empty array must round-trip so the operator can clear it from the UI.
|
||||
PIIDefaultDetectors *[]string `json:"pii_default_detectors"`
|
||||
}
|
||||
|
||||
@@ -175,6 +175,11 @@ var defaultImporters = []Importer{
|
||||
// importer. Matches only the antirez/deepseek-v4-gguf repo + filename
|
||||
// pattern, so false-positives against arbitrary GGUFs are impossible.
|
||||
&DS4Importer{},
|
||||
// PrivacyFilterImporter must precede LlamaCPPImporter too — the OpenMed
|
||||
// privacy-filter GGUFs would otherwise be claimed by the generic .gguf
|
||||
// importer. Matches only .gguf names carrying the "privacy-filter" token,
|
||||
// so arbitrary GGUFs are never claimed.
|
||||
&PrivacyFilterImporter{},
|
||||
&LlamaCPPImporter{},
|
||||
&MLXImporter{},
|
||||
&VLLMImporter{},
|
||||
|
||||
202
core/gallery/importers/privacy-filter.go
Normal file
202
core/gallery/importers/privacy-filter.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package importers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/pkg/downloader"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
"go.yaml.in/yaml/v2"
|
||||
)
|
||||
|
||||
var _ Importer = &PrivacyFilterImporter{}
|
||||
|
||||
// PrivacyFilterImporter recognises the OpenMed privacy-filter PII/NER model
|
||||
// family, served by the standalone privacy-filter.cpp ggml engine (the
|
||||
// openai-privacy-filter architecture). Detection is deliberately narrow: the
|
||||
// engine can only run a privacy-filter GGUF, so we match a .gguf whose name
|
||||
// carries the "privacy-filter" token (e.g. privacy-filter-multilingual-f16.gguf)
|
||||
// or an HF repo that ships one. That keeps us from claiming arbitrary
|
||||
// llama-style GGUFs (the importer is registered before llama-cpp) and from
|
||||
// claiming the upstream OpenMed/privacy-filter-* safetensors repos, which carry
|
||||
// no runnable GGUF. preferences.backend="privacy-filter" forces it regardless.
|
||||
type PrivacyFilterImporter struct{}
|
||||
|
||||
func (i *PrivacyFilterImporter) Name() string { return "privacy-filter" }
|
||||
|
||||
// Modality is "text": the filter operates in the text domain and there is no
|
||||
// dedicated token-classification chip in the import UI, so it groups with the
|
||||
// other text-domain backends (matching how ds4 — another single-family text
|
||||
// GGUF — is classified).
|
||||
func (i *PrivacyFilterImporter) Modality() string { return "text" }
|
||||
func (i *PrivacyFilterImporter) AutoDetects() bool { return true }
|
||||
|
||||
func (i *PrivacyFilterImporter) Match(details Details) bool {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if b, ok := preferencesMap["backend"].(string); ok && b == "privacy-filter" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Direct URL or path to a privacy-filter GGUF.
|
||||
if isPrivacyFilterGGUF(filepath.Base(details.URI)) {
|
||||
return true
|
||||
}
|
||||
|
||||
// HF repo shipping at least one privacy-filter GGUF.
|
||||
if details.HuggingFace != nil {
|
||||
for _, f := range details.HuggingFace.Files {
|
||||
if isPrivacyFilterGGUF(filepath.Base(f.Path)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: hfapi recursion bug may leave HuggingFace nil — match a repo
|
||||
// that names itself as the privacy-filter GGUF distribution (both tokens
|
||||
// present), e.g. LocalAI-io/privacy-filter-multilingual-GGUF. Requiring
|
||||
// "gguf" keeps the safetensors-only source repo out.
|
||||
if _, repo, ok := HFOwnerRepoFromURI(details.URI); ok {
|
||||
lower := strings.ToLower(repo)
|
||||
if privacyFilterName(lower) && strings.Contains(lower, "gguf") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (i *PrivacyFilterImporter) Import(details Details) (gallery.ModelConfig, error) {
|
||||
preferences, err := details.Preferences.MarshalJSON()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
preferencesMap := make(map[string]any)
|
||||
if len(preferences) > 0 {
|
||||
if err := json.Unmarshal(preferences, &preferencesMap); err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
}
|
||||
|
||||
name, ok := preferencesMap["name"].(string)
|
||||
if !ok {
|
||||
name = filepath.Base(details.URI)
|
||||
}
|
||||
|
||||
description, ok := preferencesMap["description"].(string)
|
||||
if !ok {
|
||||
description = "Imported from " + details.URI
|
||||
}
|
||||
|
||||
// The token classifier's accuracy is parity-sensitive, so prefer the
|
||||
// highest-precision weights first (f16 is what the gallery ships today),
|
||||
// then fall back down the quant ladder; the last file wins if none match.
|
||||
preferredQuants, _ := preferencesMap["quantizations"].(string)
|
||||
quants := []string{"f16", "q8_0", "q6_k", "q5_k", "q4_k"}
|
||||
if preferredQuants != "" {
|
||||
quants = strings.Split(preferredQuants, ",")
|
||||
}
|
||||
|
||||
cfg := gallery.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
}
|
||||
|
||||
trueV := true
|
||||
modelConfig := config.ModelConfig{
|
||||
Name: name,
|
||||
Description: description,
|
||||
Backend: "privacy-filter",
|
||||
// embeddings:true mirrors the gallery entry — the privacy-filter
|
||||
// backend loads in embedding mode to expose per-token logits.
|
||||
Embeddings: &trueV,
|
||||
// token_classify reserves the model for the PII NER tier; another
|
||||
// model opts into redaction by listing this one under pii.detectors.
|
||||
KnownUsecaseStrings: []string{"token_classify"},
|
||||
}
|
||||
|
||||
uri := downloader.URI(details.URI)
|
||||
directGGUF := isPrivacyFilterGGUF(filepath.Base(details.URI))
|
||||
switch {
|
||||
case uri.LooksLikeURL() && directGGUF:
|
||||
// Direct file URL (e.g. .../resolve/main/privacy-filter-multilingual-f16.gguf).
|
||||
// The exact file is known, no quant pick.
|
||||
fileName, err := uri.FilenameFromUrl()
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
target := filepath.Join("privacy-filter", "models", name, fileName)
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: details.URI,
|
||||
Filename: target,
|
||||
})
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: target},
|
||||
}
|
||||
case details.HuggingFace != nil:
|
||||
// HF repo: collect every privacy-filter GGUF, pick the preferred quant,
|
||||
// and nest under privacy-filter/models/<name>/ so a multi-quant repo
|
||||
// doesn't collide on disk.
|
||||
var ggufFiles []hfapi.ModelFile
|
||||
for _, f := range details.HuggingFace.Files {
|
||||
if isPrivacyFilterGGUF(filepath.Base(f.Path)) {
|
||||
ggufFiles = append(ggufFiles, f)
|
||||
}
|
||||
}
|
||||
if chosen, ok := pickPreferredGGMLFile(ggufFiles, quants); ok {
|
||||
target := filepath.Join("privacy-filter", "models", name, filepath.Base(chosen.Path))
|
||||
cfg.Files = append(cfg.Files, gallery.File{
|
||||
URI: chosen.URL,
|
||||
Filename: target,
|
||||
SHA256: chosen.SHA256,
|
||||
})
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: target},
|
||||
}
|
||||
}
|
||||
default:
|
||||
// Bare URI with no HF metadata (pref-only path): point at the basename
|
||||
// so users can tweak the YAML after import.
|
||||
modelConfig.PredictionOptions = schema.PredictionOptions{
|
||||
BasicModelRequest: schema.BasicModelRequest{Model: filepath.Base(details.URI)},
|
||||
}
|
||||
}
|
||||
|
||||
data, err := yaml.Marshal(modelConfig)
|
||||
if err != nil {
|
||||
return gallery.ModelConfig{}, err
|
||||
}
|
||||
cfg.ConfigFile = string(data)
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// privacyFilterName reports whether a lower-cased string carries the
|
||||
// privacy-filter token in either separator form.
|
||||
func privacyFilterName(lower string) bool {
|
||||
return strings.Contains(lower, "privacy-filter") || strings.Contains(lower, "privacy_filter")
|
||||
}
|
||||
|
||||
// isPrivacyFilterGGUF reports whether name is a privacy-filter GGUF: a .gguf
|
||||
// file whose name carries the privacy-filter token. The .gguf check is
|
||||
// case-insensitive.
|
||||
func isPrivacyFilterGGUF(name string) bool {
|
||||
lower := strings.ToLower(name)
|
||||
if !strings.HasSuffix(lower, ".gguf") {
|
||||
return false
|
||||
}
|
||||
return privacyFilterName(lower)
|
||||
}
|
||||
104
core/gallery/importers/privacy-filter_test.go
Normal file
104
core/gallery/importers/privacy-filter_test.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package importers_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mudler/LocalAI/core/gallery/importers"
|
||||
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// privacyFilterDetails builds Details carrying a synthetic HF file list so
|
||||
// detection can be exercised without hitting the network.
|
||||
func privacyFilterDetails(uri string, prefs string, files ...hfapi.ModelFile) importers.Details {
|
||||
return importers.Details{
|
||||
URI: uri,
|
||||
Preferences: json.RawMessage(prefs),
|
||||
HuggingFace: &hfapi.ModelDetails{Files: files},
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("PrivacyFilterImporter", func() {
|
||||
imp := &importers.PrivacyFilterImporter{}
|
||||
|
||||
Context("Importer interface metadata", func() {
|
||||
It("exposes name/modality/autodetect", func() {
|
||||
Expect(imp.Name()).To(Equal("privacy-filter"))
|
||||
Expect(imp.Modality()).To(Equal("text"))
|
||||
Expect(imp.AutoDetects()).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
Context("detection (Match)", func() {
|
||||
It("matches an HF repo shipping a privacy-filter GGUF", func() {
|
||||
d := privacyFilterDetails("huggingface://LocalAI-io/privacy-filter-multilingual-GGUF", "",
|
||||
hfapi.ModelFile{Path: "privacy-filter-multilingual-f16.gguf", URL: "https://hf/f16"})
|
||||
Expect(imp.Match(d)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches a direct URL to a privacy-filter GGUF", func() {
|
||||
d := privacyFilterDetails("https://hf/resolve/main/privacy-filter-multilingual-f16.gguf", "")
|
||||
Expect(imp.Match(d)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("matches the GGUF distribution repo by name when HF metadata is absent", func() {
|
||||
d := importers.Details{URI: "huggingface://LocalAI-io/privacy-filter-multilingual-GGUF", Preferences: json.RawMessage("")}
|
||||
Expect(imp.Match(d)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("honours preferences.backend=privacy-filter for arbitrary URIs", func() {
|
||||
d := privacyFilterDetails("huggingface://some/unrelated-repo", `{"backend":"privacy-filter"}`)
|
||||
Expect(imp.Match(d)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("does NOT claim a generic llama-style GGUF", func() {
|
||||
d := privacyFilterDetails("huggingface://TheBloke/Llama-2-7B-GGUF", "",
|
||||
hfapi.ModelFile{Path: "llama-2-7b.Q4_K_M.gguf", URL: "https://hf/llama"})
|
||||
Expect(imp.Match(d)).To(BeFalse())
|
||||
})
|
||||
|
||||
It("does NOT claim the upstream safetensors source repo (no GGUF)", func() {
|
||||
d := privacyFilterDetails("huggingface://OpenMed/privacy-filter-multilingual", "",
|
||||
hfapi.ModelFile{Path: "model.safetensors", URL: "https://hf/st"},
|
||||
hfapi.ModelFile{Path: "config.json", URL: "https://hf/cfg"})
|
||||
Expect(imp.Match(d)).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
Context("import (Import)", func() {
|
||||
It("emits a privacy-filter token_classify config from an HF GGUF repo", func() {
|
||||
d := privacyFilterDetails("huggingface://LocalAI-io/privacy-filter-multilingual-GGUF", `{"name":"pii"}`,
|
||||
hfapi.ModelFile{Path: "privacy-filter-multilingual-f16.gguf", URL: "https://hf/f16", SHA256: "abc"})
|
||||
cfg, err := imp.Import(d)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.ConfigFile).To(ContainSubstring("backend: privacy-filter"), fmt.Sprintf("%+v", cfg))
|
||||
Expect(cfg.ConfigFile).To(ContainSubstring("token_classify"))
|
||||
Expect(cfg.ConfigFile).To(ContainSubstring("embeddings: true"))
|
||||
Expect(cfg.Files).To(HaveLen(1))
|
||||
Expect(cfg.Files[0].URI).To(Equal("https://hf/f16"))
|
||||
Expect(cfg.Files[0].SHA256).To(Equal("abc"))
|
||||
Expect(cfg.Files[0].Filename).To(ContainSubstring("privacy-filter/models/pii/privacy-filter-multilingual-f16.gguf"))
|
||||
})
|
||||
|
||||
It("prefers the highest-precision quant (f16) from a multi-quant repo", func() {
|
||||
d := privacyFilterDetails("huggingface://LocalAI-io/privacy-filter-multilingual-GGUF", "",
|
||||
hfapi.ModelFile{Path: "privacy-filter-multilingual-q4_k.gguf", URL: "https://hf/q4k"},
|
||||
hfapi.ModelFile{Path: "privacy-filter-multilingual-f16.gguf", URL: "https://hf/f16"})
|
||||
cfg, err := imp.Import(d)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Files).To(HaveLen(1))
|
||||
Expect(cfg.Files[0].URI).To(Equal("https://hf/f16"), "f16 should win over q4_k")
|
||||
})
|
||||
|
||||
It("uses the exact file for a direct GGUF URL", func() {
|
||||
d := privacyFilterDetails("https://hf/resolve/main/privacy-filter-multilingual-f16.gguf", "")
|
||||
cfg, err := imp.Import(d)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cfg.Files).To(HaveLen(1))
|
||||
Expect(cfg.Files[0].Filename).To(ContainSubstring("privacy-filter/models/"))
|
||||
Expect(cfg.Files[0].Filename).To(ContainSubstring("privacy-filter-multilingual-f16.gguf"))
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -123,6 +123,10 @@ var RouteFeatureRegistry = []RouteFeature{
|
||||
{"GET", "/api/fine-tuning/jobs/:id/download", FeatureFineTuning},
|
||||
{"POST", "/api/fine-tuning/datasets", FeatureFineTuning},
|
||||
|
||||
// PII analyze/redact service (the events log stays admin-gated in-handler)
|
||||
{"POST", "/api/pii/analyze", FeaturePIIFilter},
|
||||
{"POST", "/api/pii/redact", FeaturePIIFilter},
|
||||
|
||||
// Quantization
|
||||
{"POST", "/api/quantization/jobs", FeatureQuantization},
|
||||
{"GET", "/api/quantization/jobs", FeatureQuantization},
|
||||
@@ -181,5 +185,6 @@ func APIFeatureMetas() []FeatureMeta {
|
||||
{FeatureFaceRecognition, "Face Recognition", true},
|
||||
{FeatureVoiceRecognition, "Voice Recognition", true},
|
||||
{FeatureAudioTransform, "Audio Transform", true},
|
||||
{FeaturePIIFilter, "PII Analyze / Redact", true},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,6 +56,10 @@ const (
|
||||
FeatureFaceRecognition = "face_recognition"
|
||||
FeatureVoiceRecognition = "voice_recognition"
|
||||
FeatureAudioTransform = "audio_transform"
|
||||
// FeaturePIIFilter gates the synchronous PII analyze/redact service
|
||||
// (POST /api/pii/{analyze,redact}). Default ON like the other API
|
||||
// features; the admin-only events log is gated separately in-handler.
|
||||
FeaturePIIFilter = "pii_filter"
|
||||
)
|
||||
|
||||
// AgentFeatures lists agent-related features (default OFF).
|
||||
@@ -71,6 +75,7 @@ var APIFeatures = []string{
|
||||
FeatureVAD, FeatureDetection, FeatureVideo, FeatureEmbeddings, FeatureSound,
|
||||
FeatureRealtime, FeatureRerank, FeatureTokenize, FeatureMCP, FeatureStores,
|
||||
FeatureFaceRecognition, FeatureVoiceRecognition, FeatureAudioTransform,
|
||||
FeaturePIIFilter,
|
||||
}
|
||||
|
||||
// AllFeatures lists all known features (used by UI and validation).
|
||||
|
||||
@@ -10,13 +10,11 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp"
|
||||
openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -30,7 +28,7 @@ import (
|
||||
// @Param request body schema.AnthropicRequest true "query params"
|
||||
// @Success 200 {object} schema.AnthropicResponse "Response"
|
||||
// @Router /v1/messages [post]
|
||||
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, piiRedactor *pii.Redactor, piiEvents pii.EventStore) echo.HandlerFunc {
|
||||
func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
id := uuid.New().String()
|
||||
|
||||
@@ -53,7 +51,7 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
|
||||
// Cloud-proxy bail. Same shape as the OpenAI chat endpoint —
|
||||
// forwards via the cloud-proxy gRPC backend.
|
||||
if cfg.IsCloudProxyBackendPassthrough() {
|
||||
return forwardCloudProxyAnthropicViaBackend(c, cfg, input, piiRedactor, piiEvents, ml, appConfig)
|
||||
return forwardCloudProxyAnthropicViaBackend(c, cfg, input, ml, appConfig)
|
||||
}
|
||||
|
||||
// Convert Anthropic messages to OpenAI format for internal processing
|
||||
@@ -141,7 +139,7 @@ func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evalu
|
||||
xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput)
|
||||
|
||||
if input.Stream {
|
||||
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator, piiRedactor, piiEvents)
|
||||
return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
|
||||
}
|
||||
|
||||
return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpExecutor, evaluator)
|
||||
@@ -330,36 +328,13 @@ func handleAnthropicNonStream(c echo.Context, id string, input *schema.Anthropic
|
||||
return sendAnthropicError(c, 500, "api_error", "MCP iteration limit reached")
|
||||
}
|
||||
|
||||
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator, piiRedactor *pii.Redactor, piiEvents pii.EventStore) error {
|
||||
func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpExecutor mcpTools.ToolExecutor, evaluator *templates.Evaluator) error {
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
|
||||
// Per-stream PII filter — same gating as the OpenAI chat path. The
|
||||
// filter is wire-format-agnostic; we feed it the text portion of
|
||||
// each text_delta and emit only what's safe to send. The filter
|
||||
// holds back a tail of size MaxPatternLength-1 so a pattern split
|
||||
// across chunk boundaries still gets masked. When PII is disabled
|
||||
// for this model the filter is nil and emits flow unchanged.
|
||||
var streamPIIFilter *pii.StreamFilter
|
||||
if piiRedactor != nil && cfg.PIIIsEnabled() {
|
||||
correlationID := c.Request().Header.Get("x-request-id")
|
||||
userID := ""
|
||||
if u := auth.GetUser(c); u != nil {
|
||||
userID = u.ID
|
||||
}
|
||||
var overrides map[string]pii.Action
|
||||
if raw := cfg.PIIPatternOverrides(); len(raw) > 0 {
|
||||
overrides = make(map[string]pii.Action, len(raw))
|
||||
for ovid, action := range raw {
|
||||
switch pii.Action(action) {
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionAllow:
|
||||
overrides[ovid] = pii.Action(action)
|
||||
}
|
||||
}
|
||||
}
|
||||
streamPIIFilter = pii.NewStreamFilter(piiRedactor, overrides, piiEvents, correlationID, userID)
|
||||
}
|
||||
// Response/output PII redaction is out of scope for now — redaction
|
||||
// runs request-side only (the NER middleware).
|
||||
|
||||
// Send message_start event
|
||||
messageStart := schema.AnthropicStreamEvent{
|
||||
@@ -440,7 +415,6 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
|
||||
if len(toolCalls) > toolCallsEmitted {
|
||||
if !inToolCall && currentBlockIndex == 0 {
|
||||
drainStreamPIIToText(c, streamPIIFilter, intPtr(currentBlockIndex))
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
@@ -481,20 +455,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
}
|
||||
|
||||
if !inToolCall && token != "" {
|
||||
out := token
|
||||
if streamPIIFilter != nil {
|
||||
out = streamPIIFilter.Push(token)
|
||||
}
|
||||
if out != "" {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: out,
|
||||
},
|
||||
})
|
||||
}
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: token,
|
||||
},
|
||||
})
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -532,20 +500,14 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
// didn't already stream it (autoparser clears raw text, so
|
||||
// accumulatedContent will be empty in that case).
|
||||
if deltaContent != "" && !inToolCall && accumulatedContent == "" {
|
||||
out := deltaContent
|
||||
if streamPIIFilter != nil {
|
||||
out = streamPIIFilter.Push(deltaContent)
|
||||
}
|
||||
if out != "" {
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: out,
|
||||
},
|
||||
})
|
||||
}
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: intPtr(0),
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: deltaContent,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Emit tool_use blocks from ChatDeltas
|
||||
@@ -553,7 +515,6 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
collectedToolCalls = deltaToolCalls
|
||||
|
||||
if !inToolCall && currentBlockIndex == 0 {
|
||||
drainStreamPIIToText(c, streamPIIFilter, intPtr(currentBlockIndex))
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
@@ -657,9 +618,7 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
if !shouldUseFn && cfg.FunctionsConfig.AutomaticToolParsingFallback && accumulatedContent != "" && toolCallsEmitted == 0 {
|
||||
parsed := functions.ParseFunctionCall(accumulatedContent, cfg.FunctionsConfig)
|
||||
if len(parsed) > 0 {
|
||||
// Close the text content block (after flushing any
|
||||
// residual the streaming PII filter held back).
|
||||
drainStreamPIIToText(c, streamPIIFilter, intPtr(currentBlockIndex))
|
||||
// Close the text content block.
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(currentBlockIndex),
|
||||
@@ -699,12 +658,8 @@ func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicReq
|
||||
}
|
||||
}
|
||||
|
||||
// No MCP tools to execute, close stream. drainStreamPIIToText
|
||||
// flushes any residual the streaming PII filter held back as
|
||||
// part of its trailing pattern-window before we close the
|
||||
// text content block.
|
||||
// No MCP tools to execute, close the text content block.
|
||||
if !inToolCall {
|
||||
drainStreamPIIToText(c, streamPIIFilter, intPtr(0))
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_stop",
|
||||
Index: intPtr(0),
|
||||
@@ -752,30 +707,6 @@ func convertFuncsToOpenAITools(funcs functions.Functions) []functions.Tool {
|
||||
|
||||
func intPtr(i int) *int { return &i }
|
||||
|
||||
// drainStreamPIIToText flushes any residual the streaming PII filter
|
||||
// has been holding back as part of its trailing pattern-window, and
|
||||
// emits it as one final text_delta into the named block before the
|
||||
// caller closes that block. Drain is idempotent: calling it twice on
|
||||
// the same filter returns "" the second time. Safe to call with a nil
|
||||
// filter (no-op).
|
||||
func drainStreamPIIToText(c echo.Context, sf *pii.StreamFilter, index *int) {
|
||||
if sf == nil {
|
||||
return
|
||||
}
|
||||
residual := sf.Drain()
|
||||
if residual == "" {
|
||||
return
|
||||
}
|
||||
sendAnthropicSSE(c, schema.AnthropicStreamEvent{
|
||||
Type: "content_block_delta",
|
||||
Index: index,
|
||||
Delta: &schema.AnthropicStreamDelta{
|
||||
Type: "text_delta",
|
||||
Text: residual,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) {
|
||||
data, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
@@ -973,17 +904,14 @@ func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConf
|
||||
}
|
||||
|
||||
// forwardCloudProxyAnthropicViaBackend marshals the Anthropic request,
|
||||
// constructs the streaming PII filter (when applicable), and hands the
|
||||
// body off to the cloud-proxy gRPC backend. Model swap + upstream auth
|
||||
// headers are applied inside the backend; the filter is built here
|
||||
// because the auth/correlation context only exists in the echo handler.
|
||||
func forwardCloudProxyAnthropicViaBackend(c echo.Context, cfg *config.ModelConfig, input *schema.AnthropicRequest, piiRedactor *pii.Redactor, piiEvents pii.EventStore, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
// and hands the body off to the cloud-proxy gRPC backend. Model swap +
|
||||
// upstream auth headers are applied inside the backend. Request-side PII
|
||||
// redaction already ran in the middleware; the response is forwarded
|
||||
// unmodified.
|
||||
func forwardCloudProxyAnthropicViaBackend(c echo.Context, cfg *config.ModelConfig, input *schema.AnthropicRequest, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
body, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return sendAnthropicError(c, 400, "invalid_request_error", "cloudproxy: marshal request: "+err.Error())
|
||||
}
|
||||
|
||||
correlationID := c.Request().Header.Get("x-request-id")
|
||||
streamFilter := cloudproxy.BuildStreamFilter(c, cfg, input.Stream, piiRedactor, piiEvents, correlationID)
|
||||
return cloudproxy.ForwardViaBackend(c, cfg, body, streamFilter, ml, appConfig)
|
||||
return cloudproxy.ForwardViaBackend(c, cfg, body, ml, appConfig)
|
||||
}
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// drainStreamPIIToText is called from four sites in messages.go and is
|
||||
// the load-bearing primitive for "the streaming filter has buffered
|
||||
// some bytes that the request just ended on; flush them as a final
|
||||
// text_delta event before closing the content block". A regression
|
||||
// here would silently truncate the last few bytes of an assistant
|
||||
// response on every PII-enabled stream — invisible without coverage.
|
||||
|
||||
// newTestFilter compiles the default patterns and returns a filter
|
||||
// that holds back its trailing pattern-window; pushing a short string
|
||||
// (shorter than holdLen) keeps the bytes inside Drain.
|
||||
func newTestFilter() *pii.StreamFilter {
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred())
|
||||
red := pii.NewRedactor(patterns)
|
||||
return pii.NewStreamFilter(red, nil, nil, "", "")
|
||||
}
|
||||
|
||||
// newTestContext builds a recording echo context — the recorder
|
||||
// captures the SSE bytes drainStreamPIIToText writes.
|
||||
func newTestContext() (echo.Context, *httptest.ResponseRecorder) {
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader("{}"))
|
||||
rec := httptest.NewRecorder()
|
||||
return echo.New().NewContext(req, rec), rec
|
||||
}
|
||||
|
||||
var _ = Describe("drainStreamPIIToText", func() {
|
||||
It("is a no-op when the filter is nil", func() {
|
||||
c, rec := newTestContext()
|
||||
drainStreamPIIToText(c, nil, intPtr(0))
|
||||
Expect(rec.Body.Len()).To(Equal(0), "nil filter wrote %d bytes: %q", rec.Body.Len(), rec.Body.String())
|
||||
})
|
||||
|
||||
It("emits nothing when the drain is empty", func() {
|
||||
// A filter with nothing buffered should not emit a phantom event;
|
||||
// otherwise every non-PII response would close with an empty
|
||||
// text_delta that pollutes downstream parsers.
|
||||
sf := newTestFilter()
|
||||
c, rec := newTestContext()
|
||||
drainStreamPIIToText(c, sf, intPtr(0))
|
||||
Expect(rec.Body.Len()).To(Equal(0), "empty drain wrote %d bytes: %q", rec.Body.Len(), rec.Body.String())
|
||||
})
|
||||
|
||||
It("flushes residual buffered bytes as a text_delta event", func() {
|
||||
sf := newTestFilter()
|
||||
// Push less than holdLen so all bytes are retained until Drain.
|
||||
// "tail" is short enough that no pattern is plausible.
|
||||
out := sf.Push("tail")
|
||||
Expect(out).To(Equal(""), "Push of short text emitted %q; want all bytes held", out)
|
||||
|
||||
c, rec := newTestContext()
|
||||
drainStreamPIIToText(c, sf, intPtr(2))
|
||||
|
||||
body := rec.Body.String()
|
||||
// Wire format: "event: content_block_delta\ndata: {…}\n\n"
|
||||
Expect(body).To(ContainSubstring("event: content_block_delta"))
|
||||
Expect(body).To(ContainSubstring(`"type":"content_block_delta"`))
|
||||
Expect(body).To(ContainSubstring(`"index":2`))
|
||||
Expect(body).To(ContainSubstring(`"text":"tail"`))
|
||||
Expect(body).To(ContainSubstring(`"type":"text_delta"`))
|
||||
Expect(strings.HasSuffix(body, "\n\n")).To(BeTrue(), "SSE event missing trailing blank line: %q", body)
|
||||
})
|
||||
|
||||
It("is idempotent across consecutive drains", func() {
|
||||
// Two consecutive Drains: the filter returns "" the second time,
|
||||
// so the second drainStreamPIIToText must emit nothing. The
|
||||
// production path in messages.go has at least four call sites
|
||||
// that may overlap (currentBlockIndex==0 emergency path + the
|
||||
// unconditional drain near the end of the stream); without
|
||||
// idempotence we'd duplicate the residual on the wire.
|
||||
sf := newTestFilter()
|
||||
sf.Push("tail")
|
||||
|
||||
c1, rec1 := newTestContext()
|
||||
drainStreamPIIToText(c1, sf, intPtr(0))
|
||||
first := rec1.Body.Len()
|
||||
Expect(first).NotTo(Equal(0), "first drain emitted nothing")
|
||||
|
||||
c2, rec2 := newTestContext()
|
||||
drainStreamPIIToText(c2, sf, intPtr(0))
|
||||
Expect(rec2.Body.Len()).To(Equal(0), "second drain wrote %d bytes; want idempotent no-op: %q", rec2.Body.Len(), rec2.Body.String())
|
||||
})
|
||||
|
||||
It("masks redacted residual instead of leaking it", func() {
|
||||
// The held tail must travel through the redactor on Drain. If
|
||||
// the bytes happen to form a complete pattern at end-of-stream,
|
||||
// the residual emit must contain the mask placeholder, not the
|
||||
// raw value.
|
||||
sf := newTestFilter()
|
||||
// "alice@example.com" is 17 bytes. holdLen for default patterns
|
||||
// is well above 17, so this stays buffered until Drain, which
|
||||
// then redacts it.
|
||||
out := sf.Push("alice@example.com")
|
||||
Expect(out).To(Equal(""), "Push emitted bytes early: %q", out)
|
||||
|
||||
c, rec := newTestContext()
|
||||
drainStreamPIIToText(c, sf, intPtr(0))
|
||||
body := rec.Body.String()
|
||||
Expect(body).NotTo(ContainSubstring("alice@example.com"), "raw email leaked in residual emit: %q", body)
|
||||
Expect(body).To(ContainSubstring("[REDACTED:email]"), "residual emit missing mask placeholder: %q", body)
|
||||
})
|
||||
})
|
||||
@@ -100,15 +100,15 @@ var instructionDefs = []instructionDef{
|
||||
},
|
||||
{
|
||||
Name: "pii-filtering",
|
||||
Description: "Inspect and tune the regex PII filter applied to chat requests",
|
||||
Description: "Inspect the NER-based PII filter applied to chat requests",
|
||||
Tags: []string{"pii"},
|
||||
Intro: "GET /api/pii/patterns lists the active pattern set with each one's action (mask, block, allow). GET /api/pii/events returns recent redaction events filtered by correlation_id / user_id / pattern_id (admin or local-user only). POST /api/pii/test dry-runs the redactor against an admin-supplied string. POST /api/pii/decide is the programmatic decision oracle for external routers: send `{text}`, receive `{findings, suggested_action, redacted_preview}` without LocalAI mutating, recording, or acting on the call — caller composes the action with its own policy. Default patterns: email, phone, SSN, credit card (Luhn), IPv4, common API key prefixes (sk-, pk-, ghp_, github_pat_). PII is per-model: by default it is OFF for non-proxy backends and ON for backends starting with proxy-* (cloud passthroughs). Opt in with `pii: { enabled: true }` in a model's YAML; use `pii: { patterns: [{id, action}] }` to upgrade or downgrade individual actions for that model. Override global default actions via --pii-config pii.yaml; --disable-pii turns the filter off entirely.",
|
||||
Intro: "PII redaction is NER-based and request-side. A consuming model opts in with `pii: { enabled: true, detectors: [<model>] }` where each detector is a token-classification (token_classify) model. The detection policy lives on the detector model itself in a `pii_detection:` block: `{ min_score, default_action (mask|block|allow), entity_actions: { GROUP: action } }`. Multiple detectors union their hits; overlapping spans resolve to the strongest action (block > mask > allow). PII defaults OFF for non-proxy backends and ON for proxy-* (cloud passthroughs). Besides the inline path, two synchronous service endpoints expose the same engine without an inference request: POST /api/pii/analyze returns the detected entity spans (entity_type, source ner|pattern, start/end, score, action) without mutating the text, and POST /api/pii/redact applies the policy — returning redacted_text, or 400 (type pii_blocked) with the offending entities when a block action fires. Both take `{ text, detectors:[<model>...] }` (or `model` to inherit a consuming model's detectors), require the pii_filter feature (any authenticated user), and record audit events with an `origin` of pii_analyze / pii_redact. GET /api/pii/events returns recent redaction events filtered by correlation_id / user_id / pattern_id / origin (middleware|proxy|pii_analyze|pii_redact); events carry `<source>:<GROUP>` ids — e.g. `ner:EMAIL` for the neural detector, `pattern:ANTHROPIC_KEY` for the regex pattern tier — and an 8-char hash prefix, never the matched value (admin or local-user only). The legacy regex pattern tier and its endpoints (/api/pii/patterns, /test, /decide) were removed.",
|
||||
},
|
||||
{
|
||||
Name: "middleware-admin",
|
||||
Description: "Inspect and configure the routing-module middleware (PII filter and routing)",
|
||||
Tags: []string{"middleware", "pii", "router"},
|
||||
Intro: "GET /api/middleware/status is the single round-trip the /app/middleware admin page reads to render the current state: active PII patterns and their actions, every model's resolved enabled/override state, recent event count, and the active routing models with their classifier configurations. Admin-only (the synthetic local user is admin in no-auth mode). PUT /api/pii/patterns/:id changes a pattern's action in-process — TRANSIENT, lost on restart. To persist, edit --pii-config YAML. GET /api/router/decisions returns the routing decision log filtered by correlation_id / user_id / router_model. The same surface is exposed as MCP tools (`get_middleware_status`, `set_pii_pattern_action`, `get_router_decisions`) for agent-driven configuration.",
|
||||
Intro: "GET /api/middleware/status is the single round-trip the /app/middleware admin page reads to render the current state: every model's resolved PII enabled state and the NER detector models it references, recent event count, and the active routing models with their classifier configurations. Admin-only (the synthetic local user is admin in no-auth mode). PII detection policy is edited on each detector model's `pii_detection:` block via the model-config tools/UI — there is no global pattern set to mutate. GET /api/router/decisions returns the routing decision log filtered by correlation_id / user_id / router_model. The same surface is exposed as MCP tools (`get_middleware_status`, `get_pii_events`, `get_router_decisions`) for agent-driven inspection.",
|
||||
},
|
||||
{
|
||||
Name: "intelligent-routing",
|
||||
|
||||
@@ -25,6 +25,10 @@ var knownPrefOnlyBackends = []schema.KnownBackend{
|
||||
// Text LLM
|
||||
// ds4: antirez/ds4 - single-model DeepSeek V4 Flash engine; auto-detected via DS4Importer
|
||||
{Name: "ds4", Modality: "text", AutoDetect: false, Description: "antirez/ds4 DeepSeek V4 Flash engine (auto-detected; pref-only fallback)"},
|
||||
// privacy-filter is now auto-detected via PrivacyFilterImporter (see
|
||||
// core/gallery/importers/privacy-filter.go); the importer registry entry
|
||||
// supersedes any pref-only line here, which the /backends/known merge would
|
||||
// dedupe away.
|
||||
{Name: "sglang", Modality: "text", AutoDetect: false, Description: "SGLang runtime (preference-only)"},
|
||||
{Name: "tinygrad", Modality: "text", AutoDetect: false, Description: "tinygrad runtime (preference-only)"},
|
||||
{Name: "trl", Modality: "text", AutoDetect: false, Description: "Transformers Reinforcement Learning (preference-only)"},
|
||||
|
||||
@@ -88,7 +88,20 @@ var _ = Describe("Backend Endpoints", func() {
|
||||
}
|
||||
Expect(names).To(ContainElements(
|
||||
"llama-cpp", "mlx", "vllm", "transformers", "diffusers",
|
||||
"privacy-filter",
|
||||
))
|
||||
|
||||
// privacy-filter is auto-detected via PrivacyFilterImporter, so it
|
||||
// surfaces from the importer registry (AutoDetect=true) rather than
|
||||
// the curated pref-only slice.
|
||||
byName := map[string]schema.KnownBackend{}
|
||||
for _, b := range payload {
|
||||
byName[b.Name] = b
|
||||
}
|
||||
pf, ok := byName["privacy-filter"]
|
||||
Expect(ok).To(BeTrue(), "privacy-filter must be present")
|
||||
Expect(pf.AutoDetect).To(BeTrue(), "privacy-filter is auto-detected via its importer")
|
||||
Expect(pf.Modality).To(Equal("text"))
|
||||
})
|
||||
|
||||
It("includes drop-in llama-cpp replacements with AutoDetect=false", func() {
|
||||
|
||||
@@ -126,6 +126,8 @@ func AutocompleteEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, a
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)
|
||||
case "score": // router classifier usecase (FLAG_SCORE); not in UsecaseInfoMap
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_SCORE)
|
||||
case config.UsecaseTokenClassify: // PII NER detector usecase (FLAG_TOKEN_CLASSIFY)
|
||||
filterFn = config.BuildUsecaseFilterFn(config.FLAG_TOKEN_CLASSIFY)
|
||||
default:
|
||||
filterFn = config.NoFilterFn
|
||||
}
|
||||
|
||||
@@ -65,7 +65,7 @@ func MCPEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
// the per-model PII config and is kept for backward compatibility.
|
||||
// The request-side middleware on the main chat route handles
|
||||
// filtering for the standard /v1/chat/completions path.
|
||||
chatHandler := openai.ChatEndpoint(cl, ml, evaluator, appConfig, natsClient, nil, nil, nil)
|
||||
chatHandler := openai.ChatEndpoint(cl, ml, evaluator, appConfig, natsClient, nil)
|
||||
|
||||
return func(c echo.Context) error {
|
||||
input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest)
|
||||
|
||||
248
core/http/endpoints/localai/pii.go
Normal file
248
core/http/endpoints/localai/pii.go
Normal file
@@ -0,0 +1,248 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
|
||||
// ErrNoDetectors is returned by RunPIIScan when neither an explicit detector
|
||||
// list nor a model's effective PII policy resolve to anything to scan with —
|
||||
// including a model that has PII disabled, or one that is enabled but names
|
||||
// no detectors while no instance-wide default is set. The handler maps it to
|
||||
// 400: the truthful answer is "the middleware would scan nothing", and
|
||||
// surfacing that loudly beats implying a clean scan happened.
|
||||
var ErrNoDetectors = errors.New("no PII detectors specified")
|
||||
|
||||
// ErrUnknownDetector is returned when a named detector model cannot be
|
||||
// resolved. Wrapped (errors.Is) so the handler can map it to 400 — a bad
|
||||
// detector name is a client error, distinct from a detector that resolved but
|
||||
// failed at scan time (mapped to 502, fail-closed).
|
||||
var ErrUnknownDetector = errors.New("unknown PII detector")
|
||||
|
||||
// RunPIIScan resolves the requested detectors and runs the shared NER/pattern
|
||||
// redaction pipeline over text. It is the engine behind both /api/pii/analyze
|
||||
// and /api/pii/redact, kept free of echo so the resolution + scan logic is
|
||||
// unit-testable with a fake resolver.
|
||||
//
|
||||
// Detector selection mirrors the inline chat middleware (middleware.go):
|
||||
// explicit names take precedence; otherwise the consuming model's effective
|
||||
// policy is resolved through policy (Application.ResolvePIIPolicy — the
|
||||
// model's own pii.detectors, else the instance-wide PIIDefaultDetectors, and
|
||||
// nothing when the model has PII disabled), so the model path answers "what
|
||||
// would the middleware do with this text?" with the same inputs the
|
||||
// middleware uses. A nil policy falls back to the model's raw pii.detectors
|
||||
// (unit tests). Unknown names fail closed (ErrUnknownDetector) rather than
|
||||
// silently scanning with fewer detectors than asked for.
|
||||
func RunPIIScan(ctx context.Context, resolver pii.NERDetectorResolver, cl *config.ModelConfigLoader, policy pii.PolicyResolver, names []string, model, text string) (pii.Result, error) {
|
||||
if len(names) == 0 && model != "" && cl != nil {
|
||||
if cfg, ok := cl.GetModelConfig(model); ok {
|
||||
if policy != nil {
|
||||
if enabled, detectors := policy(&cfg); enabled {
|
||||
names = detectors
|
||||
}
|
||||
} else {
|
||||
names = cfg.PIIDetectors()
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return pii.Result{}, ErrNoDetectors
|
||||
}
|
||||
|
||||
cfgs := make([]pii.NERConfig, 0, len(names))
|
||||
for _, name := range names {
|
||||
nc, ok := resolver(name)
|
||||
if !ok {
|
||||
return pii.Result{}, fmt.Errorf("%w: %q", ErrUnknownDetector, name)
|
||||
}
|
||||
cfgs = append(cfgs, nc)
|
||||
}
|
||||
return pii.RedactNER(ctx, text, cfgs)
|
||||
}
|
||||
|
||||
// piiEntities maps redaction spans to API entities. Each span's Pattern is the
|
||||
// synthetic "<source>:<GROUP>" id (e.g. "ner:EMAIL"); it is split back into
|
||||
// the entity type and its source tier. hash_prefix is included only when
|
||||
// revealHash is set (admin + reveal) — the raw matched value is never exposed.
|
||||
func piiEntities(spans []pii.Span, revealHash bool) []schema.PIIEntity {
|
||||
out := make([]schema.PIIEntity, 0, len(spans))
|
||||
for _, s := range spans {
|
||||
source, group := splitPatternID(s.Pattern)
|
||||
e := schema.PIIEntity{
|
||||
EntityType: group,
|
||||
Source: source,
|
||||
Start: s.Start,
|
||||
End: s.End,
|
||||
Score: s.Score,
|
||||
Action: string(s.Action),
|
||||
}
|
||||
if revealHash {
|
||||
e.HashPrefix = s.HashPrefix
|
||||
}
|
||||
out = append(out, e)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// splitPatternID splits "ner:EMAIL" into ("ner", "EMAIL"). A value with no
|
||||
// colon is returned as (group, "") inverted to ("", value) so the group is
|
||||
// never lost.
|
||||
func splitPatternID(patternID string) (source, group string) {
|
||||
if i := strings.IndexByte(patternID, ':'); i >= 0 {
|
||||
return patternID[:i], patternID[i+1:]
|
||||
}
|
||||
return "", patternID
|
||||
}
|
||||
|
||||
// recordPIIEvents persists one audit event per span, tagged with the calling
|
||||
// API as its Origin so /api/pii/events can be filtered to this surface. Mirrors
|
||||
// the per-span recording the chat middleware does. Best-effort: a store error
|
||||
// is logged by the store layer, not surfaced to the caller.
|
||||
func recordPIIEvents(store pii.EventStore, spans []pii.Span, origin pii.Origin, correlationID, userID string) {
|
||||
if store == nil {
|
||||
return
|
||||
}
|
||||
for _, s := range spans {
|
||||
_ = store.Record(context.Background(), pii.PIIEvent{
|
||||
ID: pii.NewEventID(),
|
||||
Kind: pii.KindPII,
|
||||
Origin: origin,
|
||||
CorrelationID: correlationID,
|
||||
UserID: userID,
|
||||
Direction: pii.DirectionIn,
|
||||
PatternID: s.Pattern,
|
||||
ByteOffset: s.Start,
|
||||
Length: s.End - s.Start,
|
||||
HashPrefix: s.HashPrefix,
|
||||
Action: s.Action,
|
||||
Score: s.Score,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// piiScanError maps a RunPIIScan error to an HTTP response. Selection/naming
|
||||
// errors are client errors (400); a detector that resolved but failed at scan
|
||||
// time is a fail-closed dependency error (502) — the text is never returned
|
||||
// unredacted.
|
||||
func piiScanError(c echo.Context, err error) error {
|
||||
if errors.Is(err, ErrNoDetectors) || errors.Is(err, ErrUnknownDetector) {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{
|
||||
"error": map[string]string{"message": err.Error(), "type": "invalid_request"},
|
||||
})
|
||||
}
|
||||
return c.JSON(http.StatusBadGateway, map[string]any{
|
||||
"error": map[string]string{"message": err.Error(), "type": "pii_detector_error"},
|
||||
})
|
||||
}
|
||||
|
||||
// piiViewer resolves the request's user (the authenticated user, or the
|
||||
// synthetic local admin in single-user mode) so the handlers can attribute
|
||||
// events and gate the admin-only hash reveal.
|
||||
func piiViewer(c echo.Context, app *application.Application) *auth.User {
|
||||
if u := auth.GetUser(c); u != nil {
|
||||
return u
|
||||
}
|
||||
return app.FallbackUser()
|
||||
}
|
||||
|
||||
// PIIAnalyzeEndpoint scans text and returns the detected PII entities without
|
||||
// mutating it. Always 200 (detection, not enforcement); Blocked reports
|
||||
// whether the redact endpoint would reject the same text.
|
||||
// @Summary Detect PII entities in a string (no mutation).
|
||||
// @Description Runs the configured PII detectors (NER and/or pattern tiers) over the supplied text and returns the matched entity spans with the policy action that would fire. Detection only — the text is not modified and no block is enforced. Select detectors explicitly via `detectors`, or pass a consuming `model` to use its effective policy: the model's own `pii.detectors`, else the instance-wide `pii_default_detectors`. A model with PII disabled, or enabled with nothing to scan with, is a 400. The raw matched value is never returned; admins may set `reveal:true` for the audit hash prefix.
|
||||
// @Tags pii
|
||||
// @Param request body schema.PIIAnalyzeRequest true "text + detector selection"
|
||||
// @Success 200 {object} schema.PIIAnalyzeResponse "Detected entities"
|
||||
// @Router /api/pii/analyze [post]
|
||||
func PIIAnalyzeEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req schema.PIIAnalyzeRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{
|
||||
"error": map[string]string{"message": "invalid request body", "type": "invalid_request"},
|
||||
})
|
||||
}
|
||||
viewer := piiViewer(c, app)
|
||||
if viewer == nil {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"})
|
||||
}
|
||||
|
||||
correlationID := pii.NewEventID()
|
||||
res, err := RunPIIScan(c.Request().Context(), app.PIINERResolver(), app.ModelConfigLoader(), app.PIIPolicyResolver(), req.Detectors, req.Model, req.Text)
|
||||
if err != nil {
|
||||
return piiScanError(c, err)
|
||||
}
|
||||
|
||||
recordPIIEvents(app.PIIEvents(), res.Spans, pii.OriginAnalyzeAPI, correlationID, viewer.ID)
|
||||
revealHash := req.Reveal && viewer.Role == auth.RoleAdmin
|
||||
return c.JSON(http.StatusOK, schema.PIIAnalyzeResponse{
|
||||
Entities: piiEntities(res.Spans, revealHash),
|
||||
Blocked: res.Blocked,
|
||||
CorrelationID: correlationID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// PIIRedactEndpoint scans text and applies the configured mask/block/allow
|
||||
// policy. Returns the redacted text (200), or 400 with type "pii_blocked" and
|
||||
// the offending entities when a block action fires — never a redacted body in
|
||||
// that case. Mirrors the inline middleware's block contract.
|
||||
// @Summary Redact PII in a string by applying the configured policy.
|
||||
// @Description Runs the configured PII detectors over the text and applies each detector model's policy: masked spans are replaced with `[REDACTED:<id>]`, allow spans pass through, and a single block action causes a 400 (type `pii_blocked`) carrying the offending entities — the text is never returned in that case. Select detectors via `detectors`, or a consuming `model`'s effective policy (its own `pii.detectors`, else the instance-wide `pii_default_detectors`; PII must be enabled on the model). Records audit events (origin `pii_redact`) visible at /api/pii/events.
|
||||
// @Tags pii
|
||||
// @Param request body schema.PIIAnalyzeRequest true "text + detector selection"
|
||||
// @Success 200 {object} schema.PIIRedactResponse "Redacted text + entities"
|
||||
// @Router /api/pii/redact [post]
|
||||
func PIIRedactEndpoint(app *application.Application) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req schema.PIIAnalyzeRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{
|
||||
"error": map[string]string{"message": "invalid request body", "type": "invalid_request"},
|
||||
})
|
||||
}
|
||||
viewer := piiViewer(c, app)
|
||||
if viewer == nil {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"})
|
||||
}
|
||||
|
||||
correlationID := pii.NewEventID()
|
||||
res, err := RunPIIScan(c.Request().Context(), app.PIINERResolver(), app.ModelConfigLoader(), app.PIIPolicyResolver(), req.Detectors, req.Model, req.Text)
|
||||
if err != nil {
|
||||
return piiScanError(c, err)
|
||||
}
|
||||
|
||||
recordPIIEvents(app.PIIEvents(), res.Spans, pii.OriginRedactAPI, correlationID, viewer.ID)
|
||||
revealHash := req.Reveal && viewer.Role == auth.RoleAdmin
|
||||
entities := piiEntities(res.Spans, revealHash)
|
||||
|
||||
if res.Blocked {
|
||||
// Fail closed: a block action returns no redacted text, only the
|
||||
// reason and the offending entities — identical to the middleware.
|
||||
return c.JSON(http.StatusBadRequest, map[string]any{
|
||||
"error": map[string]string{"message": "text blocked by content policy (sensitive data detected)", "type": "pii_blocked"},
|
||||
"entities": entities,
|
||||
"correlation_id": correlationID,
|
||||
})
|
||||
}
|
||||
return c.JSON(http.StatusOK, schema.PIIRedactResponse{
|
||||
RedactedText: res.Redacted,
|
||||
Entities: entities,
|
||||
Blocked: false,
|
||||
Masked: res.Masked,
|
||||
CorrelationID: correlationID,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
package localai
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
|
||||
// PIIDecideEndpoint exposes the PII redactor as a decision oracle:
|
||||
// scan the supplied text and return findings + the strongest action
|
||||
// the configured pattern set would take, without rewriting the
|
||||
// caller's request or recording an audit event.
|
||||
//
|
||||
// External routers (e.g. the localai-org/platform router) call this
|
||||
// before dispatching to learn whether to mask the prompt in place,
|
||||
// block the request, or pass it through. LocalAI's in-band PII
|
||||
// middleware is the alternative path for direct-to-LocalAI clients —
|
||||
// same Redactor, different framing.
|
||||
//
|
||||
// Takes the *pii.Redactor directly rather than the whole
|
||||
// *application.Application so the handler stays unit-testable with a
|
||||
// freshly-constructed redactor (mirrors the pattern in
|
||||
// router_decide.go). The route-registration site is responsible for
|
||||
// stubbing this endpoint when --disable-pii is set so callers get a
|
||||
// 503 signalling "admin opted out" rather than a misleading allow.
|
||||
//
|
||||
// @Summary Scan text for PII and return findings + suggested action (decision oracle)
|
||||
// @Tags pii
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body schema.PIIDecideRequest true "decide params"
|
||||
// @Success 200 {object} schema.PIIDecideResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Router /api/pii/decide [post]
|
||||
func PIIDecideEndpoint(redactor *pii.Redactor) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var req schema.PIIDecideRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "invalid request body: "+err.Error())
|
||||
}
|
||||
if req.Text == "" {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "text is required")
|
||||
}
|
||||
|
||||
res := redactor.Redact(req.Text)
|
||||
findings := make([]schema.PIIFinding, len(res.Spans))
|
||||
for i, s := range res.Spans {
|
||||
findings[i] = schema.PIIFinding{
|
||||
Start: s.Start,
|
||||
End: s.End,
|
||||
Pattern: s.Pattern,
|
||||
HashPrefix: s.HashPrefix,
|
||||
}
|
||||
}
|
||||
return c.JSON(http.StatusOK, schema.PIIDecideResponse{
|
||||
Findings: findings,
|
||||
SuggestedAction: suggestedAction(res),
|
||||
RedactedPreview: res.Redacted,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// suggestedAction collapses the Redactor's Result flags onto a single
|
||||
// wire-format action using the in-band ordering (block > mask >
|
||||
// allow). "allow" covers both "nothing matched" and "matched but every
|
||||
// span resolved to the allow action" — in both cases the caller may
|
||||
// dispatch unchanged, with the Findings list reporting what was seen.
|
||||
func suggestedAction(res pii.Result) string {
|
||||
switch {
|
||||
case res.Blocked:
|
||||
return string(pii.ActionBlock)
|
||||
case res.Masked:
|
||||
return string(pii.ActionMask)
|
||||
default:
|
||||
return string(pii.ActionAllow)
|
||||
}
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// PIIDecideEndpoint exposes the redactor as a decision oracle. These
|
||||
// specs pin the validation surface and the suggested_action mapping
|
||||
// across the three actions (allow/mask/block). The redactor itself is
|
||||
// covered in core/services/routing/pii/redactor_test.go.
|
||||
|
||||
var _ = Describe("PIIDecideEndpoint", func() {
|
||||
var redactor *pii.Redactor
|
||||
|
||||
BeforeEach(func() {
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
redactor = pii.NewRedactor(patterns)
|
||||
})
|
||||
|
||||
It("rejects requests with no text field", func() {
|
||||
rec, _ := invokePIIDecide(redactor, `{}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
Expect(rec.Body.String()).To(ContainSubstring("text is required"))
|
||||
})
|
||||
|
||||
It("rejects malformed JSON", func() {
|
||||
rec, _ := invokePIIDecide(redactor, `not json`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
})
|
||||
|
||||
It("returns allow for clean text", func() {
|
||||
rec, body := invokePIIDecide(redactor, `{"text":"hello world"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.SuggestedAction).To(Equal("allow"))
|
||||
Expect(body.Findings).To(BeEmpty())
|
||||
Expect(body.RedactedPreview).To(Equal("hello world"))
|
||||
})
|
||||
|
||||
It("returns mask for text containing email (default action)", func() {
|
||||
rec, body := invokePIIDecide(redactor, `{"text":"reach me at alice@example.com please"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.SuggestedAction).To(Equal("mask"))
|
||||
Expect(body.Findings).To(HaveLen(1))
|
||||
Expect(body.Findings[0].Pattern).To(Equal("email"))
|
||||
Expect(body.Findings[0].HashPrefix).NotTo(BeEmpty())
|
||||
Expect(body.RedactedPreview).To(ContainSubstring("[REDACTED:email]"))
|
||||
Expect(body.RedactedPreview).NotTo(ContainSubstring("alice@example.com"))
|
||||
})
|
||||
|
||||
It("returns block when an api_key_prefix is present (block beats mask)", func() {
|
||||
// api_key_prefix defaults to ActionBlock per DefaultPatterns.
|
||||
// Mix in an email so we also confirm the block-action wins
|
||||
// over the mask-action via actionRank.
|
||||
rec, body := invokePIIDecide(redactor, `{"text":"my key is sk-1234567890abcdefghij and email alice@example.com"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.SuggestedAction).To(Equal("block"))
|
||||
Expect(len(body.Findings)).To(BeNumerically(">=", 1))
|
||||
})
|
||||
|
||||
It("returns allow when a matched pattern's action is allow", func() {
|
||||
// Downgrade the email pattern to allow for this test —
|
||||
// exercises the allow branch of suggestedAction: a match is
|
||||
// found, but the strongest action is allow so the suggestion
|
||||
// is "allow" and the text is left intact.
|
||||
Expect(redactor.SetAction("email", pii.ActionAllow)).To(Succeed())
|
||||
rec, body := invokePIIDecide(redactor, `{"text":"contact alice@example.com"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.SuggestedAction).To(Equal("allow"))
|
||||
Expect(body.Findings).To(HaveLen(1), "allow still reports the finding")
|
||||
// allow leaves the original text intact.
|
||||
Expect(body.RedactedPreview).To(ContainSubstring("alice@example.com"))
|
||||
})
|
||||
|
||||
It("never leaks the matched value via HashPrefix", func() {
|
||||
rec, body := invokePIIDecide(redactor, `{"text":"alice@example.com"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Findings).To(HaveLen(1))
|
||||
// HashPrefix is 8 hex chars of sha256 — definitely not the
|
||||
// matched value, but stable so admins can correlate leaks.
|
||||
Expect(body.Findings[0].HashPrefix).To(HaveLen(8))
|
||||
Expect(body.Findings[0].HashPrefix).NotTo(ContainSubstring("alice"))
|
||||
})
|
||||
})
|
||||
|
||||
func invokePIIDecide(redactor *pii.Redactor, body string) (*httptest.ResponseRecorder, schema.PIIDecideResponse) {
|
||||
e := echo.New()
|
||||
e.POST("/api/pii/decide", localai.PIIDecideEndpoint(redactor))
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/pii/decide", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
var parsed schema.PIIDecideResponse
|
||||
if rec.Code == http.StatusOK {
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &parsed)).To(Succeed())
|
||||
}
|
||||
return rec, parsed
|
||||
}
|
||||
258
core/http/endpoints/localai/pii_test.go
Normal file
258
core/http/endpoints/localai/pii_test.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package localai_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
. "github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/pkg/system"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// stubDetector is a fixed NER detector for the resolver-level unit tests.
|
||||
type stubDetector struct {
|
||||
ents []pii.NEREntity
|
||||
err error
|
||||
}
|
||||
|
||||
func (s stubDetector) Detect(_ context.Context, _ string) ([]pii.NEREntity, error) {
|
||||
return s.ents, s.err
|
||||
}
|
||||
|
||||
var _ = Describe("RunPIIScan (resolver + scan core)", func() {
|
||||
ctx := context.Background()
|
||||
|
||||
resolver := func(name string) (pii.NERConfig, bool) {
|
||||
if name != "det" {
|
||||
return pii.NERConfig{}, false
|
||||
}
|
||||
return pii.NERConfig{
|
||||
Detector: stubDetector{ents: []pii.NEREntity{{Group: "EMAIL", Start: 0, End: 5, Score: 0.9}}},
|
||||
EntityActions: map[string]pii.Action{"EMAIL": pii.ActionMask},
|
||||
Source: pii.SourceNER,
|
||||
}, true
|
||||
}
|
||||
|
||||
It("resolves named detectors and returns their spans", func() {
|
||||
res, err := RunPIIScan(ctx, resolver, nil, nil, []string{"det"}, "", "jane@acme.io")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(res.Spans).To(HaveLen(1))
|
||||
Expect(res.Spans[0].Pattern).To(Equal("ner:EMAIL"))
|
||||
Expect(res.Masked).To(BeTrue())
|
||||
})
|
||||
|
||||
It("fails closed with ErrUnknownDetector for an unresolvable name", func() {
|
||||
_, err := RunPIIScan(ctx, resolver, nil, nil, []string{"nope"}, "", "x")
|
||||
Expect(errors.Is(err, ErrUnknownDetector)).To(BeTrue())
|
||||
})
|
||||
|
||||
It("returns ErrNoDetectors when nothing is selected", func() {
|
||||
_, err := RunPIIScan(ctx, resolver, nil, nil, nil, "", "x")
|
||||
Expect(errors.Is(err, ErrNoDetectors)).To(BeTrue())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("PII analyze/redact endpoints", func() {
|
||||
var (
|
||||
app *application.Application
|
||||
e *echo.Echo
|
||||
tmp string
|
||||
cancel context.CancelFunc
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
tmp, err = os.MkdirTemp("", "pii-api-test-*")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
var ctx context.Context
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
|
||||
modelsDir := filepath.Join(tmp, "models")
|
||||
Expect(os.MkdirAll(modelsDir, 0o755)).To(Succeed())
|
||||
|
||||
st, err := system.GetSystemState(
|
||||
system.WithModelPath(modelsDir),
|
||||
system.WithBackendPath(filepath.Join(tmp, "backends")),
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
app, err = application.New(config.WithContext(ctx), config.WithSystemState(st))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// A pattern detector with two deterministic patterns: one blocks, one
|
||||
// masks. No backend is loaded — the pattern tier runs in-process.
|
||||
detYAML := `name: secret-filter
|
||||
backend: pattern
|
||||
pii_detection:
|
||||
default_action: mask
|
||||
patterns:
|
||||
- name: SECRET
|
||||
match: "sk-test-[A-Za-z0-9]+"
|
||||
action: block
|
||||
- name: TOKEN
|
||||
match: "tok-[A-Za-z0-9]+"
|
||||
action: mask
|
||||
`
|
||||
// A consuming model that opts into the detector, for the model-fallback path.
|
||||
consumerYAML := `name: chatmodel
|
||||
pii:
|
||||
enabled: true
|
||||
detectors: [secret-filter]
|
||||
`
|
||||
// PII-enabled but names no detectors: scanned only when the
|
||||
// instance-wide default detectors are set, else a 400.
|
||||
defaultsYAML := `name: defaultsmodel
|
||||
pii:
|
||||
enabled: true
|
||||
`
|
||||
// Lists detectors but never enables PII — the middleware ignores it,
|
||||
// so the model path must too.
|
||||
disabledYAML := `name: disabledmodel
|
||||
pii:
|
||||
detectors: [secret-filter]
|
||||
`
|
||||
detPath := filepath.Join(modelsDir, "secret-filter.yaml")
|
||||
consumerPath := filepath.Join(modelsDir, "chatmodel.yaml")
|
||||
defaultsPath := filepath.Join(modelsDir, "defaultsmodel.yaml")
|
||||
disabledPath := filepath.Join(modelsDir, "disabledmodel.yaml")
|
||||
Expect(os.WriteFile(detPath, []byte(detYAML), 0o644)).To(Succeed())
|
||||
Expect(os.WriteFile(consumerPath, []byte(consumerYAML), 0o644)).To(Succeed())
|
||||
Expect(os.WriteFile(defaultsPath, []byte(defaultsYAML), 0o644)).To(Succeed())
|
||||
Expect(os.WriteFile(disabledPath, []byte(disabledYAML), 0o644)).To(Succeed())
|
||||
Expect(app.ModelConfigLoader().ReadModelConfig(detPath)).To(Succeed())
|
||||
Expect(app.ModelConfigLoader().ReadModelConfig(consumerPath)).To(Succeed())
|
||||
Expect(app.ModelConfigLoader().ReadModelConfig(defaultsPath)).To(Succeed())
|
||||
Expect(app.ModelConfigLoader().ReadModelConfig(disabledPath)).To(Succeed())
|
||||
|
||||
e = echo.New()
|
||||
e.POST("/api/pii/analyze", PIIAnalyzeEndpoint(app))
|
||||
e.POST("/api/pii/redact", PIIRedactEndpoint(app))
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
cancel()
|
||||
Expect(os.RemoveAll(tmp)).To(Succeed())
|
||||
})
|
||||
|
||||
post := func(path, body string) *httptest.ResponseRecorder {
|
||||
req := httptest.NewRequest(http.MethodPost, path, bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
e.ServeHTTP(rec, req)
|
||||
return rec
|
||||
}
|
||||
|
||||
It("analyze reports a block-class entity without mutating text (200)", func() {
|
||||
rec := post("/api/pii/analyze", `{"text":"my key sk-test-abc123 ok","detectors":["secret-filter"]}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp struct {
|
||||
Entities []struct {
|
||||
EntityType string `json:"entity_type"`
|
||||
Source string `json:"source"`
|
||||
Action string `json:"action"`
|
||||
} `json:"entities"`
|
||||
Blocked bool `json:"blocked"`
|
||||
}
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp.Blocked).To(BeTrue())
|
||||
Expect(resp.Entities).To(HaveLen(1))
|
||||
Expect(resp.Entities[0].EntityType).To(Equal("SECRET"))
|
||||
Expect(resp.Entities[0].Source).To(Equal("pattern"))
|
||||
Expect(resp.Entities[0].Action).To(Equal("block"))
|
||||
})
|
||||
|
||||
It("redact masks a mask-class match and returns redacted text (200)", func() {
|
||||
rec := post("/api/pii/redact", `{"text":"here is tok-xyz789 done","detectors":["secret-filter"]}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
|
||||
var resp struct {
|
||||
RedactedText string `json:"redacted_text"`
|
||||
Masked bool `json:"masked"`
|
||||
Blocked bool `json:"blocked"`
|
||||
}
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp.Masked).To(BeTrue())
|
||||
Expect(resp.Blocked).To(BeFalse())
|
||||
Expect(resp.RedactedText).To(ContainSubstring("[REDACTED:pattern:TOKEN]"))
|
||||
Expect(resp.RedactedText).ToNot(ContainSubstring("tok-xyz789"))
|
||||
})
|
||||
|
||||
It("redact returns 400 pii_blocked for a block-class match", func() {
|
||||
rec := post("/api/pii/redact", `{"text":"key sk-test-abc123","detectors":["secret-filter"]}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
Expect(rec.Body.String()).To(ContainSubstring("pii_blocked"))
|
||||
// The raw secret must never appear in the block response.
|
||||
Expect(rec.Body.String()).ToNot(ContainSubstring("sk-test-abc123"))
|
||||
})
|
||||
|
||||
It("400s when no detector is selected", func() {
|
||||
rec := post("/api/pii/redact", `{"text":"sk-test-abc123"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
Expect(rec.Body.String()).To(ContainSubstring("invalid_request"))
|
||||
})
|
||||
|
||||
It("resolves detectors from a consuming model via the model field", func() {
|
||||
rec := post("/api/pii/analyze", `{"text":"tok-aaa111","model":"chatmodel"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
var resp struct {
|
||||
Entities []struct {
|
||||
EntityType string `json:"entity_type"`
|
||||
} `json:"entities"`
|
||||
}
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp.Entities).To(HaveLen(1))
|
||||
Expect(resp.Entities[0].EntityType).To(Equal("TOKEN"))
|
||||
})
|
||||
|
||||
It("400s for a PII-enabled model with no detectors and no instance default", func() {
|
||||
rec := post("/api/pii/analyze", `{"text":"tok-aaa111","model":"defaultsmodel"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
Expect(rec.Body.String()).To(ContainSubstring("invalid_request"))
|
||||
})
|
||||
|
||||
It("falls back to the instance-wide default detectors for an enabled model", func() {
|
||||
defaults := []string{"secret-filter"}
|
||||
app.ApplicationConfig().ApplyRuntimeSettings(&config.RuntimeSettings{PIIDefaultDetectors: &defaults})
|
||||
|
||||
rec := post("/api/pii/analyze", `{"text":"tok-aaa111","model":"defaultsmodel"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusOK))
|
||||
var resp struct {
|
||||
Entities []struct {
|
||||
EntityType string `json:"entity_type"`
|
||||
} `json:"entities"`
|
||||
}
|
||||
Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed())
|
||||
Expect(resp.Entities).To(HaveLen(1))
|
||||
Expect(resp.Entities[0].EntityType).To(Equal("TOKEN"))
|
||||
})
|
||||
|
||||
It("400s for a model that lists detectors but has PII disabled, like the middleware", func() {
|
||||
rec := post("/api/pii/analyze", `{"text":"tok-aaa111","model":"disabledmodel"}`)
|
||||
Expect(rec.Code).To(Equal(http.StatusBadRequest))
|
||||
Expect(rec.Body.String()).To(ContainSubstring("invalid_request"))
|
||||
})
|
||||
|
||||
It("records redact-API events with origin pii_redact", func() {
|
||||
_ = post("/api/pii/redact", `{"text":"here is tok-xyz789 done","detectors":["secret-filter"]}`)
|
||||
events, err := app.PIIEvents().List(context.Background(), pii.ListQuery{Origin: pii.OriginRedactAPI})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(len(events)).To(BeNumerically(">=", 1))
|
||||
Expect(events[0].PatternID).To(Equal("pattern:TOKEN"))
|
||||
// Regression: API-recorded events must carry a real timestamp, not the
|
||||
// zero value (the handler, unlike the middleware, originally omitted it).
|
||||
Expect(events[0].CreatedAt.IsZero()).To(BeFalse())
|
||||
})
|
||||
})
|
||||
@@ -22,25 +22,31 @@ type stubClient struct{}
|
||||
func (stubClient) GallerySearch(_ context.Context, _ localaitools.GallerySearchQuery) ([]gallery.Metadata, error) {
|
||||
return []gallery.Metadata{{Name: "stub", Gallery: config.Gallery{Name: "stub-gallery"}}}, nil
|
||||
}
|
||||
|
||||
func (stubClient) ListInstalledModels(_ context.Context, _ localaitools.Capability) ([]localaitools.InstalledModel, error) {
|
||||
return []localaitools.InstalledModel{{Name: "stub"}}, nil
|
||||
}
|
||||
|
||||
func (stubClient) ListGalleries(_ context.Context) ([]config.Gallery, error) {
|
||||
return []config.Gallery{{Name: "stub-gallery", URL: "http://example"}}, nil
|
||||
}
|
||||
|
||||
func (stubClient) GetJobStatus(_ context.Context, _ string) (*localaitools.JobStatus, error) {
|
||||
return &localaitools.JobStatus{ID: "stub", Processed: true}, nil
|
||||
}
|
||||
|
||||
func (stubClient) GetModelConfig(_ context.Context, _ string) (*localaitools.ModelConfigView, error) {
|
||||
return &localaitools.ModelConfigView{Name: "stub"}, nil
|
||||
}
|
||||
|
||||
func (stubClient) InstallModel(_ context.Context, _ localaitools.InstallModelRequest) (string, error) {
|
||||
return "stub-job", nil
|
||||
}
|
||||
|
||||
func (stubClient) ImportModelURI(_ context.Context, _ localaitools.ImportModelURIRequest) (*localaitools.ImportModelURIResponse, error) {
|
||||
return &localaitools.ImportModelURIResponse{JobID: "stub-import"}, nil
|
||||
}
|
||||
func (stubClient) DeleteModel(_ context.Context, _ string) error { return nil }
|
||||
func (stubClient) DeleteModel(_ context.Context, _ string) error { return nil }
|
||||
func (stubClient) EditModelConfig(_ context.Context, _ string, _ map[string]any) error {
|
||||
return nil
|
||||
}
|
||||
@@ -48,57 +54,61 @@ func (stubClient) ReloadModels(_ context.Context) error { return nil }
|
||||
func (stubClient) ListBackends(_ context.Context) ([]localaitools.Backend, error) {
|
||||
return []localaitools.Backend{{Name: "stub-backend", Installed: true}}, nil
|
||||
}
|
||||
|
||||
func (stubClient) ListKnownBackends(_ context.Context) ([]schema.KnownBackend, error) {
|
||||
return []schema.KnownBackend{}, nil
|
||||
}
|
||||
|
||||
func (stubClient) InstallBackend(_ context.Context, _ localaitools.InstallBackendRequest) (string, error) {
|
||||
return "stub-backend-job", nil
|
||||
}
|
||||
|
||||
func (stubClient) UpgradeBackend(_ context.Context, _ string) (string, error) {
|
||||
return "stub-upgrade-job", nil
|
||||
}
|
||||
|
||||
func (stubClient) SystemInfo(_ context.Context) (*localaitools.SystemInfo, error) {
|
||||
return &localaitools.SystemInfo{Version: "stub"}, nil
|
||||
}
|
||||
|
||||
func (stubClient) ListNodes(_ context.Context) ([]localaitools.Node, error) {
|
||||
return []localaitools.Node{}, nil
|
||||
}
|
||||
|
||||
func (stubClient) VRAMEstimate(_ context.Context, _ localaitools.VRAMEstimateRequest) (*vram.EstimateResult, error) {
|
||||
return &vram.EstimateResult{SizeDisplay: "stub"}, nil
|
||||
}
|
||||
func (stubClient) ToggleModelState(_ context.Context, _ string, _ modeladmin.Action) error { return nil }
|
||||
func (stubClient) ToggleModelPinned(_ context.Context, _ string, _ modeladmin.Action) error { return nil }
|
||||
func (stubClient) ToggleModelState(_ context.Context, _ string, _ modeladmin.Action) error {
|
||||
return nil
|
||||
}
|
||||
func (stubClient) ToggleModelPinned(_ context.Context, _ string, _ modeladmin.Action) error {
|
||||
return nil
|
||||
}
|
||||
func (stubClient) GetBranding(_ context.Context) (*localaitools.Branding, error) {
|
||||
return &localaitools.Branding{InstanceName: "LocalAI"}, nil
|
||||
}
|
||||
|
||||
func (stubClient) SetBranding(_ context.Context, _ localaitools.SetBrandingRequest) (*localaitools.Branding, error) {
|
||||
return &localaitools.Branding{InstanceName: "LocalAI"}, nil
|
||||
}
|
||||
|
||||
func (stubClient) GetUsageStats(_ context.Context, _ localaitools.UsageStatsQuery) (*localaitools.UsageStats, error) {
|
||||
return &localaitools.UsageStats{Viewer: localaitools.UsageViewer{ID: "stub", Name: "stub"}, Period: "month"}, nil
|
||||
}
|
||||
func (stubClient) ListPIIPatterns(_ context.Context) ([]localaitools.PIIPattern, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (stubClient) GetPIIEvents(_ context.Context, _ localaitools.PIIEventsQuery) ([]localaitools.PIIEvent, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (stubClient) TestPIIRedaction(_ context.Context, req localaitools.PIIRedactTestRequest) (*localaitools.PIIRedactTestResult, error) {
|
||||
return &localaitools.PIIRedactTestResult{Redacted: req.Text}, nil
|
||||
}
|
||||
func (stubClient) SetPIIPatternAction(_ context.Context, _ localaitools.PIIPatternActionUpdate) error {
|
||||
return nil
|
||||
}
|
||||
func (stubClient) PersistPIIPatterns(_ context.Context) error { return nil }
|
||||
|
||||
func (stubClient) GetMiddlewareStatus(_ context.Context) (*localaitools.MiddlewareStatus, error) {
|
||||
return &localaitools.MiddlewareStatus{
|
||||
PII: localaitools.MiddlewarePIIStatus{
|
||||
EnabledGlobally: true,
|
||||
Patterns: []localaitools.PIIPattern{},
|
||||
Models: []localaitools.MiddlewarePIIModel{},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (stubClient) GetRouterDecisions(_ context.Context, _ localaitools.RouterDecisionsQuery) ([]localaitools.RouterDecision, error) {
|
||||
return []localaitools.RouterDecision{}, nil
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
reason "github.com/mudler/LocalAI/pkg/reasoning"
|
||||
|
||||
@@ -130,7 +129,7 @@ func applyAutoparserOverride(
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/chat/completions [post]
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, assistantHolder *mcpTools.LocalAIAssistantHolder, piiRedactor *pii.Redactor, piiEvents pii.EventStore) echo.HandlerFunc {
|
||||
func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig, natsClient mcpTools.MCPNATSClient, assistantHolder *mcpTools.LocalAIAssistantHolder) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
var textContentToReturn string
|
||||
id := uuid.New().String()
|
||||
@@ -152,11 +151,11 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
|
||||
// Cloud-proxy bail. Bypasses the local pipeline (templating,
|
||||
// MCP injection, gRPC backend) and forwards via the cloud-
|
||||
// proxy backend, which does the outbound HTTP. The streaming
|
||||
// PII filter still runs because its input is per-token text
|
||||
// extracted from the wire envelope, not the envelope itself.
|
||||
// proxy backend, which does the outbound HTTP. Request-side PII
|
||||
// redaction already ran in the middleware; the response is
|
||||
// forwarded unmodified.
|
||||
if config.IsCloudProxyBackendPassthrough() {
|
||||
return forwardCloudProxyOpenAIViaBackend(c, config, input, piiRedactor, piiEvents, ml, startupOptions)
|
||||
return forwardCloudProxyOpenAIViaBackend(c, config, input, ml, startupOptions)
|
||||
}
|
||||
|
||||
funcs := input.Functions
|
||||
@@ -327,7 +326,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
"message": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The message to reply the user with",
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -393,14 +393,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().Header().Set("X-Correlation-ID", id)
|
||||
|
||||
// Per-stream PII filter: when the resolved model has PII
|
||||
// enabled, wrap the response content so values spanning
|
||||
// chunk boundaries still get masked. Shared with the
|
||||
// cloud-proxy bail below via cloudproxy.BuildStreamFilter
|
||||
// so both paths apply the same per-model gate and override
|
||||
// rules.
|
||||
streamPIIFilter := cloudproxy.BuildStreamFilter(c, config, true, piiRedactor, piiEvents, id)
|
||||
|
||||
mcpStreamMaxIterations := 10
|
||||
if config.Agent.MaxIterations > 0 {
|
||||
mcpStreamMaxIterations = config.Agent.MaxIterations
|
||||
@@ -476,30 +468,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
if (hasMCPToolsStream || config.FunctionsConfig.AutomaticToolParsingFallback) && haveContent {
|
||||
collectedContent += rawContent
|
||||
}
|
||||
// Stream-side PII filter: feed the content delta
|
||||
// through the buffered-emit filter. The filter
|
||||
// holds back a tail to handle pattern boundaries
|
||||
// across chunks, so a Push may legitimately
|
||||
// return "" — drop the chunk in that case rather
|
||||
// than emitting an empty Delta to the wire.
|
||||
if streamPIIFilter != nil && haveContent {
|
||||
filtered := streamPIIFilter.Push(rawContent)
|
||||
if filtered == "" {
|
||||
// Fully buffered — skip this chunk's
|
||||
// content. Still emit non-content chunks
|
||||
// (role, tool_calls). When this delta is
|
||||
// content-only and we buffer it, drop the
|
||||
// whole event to avoid a vestigial
|
||||
// {"delta":{}} on the wire.
|
||||
if ev.Choices[0].Delta.Role == "" && len(ev.Choices[0].Delta.ToolCalls) == 0 && ev.Choices[0].Delta.Reasoning == nil {
|
||||
continue
|
||||
}
|
||||
// Mixed delta — strip content, keep the rest.
|
||||
ev.Choices[0].Delta.Content = nil
|
||||
} else {
|
||||
ev.Choices[0].Delta.Content = filtered
|
||||
}
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
xlog.Debug("Failed to marshal response", "error", err)
|
||||
@@ -644,31 +612,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
}
|
||||
|
||||
// Drain the per-stream PII filter before the stop chunk
|
||||
// so any text held back by the buffered-emit invariant
|
||||
// reaches the client as a regular content delta. We
|
||||
// emit it as a chunk WITHOUT a finish_reason so the
|
||||
// next "stop" chunk still terminates the stream.
|
||||
if streamPIIFilter != nil {
|
||||
residual := streamPIIFilter.Drain()
|
||||
if residual != "" {
|
||||
drainResp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model,
|
||||
Choices: []schema.Choice{{
|
||||
Delta: &schema.Message{Content: residual},
|
||||
Index: 0,
|
||||
}},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
if drainBytes, err := json.Marshal(drainResp); err == nil {
|
||||
_, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", drainBytes)
|
||||
c.Response().Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No MCP tools to execute, send final stop message
|
||||
finishReason := FinishReasonStop
|
||||
if toolsCalled && len(input.Tools) > 0 {
|
||||
@@ -689,7 +632,8 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
FinishReason: &finishReason,
|
||||
Index: 0,
|
||||
Delta: &schema.Message{},
|
||||
}},
|
||||
},
|
||||
},
|
||||
Object: "chat.completion.chunk",
|
||||
}
|
||||
respData, _ := json.Marshal(resp)
|
||||
@@ -1075,7 +1019,6 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator
|
||||
}
|
||||
|
||||
func handleQuestion(config *config.ModelConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) {
|
||||
|
||||
if len(funcResults) == 0 && result != "" {
|
||||
xlog.Debug("nothing function results but we had a message from the LLM")
|
||||
|
||||
@@ -1111,19 +1054,16 @@ func handleQuestion(config *config.ModelConfig, funcResults []functions.FuncCall
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// forwardCloudProxyOpenAIViaBackend marshals the OpenAI request,
|
||||
// constructs the streaming PII filter (when this model has PII
|
||||
// enabled), and hands off to the cloud-proxy gRPC backend which does
|
||||
// the outbound HTTP. The chat endpoint owns the body+filter
|
||||
// construction because it's the only place the request lands as a
|
||||
// parsed *schema.OpenAIRequest.
|
||||
func forwardCloudProxyOpenAIViaBackend(c echo.Context, cfg *config.ModelConfig, input *schema.OpenAIRequest, piiRedactor *pii.Redactor, piiEvents pii.EventStore, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
// forwardCloudProxyOpenAIViaBackend marshals the OpenAI request and
|
||||
// hands off to the cloud-proxy gRPC backend which does the outbound
|
||||
// HTTP. The chat endpoint owns the body construction because it's the
|
||||
// only place the request lands as a parsed *schema.OpenAIRequest.
|
||||
// Request-side PII redaction already ran in the middleware; the
|
||||
// response is forwarded unmodified.
|
||||
func forwardCloudProxyOpenAIViaBackend(c echo.Context, cfg *config.ModelConfig, input *schema.OpenAIRequest, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error {
|
||||
body, err := json.Marshal(input)
|
||||
if err != nil {
|
||||
return echo.NewHTTPError(http.StatusBadRequest, "cloudproxy: marshal request: "+err.Error())
|
||||
}
|
||||
|
||||
correlationID := c.Response().Header().Get("X-Correlation-ID")
|
||||
streamFilter := cloudproxy.BuildStreamFilter(c, cfg, input.Stream, piiRedactor, piiEvents, correlationID)
|
||||
return cloudproxy.ForwardViaBackend(c, cfg, body, streamFilter, ml, appConfig)
|
||||
return cloudproxy.ForwardViaBackend(c, cfg, body, ml, appConfig)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,10 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/templates"
|
||||
"github.com/mudler/LocalAI/pkg/functions"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
@@ -27,7 +25,7 @@ import (
|
||||
// @Param request body schema.OpenAIRequest true "query params"
|
||||
// @Success 200 {object} schema.OpenAIResponse "Response"
|
||||
// @Router /v1/completions [post]
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig, piiRedactor *pii.Redactor, piiEvents pii.EventStore) echo.HandlerFunc {
|
||||
func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc {
|
||||
process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error {
|
||||
tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool {
|
||||
created := int(time.Now().Unix())
|
||||
@@ -70,7 +68,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
}
|
||||
|
||||
return func(c echo.Context) error {
|
||||
|
||||
created := int(time.Now().Unix())
|
||||
|
||||
// Handle Correlation
|
||||
@@ -113,31 +110,8 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming")
|
||||
}
|
||||
|
||||
// Per-stream PII filter — same gating as chat. /v1/completions
|
||||
// has no chat-message structure, so request-side PII isn't
|
||||
// wired here, but the response-side filter still catches PII
|
||||
// trained into the model. Filter is nil when this model has
|
||||
// PII disabled.
|
||||
var streamPIIFilter *pii.StreamFilter
|
||||
if piiRedactor != nil && config.PIIIsEnabled() {
|
||||
correlationID := id
|
||||
userID := ""
|
||||
if u := auth.GetUser(c); u != nil {
|
||||
userID = u.ID
|
||||
}
|
||||
var overrides map[string]pii.Action
|
||||
if raw := config.PIIPatternOverrides(); len(raw) > 0 {
|
||||
overrides = make(map[string]pii.Action, len(raw))
|
||||
for ovid, action := range raw {
|
||||
switch pii.Action(action) {
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionAllow:
|
||||
overrides[ovid] = pii.Action(action)
|
||||
}
|
||||
}
|
||||
}
|
||||
streamPIIFilter = pii.NewStreamFilter(piiRedactor, overrides, piiEvents, correlationID, userID)
|
||||
}
|
||||
|
||||
// Response/output PII redaction is out of scope for now —
|
||||
// redaction runs request-side via the NER middleware only.
|
||||
predInput := config.PromptStrings[0]
|
||||
|
||||
templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{
|
||||
@@ -179,19 +153,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
// OpenAI streaming spec: intermediate chunks must NOT
|
||||
// carry a `usage` field. Strip the tracking copy now.
|
||||
ev.Usage = nil
|
||||
// Run the per-chunk text through the streaming PII
|
||||
// filter. The filter holds back a tail to handle
|
||||
// pattern boundaries, so a Push may legitimately
|
||||
// return "" — drop the chunk's text rather than
|
||||
// emitting a 0-token delta. Choice.Text is the only
|
||||
// content surface in /v1/completions chunks.
|
||||
if streamPIIFilter != nil && ev.Choices[0].Text != "" {
|
||||
filtered := streamPIIFilter.Push(ev.Choices[0].Text)
|
||||
if filtered == "" {
|
||||
continue
|
||||
}
|
||||
ev.Choices[0].Text = filtered
|
||||
}
|
||||
respData, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
xlog.Debug("Failed to marshal response", "error", err)
|
||||
@@ -237,25 +198,6 @@ func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, eva
|
||||
}
|
||||
}
|
||||
|
||||
// Flush any residual the streaming PII filter held back as
|
||||
// part of its trailing pattern-window. Emit it as one final
|
||||
// text-bearing chunk before the synthetic stop chunk so the
|
||||
// completion body remains a contiguous text stream.
|
||||
if streamPIIFilter != nil {
|
||||
if residual := streamPIIFilter.Drain(); residual != "" {
|
||||
residualResp := schema.OpenAIResponse{
|
||||
ID: id,
|
||||
Created: created,
|
||||
Model: input.Model,
|
||||
Choices: []schema.Choice{{Index: 0, Text: residual}},
|
||||
Object: "text_completion",
|
||||
}
|
||||
if data, err := json.Marshal(residualResp); err == nil {
|
||||
_, _ = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stopReason := FinishReasonStop
|
||||
resp := &schema.OpenAIResponse{
|
||||
ID: id,
|
||||
|
||||
@@ -391,18 +391,12 @@ func buildClassifier(cfg *config.ModelConfig, deps ClassifierDeps) (router.Class
|
||||
}
|
||||
|
||||
// assertClassifierDeclaresScore refuses to build the score classifier
|
||||
// unless classifier_model's config declares FLAG_SCORE. The actual
|
||||
// usecase-conflict check (score + chat/completion/embeddings on
|
||||
// llama-cpp) lives in ModelConfig.Validate() and fires at config load
|
||||
// and save time — by the time we get here, any model that reached the
|
||||
// loader is already conflict-free. This check just refuses to bind a
|
||||
// model that never declared itself for Score in the first place; that
|
||||
// model could be a misconfigured chat model the operator pointed at
|
||||
// by accident, and without FLAG_SCORE the validator never saw it.
|
||||
// unless classifier_model's config declares FLAG_SCORE. This check only
|
||||
// refuses to bind a model that never declared itself for Score in the
|
||||
// first place; that model could be a misconfigured chat model the
|
||||
// operator pointed at by accident.
|
||||
//
|
||||
// When lookup is nil (test wiring) the check is skipped and we fall
|
||||
// back to the C++ backend's runtime tripwire as the last line of
|
||||
// defence.
|
||||
// When lookup is nil (test wiring) the check is skipped.
|
||||
func assertClassifierDeclaresScore(classifierModel string, lookup ModelConfigLookup) error {
|
||||
if lookup == nil {
|
||||
return nil
|
||||
@@ -416,8 +410,8 @@ func assertClassifierDeclaresScore(classifierModel string, lookup ModelConfigLoo
|
||||
if !cfg.HasUsecases(config.FLAG_SCORE) {
|
||||
return fmt.Errorf(
|
||||
"router classifier score: classifier_model %q does not declare the "+
|
||||
"score usecase. Add `known_usecases: [score]` to its config so "+
|
||||
"the loader can reject conflicting usecase combinations",
|
||||
"score usecase. Add `known_usecases: [score]` (alongside any other "+
|
||||
"usecases the model serves) to its config",
|
||||
classifierModel)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -1,24 +1,29 @@
|
||||
import { test, expect } from '@playwright/test'
|
||||
|
||||
// Mocked fixture covering the three things the page renders:
|
||||
// - PII pattern catalogue (action badges, action-change buttons)
|
||||
// - Per-model resolved PII state (one with default off, one with proxy default on, one with explicit YAML)
|
||||
// Mocked fixture covering the things the page renders:
|
||||
// - Per-model resolved PII state + the NER detectors each references
|
||||
// (one with default off, one with proxy default on, one explicit YAML)
|
||||
// - Recent events feed (the page must NEVER show the redacted content)
|
||||
const MOCK_STATUS = {
|
||||
pii: {
|
||||
enabled_globally: true,
|
||||
default_enabled_for_backends: ['cloud-proxy'],
|
||||
patterns: [
|
||||
{ id: 'email', description: 'Email addresses', action: 'mask', max_match_length: 254 },
|
||||
{ id: 'ssn', description: 'US Social Security Numbers', action: 'mask', max_match_length: 11 },
|
||||
{ id: 'api_key_prefix', description: 'API key prefixes', action: 'block', max_match_length: 200 },
|
||||
],
|
||||
models: [
|
||||
{ name: 'qwen-7b', backend: 'llama-cpp', enabled: false, explicit: false, default_for_backend: false, overrides: null },
|
||||
{ name: 'claude-sonnet', backend: 'cloud-proxy', enabled: true, explicit: false, default_for_backend: true, overrides: null },
|
||||
{ name: 'claude-strict', backend: 'cloud-proxy', enabled: true, explicit: true, default_for_backend: true, overrides: { ssn: 'block' } },
|
||||
{ name: 'qwen-7b', backend: 'llama-cpp', enabled: false, explicit: false, default_for_backend: false, detectors: null },
|
||||
{ name: 'claude-sonnet', backend: 'cloud-proxy', enabled: true, explicit: false, default_for_backend: true, detectors: null },
|
||||
{ name: 'claude-strict', backend: 'cloud-proxy', enabled: true, explicit: true, default_for_backend: true, detectors: ['privacy-filter-multilingual'] },
|
||||
],
|
||||
recent_event_count: 2,
|
||||
// Instance-wide default detector set (managed by the Detector models
|
||||
// table's per-row Default toggle).
|
||||
default_detectors: ['global-ner-default'],
|
||||
// The token_classify "filter" models themselves: one NER, one in-process
|
||||
// pattern matcher, plus an orphan default that names a model not loaded.
|
||||
detector_models: [
|
||||
{ name: 'privacy-filter-multilingual', backend: 'llama-cpp', type: 'ner', default: false },
|
||||
{ name: 'secret-filter', backend: 'pattern', type: 'pattern', default: false },
|
||||
{ name: 'global-ner-default', backend: '', type: 'unknown', default: true, missing: true },
|
||||
],
|
||||
},
|
||||
router: {
|
||||
configured: true,
|
||||
@@ -114,23 +119,104 @@ test.describe('Middleware page — admin in no-auth mode', () => {
|
||||
await page.route('**/api/router/decisions?**', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify(MOCK_DECISIONS) })
|
||||
)
|
||||
// The Default PII policy detector picker is capability-filtered to
|
||||
// token_classify via /api/models/capabilities.
|
||||
await page.route('**/api/models/capabilities', (route) =>
|
||||
route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify({ models: [{ id: 'privacy-filter-multilingual', capabilities: ['FLAG_TOKEN_CLASSIFY'], backend: 'llama-cpp' }] }),
|
||||
})
|
||||
)
|
||||
await page.route('**/api/settings', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify({ success: true }) })
|
||||
)
|
||||
// The per-model PII toggle PATCHes the model config (pii.enabled).
|
||||
await page.route('**/api/models/config-json/**', (route) =>
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify({ success: true }) })
|
||||
)
|
||||
})
|
||||
|
||||
test('Filtering tab renders pattern catalogue and per-model state', async ({ page }) => {
|
||||
test('Filtering tab renders per-model state and referenced detectors', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
|
||||
// Pattern table — at least one pattern id visible.
|
||||
await expect(page.getByText('email').first()).toBeVisible()
|
||||
await expect(page.getByText('api_key_prefix').first()).toBeVisible()
|
||||
|
||||
// Per-model state — each model's name is visible.
|
||||
await expect(page.getByText('qwen-7b').first()).toBeVisible()
|
||||
await expect(page.getByText('claude-strict').first()).toBeVisible()
|
||||
|
||||
// The detector a model references is shown in its row.
|
||||
await expect(page.getByText('privacy-filter-multilingual').first()).toBeVisible()
|
||||
|
||||
// Default-policy banner names the backends with PII on by default.
|
||||
await expect(page.getByText(/cloud-proxy/).first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('Filtering tab lists detector models with type badges and a default toggle', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
|
||||
// The Detector models card renders every token_classify filter model.
|
||||
await expect(page.getByText('Detector models')).toBeVisible()
|
||||
const nerRow = page.locator('tr').filter({ hasText: 'privacy-filter-multilingual' }).first()
|
||||
await expect(nerRow).toContainText(/NER/i)
|
||||
const patternRow = page.locator('tr').filter({ hasText: 'secret-filter' }).first()
|
||||
await expect(patternRow).toContainText(/pattern/i)
|
||||
|
||||
// The NER detector is not (yet) a default — its toggle is unchecked.
|
||||
// (The underlying checkbox is 0×0 by design, so we click the label wrapper.)
|
||||
const nerToggle = nerRow.locator('label.toggle')
|
||||
await expect(nerToggle.locator('input[type="checkbox"]')).not.toBeChecked()
|
||||
|
||||
// Toggling it on persists the new default set via POST /api/settings.
|
||||
const saved = page.waitForRequest(req =>
|
||||
req.url().includes('/api/settings') && req.method() === 'POST')
|
||||
await nerToggle.click()
|
||||
const req = await saved
|
||||
const body = JSON.parse(req.postData() || '{}')
|
||||
expect(body.pii_default_detectors).toContain('privacy-filter-multilingual')
|
||||
})
|
||||
|
||||
test('Filtering tab surfaces an orphan default detector that is not loaded', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
|
||||
// global-ner-default names a model that is not loaded, but it is in the
|
||||
// default set — it must still appear (toggled on) so admins can remove it.
|
||||
const orphanRow = page.locator('tr').filter({ hasText: 'global-ner-default' }).first()
|
||||
await expect(orphanRow).toContainText(/not loaded/i)
|
||||
await expect(orphanRow.locator('label.toggle input[type="checkbox"]')).toBeChecked()
|
||||
})
|
||||
|
||||
test('Filtering tab flags an enabled model with no detector as a no-op', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
|
||||
// claude-sonnet is enabled by the cloud-proxy backend default but lists
|
||||
// no detectors and there is no instance default detector — it scans
|
||||
// nothing, so the row must warn rather than read as protected.
|
||||
const noopRow = page.locator('tr').filter({ hasText: 'claude-sonnet' }).first()
|
||||
await expect(noopRow).toContainText(/no-op/i)
|
||||
|
||||
// claude-strict has an explicit detector — it must NOT be flagged.
|
||||
const okRow = page.locator('tr').filter({ hasText: 'claude-strict' }).first()
|
||||
await expect(okRow).not.toContainText(/no-op/i)
|
||||
})
|
||||
|
||||
test('Filtering tab PII column toggles a model\'s pii.enabled via PATCH', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
|
||||
// qwen-7b is OFF (enabled:false) — its PII toggle reads unchecked.
|
||||
const row = page.locator('tr').filter({ hasText: 'qwen-7b' }).first()
|
||||
const toggle = row.locator('label.toggle')
|
||||
await expect(toggle.locator('input[type="checkbox"]')).not.toBeChecked()
|
||||
|
||||
// Toggling on PATCHes the model config with an explicit pii.enabled:true,
|
||||
// scoped to that model (no other field is sent — the server deep-merges).
|
||||
const patched = page.waitForRequest(req =>
|
||||
req.url().includes('/api/models/config-json/') && req.method() === 'PATCH')
|
||||
await toggle.click()
|
||||
const req = await patched
|
||||
expect(decodeURIComponent(req.url())).toContain('qwen-7b')
|
||||
const body = JSON.parse(req.postData() || '{}')
|
||||
expect(body.pii.enabled).toBe(true)
|
||||
})
|
||||
|
||||
test('Routing tab renders configured routers and recent decisions', async ({ page }) => {
|
||||
await page.goto('/app/middleware')
|
||||
await page.getByRole('button', { name: /Routing/i }).click()
|
||||
@@ -265,25 +351,6 @@ test.describe('Middleware page — admin in no-auth mode', () => {
|
||||
await expect(page.getByText(/^proxy traffic$/i).first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('PUT /api/pii/patterns/:id fires when an action button is clicked', async ({ page }) => {
|
||||
let putHit = null
|
||||
await page.route('**/api/pii/patterns/email', (route) => {
|
||||
if (route.request().method() === 'PUT') {
|
||||
putHit = JSON.parse(route.request().postData() || '{}')
|
||||
route.fulfill({ contentType: 'application/json', body: JSON.stringify({ id: 'email', action: putHit.action, persisted: false }) })
|
||||
} else {
|
||||
route.continue()
|
||||
}
|
||||
})
|
||||
|
||||
await page.goto('/app/middleware')
|
||||
// Click the email row's "block" button (currently mask, so block is
|
||||
// enabled). Use a precise locator that matches the inner button.
|
||||
const emailRow = page.locator('tr').filter({ hasText: 'email' }).first()
|
||||
await emailRow.getByRole('button', { name: 'block' }).click()
|
||||
|
||||
await expect.poll(() => putHit).toEqual({ action: 'block' })
|
||||
})
|
||||
})
|
||||
|
||||
test.describe('Middleware page — non-admin under auth-on', () => {
|
||||
|
||||
@@ -12,6 +12,9 @@ const MOCK_METADATA = {
|
||||
{ path: 'cuda', yaml_key: 'cuda', go_type: 'bool', ui_type: 'bool', section: 'general', label: 'CUDA', description: 'Enable CUDA GPU acceleration', component: 'toggle', order: 30 },
|
||||
{ path: 'parameters.temperature', yaml_key: 'temperature', go_type: '*float64', ui_type: 'float', section: 'parameters', label: 'Temperature', description: 'Sampling temperature', component: 'slider', min: 0, max: 2, step: 0.1, order: 0 },
|
||||
{ path: 'parameters.top_p', yaml_key: 'top_p', go_type: '*float64', ui_type: 'float', section: 'parameters', label: 'Top P', description: 'Nucleus sampling threshold', component: 'slider', min: 0, max: 1, step: 0.05, order: 10 },
|
||||
{ path: 'pii_detection.builtins', yaml_key: 'builtins', go_type: '[]string', ui_type: '[]string', section: 'general', label: 'Built-in Secret Patterns', description: 'Built-in credential patterns', component: 'pii-builtins-select', options: [{ value: 'anthropic_api_key', label: 'anthropic_api_key — Anthropic API key' }, { value: 'github_token', label: 'github_token — GitHub token' }], order: 213 },
|
||||
{ path: 'pii_detection.patterns', yaml_key: 'patterns', go_type: '[]config.PIIPattern', ui_type: 'object', section: 'general', label: 'Custom Secret Patterns', description: 'Operator-defined restricted-regex patterns', component: 'pii-pattern-list', order: 214 },
|
||||
{ path: 'pii_detection.entity_actions', yaml_key: 'entity_actions', go_type: 'map[string]string', ui_type: 'map', section: 'general', label: 'Detector Entity Actions', description: 'Per-entity-group action policy', component: 'entity-action-list', order: 212 },
|
||||
],
|
||||
}
|
||||
|
||||
@@ -258,4 +261,72 @@ test.describe('Model Editor - Interactive Tab', () => {
|
||||
await expect(page.locator('nav').first()).toBeVisible()
|
||||
})
|
||||
|
||||
test('built-in secret patterns render as a checklist from field options', async ({ page }) => {
|
||||
const searchInput = page.locator('input[placeholder="Search fields to add..."]')
|
||||
await searchInput.fill('Built-in Secret Patterns')
|
||||
const dropdown = searchInput.locator('..').locator('..')
|
||||
await dropdown.locator('div', { hasText: 'Built-in Secret Patterns' }).first().click()
|
||||
|
||||
// One checkbox per catalogue option; toggling one enables Save.
|
||||
const anthropic = page.locator('label', { hasText: 'Anthropic API key' }).locator('input[type="checkbox"]')
|
||||
await expect(anthropic).toHaveCount(1)
|
||||
await anthropic.check()
|
||||
await expect(anthropic).toBeChecked()
|
||||
})
|
||||
|
||||
test('custom secret patterns render the pattern-list editor', async ({ page }) => {
|
||||
const searchInput = page.locator('input[placeholder="Search fields to add..."]')
|
||||
await searchInput.fill('Custom Secret Patterns')
|
||||
const dropdown = searchInput.locator('..').locator('..')
|
||||
await dropdown.locator('div', { hasText: 'Custom Secret Patterns' }).first().click()
|
||||
|
||||
// Empty state + an Add button; adding a row shows the name + match inputs.
|
||||
const addBtn = page.locator('button', { hasText: 'Add pattern' })
|
||||
await expect(addBtn).toBeVisible()
|
||||
await addBtn.click()
|
||||
await expect(page.locator('input[placeholder^="Name (group)"]')).toBeVisible()
|
||||
await expect(page.locator('input[placeholder^="match,"]')).toBeVisible()
|
||||
})
|
||||
|
||||
// Regression: a map-typed field (entity_actions) present in the loaded YAML
|
||||
// must render WITH its values. flattenConfig used to recurse into the map,
|
||||
// scattering it across pii_detection.entity_actions.<GROUP> paths that match
|
||||
// no registered field, so the editor showed neither the field nor the
|
||||
// per-entity policy (e.g. SSN -> block) the operator had configured.
|
||||
test('entity_actions map field present in YAML renders with its values', async ({ page }) => {
|
||||
// Override the edit endpoint for this test: YAML that carries a populated
|
||||
// entity_actions map alongside a scalar sibling (default_action).
|
||||
await page.route('**/api/models/edit/ner-model', (route) => {
|
||||
route.fulfill({
|
||||
contentType: 'application/json',
|
||||
body: JSON.stringify({
|
||||
name: 'ner-model',
|
||||
config: [
|
||||
'name: ner-model',
|
||||
'backend: llama-cpp',
|
||||
'pii_detection:',
|
||||
' default_action: mask',
|
||||
' entity_actions:',
|
||||
' SSN: block',
|
||||
' EMAIL: mask',
|
||||
'',
|
||||
].join('\n'),
|
||||
}),
|
||||
})
|
||||
})
|
||||
|
||||
await page.goto('/app/model-editor/ner-model')
|
||||
|
||||
// The entity-action-list editor is rendered (field label visible)…
|
||||
await expect(page.getByText('Detector Entity Actions').first()).toBeVisible()
|
||||
// …and bound to the existing map: one row per configured group, in order.
|
||||
const groupInputs = page.locator('input[aria-label="Entity group"]')
|
||||
await expect(groupInputs).toHaveCount(2)
|
||||
await expect(groupInputs.nth(0)).toHaveValue('SSN')
|
||||
await expect(groupInputs.nth(1)).toHaveValue('EMAIL')
|
||||
// The action select shows the bound action label (block), proving the map
|
||||
// values bound, not just an empty editor.
|
||||
await expect(page.getByText(/block —/i).first()).toBeVisible()
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
@@ -178,7 +178,7 @@ test.describe("Models Gallery - Backend Features", () => {
|
||||
});
|
||||
|
||||
const BACKEND_USECASES_MOCK = {
|
||||
"llama-cpp": ["chat", "embeddings", "vision"],
|
||||
"llama-cpp": ["chat", "embeddings", "vision", "token_classify"],
|
||||
whisper: ["transcript"],
|
||||
stablediffusion: ["image"],
|
||||
};
|
||||
@@ -285,13 +285,15 @@ test.describe("Models Gallery - Multi-select Filters", () => {
|
||||
await expect(sttBtn).toBeDisabled();
|
||||
await expect(imageBtn).toBeDisabled();
|
||||
|
||||
// Chat, Embeddings, Vision should remain enabled
|
||||
// Chat, Embeddings, Vision, NER should remain enabled
|
||||
const chatBtn = page.locator(".filter-btn", { hasText: "Chat" });
|
||||
const embBtn = page.locator(".filter-btn", { hasText: "Embeddings" });
|
||||
const visBtn = page.locator(".filter-btn", { hasText: "Vision" });
|
||||
const nerBtn = page.locator(".filter-btn", { hasText: "NER" });
|
||||
await expect(chatBtn).toBeEnabled();
|
||||
await expect(embBtn).toBeEnabled();
|
||||
await expect(visBtn).toBeEnabled();
|
||||
await expect(nerBtn).toBeEnabled();
|
||||
});
|
||||
|
||||
test("backend clears incompatible filters", async ({ page }) => {
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
"rerank": "Rerank",
|
||||
"detection": "Detection",
|
||||
"vad": "VAD",
|
||||
"ner": "NER",
|
||||
"fitsGpu": "Fits in GPU",
|
||||
"allBackends": "All Backends",
|
||||
"searchBackends": "Search backends..."
|
||||
|
||||
@@ -6,7 +6,9 @@ import SearchableModelSelect from './SearchableModelSelect'
|
||||
import AutocompleteInput from './AutocompleteInput'
|
||||
import CodeEditor from './CodeEditor'
|
||||
import StructuredCodeEditor from './StructuredCodeEditor'
|
||||
import PIIPatternListEditor from './PIIPatternListEditor'
|
||||
import EntityActionListEditor from './EntityActionListEditor'
|
||||
import PatternListEditor from './PatternListEditor'
|
||||
import ModelMultiSelect from './ModelMultiSelect'
|
||||
import RouterCandidatesEditor from './RouterCandidatesEditor'
|
||||
import RouterPoliciesEditor from './RouterPoliciesEditor'
|
||||
|
||||
@@ -17,6 +19,7 @@ const PROVIDER_TO_CAPABILITY = {
|
||||
'models:transcript': 'FLAG_TRANSCRIPT',
|
||||
'models:vad': 'FLAG_VAD',
|
||||
'models:score': 'FLAG_SCORE',
|
||||
'models:token_classify': 'FLAG_TOKEN_CLASSIFY',
|
||||
}
|
||||
|
||||
function coerceValue(raw, uiType) {
|
||||
@@ -395,10 +398,10 @@ export default function ConfigFieldRenderer({ field, value, onChange, onRemove,
|
||||
)
|
||||
}
|
||||
|
||||
// PII pattern list — per-model action overrides for named patterns.
|
||||
// The pattern catalog is loaded from /api/pii/patterns at render time
|
||||
// so new built-in patterns surface automatically.
|
||||
if (component === 'pii-pattern-list') {
|
||||
// PII detectors — a capability-filtered multi-select of token_classify
|
||||
// models (the consuming model's pii.detectors list).
|
||||
if (component === 'model-multi-select') {
|
||||
const cap = PROVIDER_TO_CAPABILITY[field.autocomplete_provider] || undefined
|
||||
return (
|
||||
<div style={{ padding: 'var(--spacing-sm) 0', borderBottom: '1px solid var(--color-border-subtle)' }}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 4 }}>
|
||||
@@ -407,7 +410,62 @@ export default function ConfigFieldRenderer({ field, value, onChange, onRemove,
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', marginTop: 2 }}>{description}</div>
|
||||
</div>
|
||||
</div>
|
||||
<PIIPatternListEditor value={value} onChange={handleChange} />
|
||||
<ModelMultiSelect value={value} onChange={handleChange} capability={cap} placeholder={field.placeholder} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// PII detection entity-action map — a detector model's
|
||||
// pii_detection.entity_actions (entity group -> mask|block|allow).
|
||||
if (component === 'entity-action-list') {
|
||||
return (
|
||||
<div style={{ padding: 'var(--spacing-sm) 0', borderBottom: '1px solid var(--color-border-subtle)' }}>
|
||||
<div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 4 }}>
|
||||
<div>
|
||||
<div style={{ fontSize: '0.875rem', fontWeight: 500 }}><FieldLabel field={field} /></div>
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', marginTop: 2 }}>{description}</div>
|
||||
</div>
|
||||
</div>
|
||||
<EntityActionListEditor value={value} onChange={handleChange} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// PII built-in secret patterns — a checklist of named built-in patterns
|
||||
// (pii_detection.builtins). value is an array of selected names.
|
||||
if (component === 'pii-builtins-select') {
|
||||
const selected = Array.isArray(value) ? value : []
|
||||
const toggle = (name) => {
|
||||
handleChange(selected.includes(name) ? selected.filter(n => n !== name) : [...selected, name])
|
||||
}
|
||||
return (
|
||||
<div style={{ padding: 'var(--spacing-sm) 0', borderBottom: '1px solid var(--color-border-subtle)' }}>
|
||||
<div style={{ marginBottom: 4 }}>
|
||||
<div style={{ fontSize: '0.875rem', fontWeight: 500 }}><FieldLabel field={field} /></div>
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', marginTop: 2 }}>{description}</div>
|
||||
</div>
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 4 }}>
|
||||
{(field.options || []).map(opt => (
|
||||
<label key={opt.value} style={{ display: 'flex', alignItems: 'center', gap: 8, fontSize: '0.8125rem', cursor: 'pointer' }}>
|
||||
<input type="checkbox" checked={selected.includes(opt.value)} onChange={() => toggle(opt.value)} />
|
||||
{opt.label || opt.value}
|
||||
</label>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// PII custom secret patterns — operator-defined restricted-regex rules
|
||||
// (pii_detection.patterns). value is an array of {name, match, action, min_len}.
|
||||
if (component === 'pii-pattern-list') {
|
||||
return (
|
||||
<div style={{ padding: 'var(--spacing-sm) 0', borderBottom: '1px solid var(--color-border-subtle)' }}>
|
||||
<div style={{ marginBottom: 4 }}>
|
||||
<div style={{ fontSize: '0.875rem', fontWeight: 500 }}><FieldLabel field={field} /></div>
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)', marginTop: 2 }}>{description}</div>
|
||||
</div>
|
||||
<PatternListEditor value={value} onChange={handleChange} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
98
core/http/react-ui/src/components/EntityActionListEditor.jsx
Normal file
98
core/http/react-ui/src/components/EntityActionListEditor.jsx
Normal file
@@ -0,0 +1,98 @@
|
||||
import { useMemo } from 'react'
|
||||
import SearchableSelect from './SearchableSelect'
|
||||
|
||||
// Editor for a detector model's pii_detection.entity_actions map:
|
||||
// entity-group name -> action. The value is an object {GROUP: action};
|
||||
// this component renders one row per entry and emits a fresh object on
|
||||
// every change. Entity-group names are model-defined (the privacy-filter
|
||||
// family emits uppercase names with no separators), so the group field is
|
||||
// free text with a datalist of common high-value categories for
|
||||
// convenience — any string the model emits is valid.
|
||||
|
||||
const ACTION_OPTIONS = [
|
||||
{ value: 'mask', label: 'mask — replace with [REDACTED:ner:GROUP]' },
|
||||
{ value: 'block', label: 'block — reject the request (HTTP 400)' },
|
||||
{ value: 'allow', label: 'allow — detect & log, leave text unchanged' },
|
||||
]
|
||||
|
||||
// Common categories surfaced as datalist hints. Not exhaustive and not
|
||||
// authoritative — the model's own label set is the source of truth.
|
||||
const COMMON_GROUPS = [
|
||||
'PASSWORD', 'PIN', 'CVV', 'CREDITCARD', 'IBAN', 'BIC', 'BANKACCOUNT', 'SSN',
|
||||
'BITCOINADDRESS', 'ETHEREUMADDRESS', 'LITECOINADDRESS',
|
||||
'EMAIL', 'PHONE', 'URL', 'IPADDRESS', 'MACADDRESS',
|
||||
'FIRSTNAME', 'LASTNAME', 'MIDDLENAME', 'USERNAME', 'DATEOFBIRTH',
|
||||
'STREET', 'CITY', 'STATE', 'ZIPCODE', 'GPSCOORDINATES',
|
||||
]
|
||||
|
||||
export default function EntityActionListEditor({ value, onChange }) {
|
||||
// value is an object map; preserve insertion order via Object.entries.
|
||||
const entries = useMemo(
|
||||
() => (value && typeof value === 'object' && !Array.isArray(value) ? Object.entries(value) : []),
|
||||
[value]
|
||||
)
|
||||
|
||||
const datalistId = 'pii-entity-groups'
|
||||
|
||||
const update = (index, key, action) => {
|
||||
const next = entries.map((e, i) => (i === index ? [key, action] : e))
|
||||
onChange(Object.fromEntries(next.filter(([k]) => k !== '')))
|
||||
}
|
||||
|
||||
const remove = (index) => {
|
||||
onChange(Object.fromEntries(entries.filter((_, i) => i !== index)))
|
||||
}
|
||||
|
||||
const add = () => {
|
||||
// New rows default to mask; an empty key is tolerated transiently and
|
||||
// filtered out on the next edit / when serialised.
|
||||
onChange(Object.fromEntries([...entries, ['', 'mask']]))
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 6, width: '100%' }}>
|
||||
<datalist id={datalistId}>
|
||||
{COMMON_GROUPS.map(g => <option key={g} value={g} />)}
|
||||
</datalist>
|
||||
|
||||
{entries.length === 0 && (
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)' }}>
|
||||
No per-entity actions — every detected group uses the default action. Add a row to
|
||||
block or allow-log a specific entity group (e.g. <code>PASSWORD</code> → block).
|
||||
</div>
|
||||
)}
|
||||
|
||||
{entries.map(([group, action], i) => (
|
||||
<div key={i} style={{ display: 'flex', gap: 6, alignItems: 'center', flexWrap: 'wrap' }}>
|
||||
<input
|
||||
className="input"
|
||||
list={datalistId}
|
||||
value={group}
|
||||
placeholder="Entity group (e.g. PASSWORD)"
|
||||
onChange={e => update(i, e.target.value, action)}
|
||||
style={{ flex: '1 1 220px', minWidth: 180, fontSize: '0.8125rem' }}
|
||||
aria-label="Entity group"
|
||||
/>
|
||||
<SearchableSelect
|
||||
value={action || 'mask'}
|
||||
onChange={v => update(i, group, v)}
|
||||
options={ACTION_OPTIONS}
|
||||
placeholder="Action..."
|
||||
style={{ flex: '1 1 240px', minWidth: 220 }}
|
||||
/>
|
||||
<button type="button" className="btn btn-secondary btn-sm"
|
||||
onClick={() => remove(i)}
|
||||
style={{ padding: '2px 8px', fontSize: '0.75rem' }}
|
||||
aria-label="Remove entity action">
|
||||
<i className="fas fa-times" />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
|
||||
<button type="button" className="btn btn-secondary btn-sm" onClick={add}
|
||||
style={{ alignSelf: 'flex-start', fontSize: '0.75rem' }}>
|
||||
<i className="fas fa-plus" /> Add entity action
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
62
core/http/react-ui/src/components/ModelMultiSelect.jsx
Normal file
62
core/http/react-ui/src/components/ModelMultiSelect.jsx
Normal file
@@ -0,0 +1,62 @@
|
||||
import SearchableModelSelect from './SearchableModelSelect'
|
||||
|
||||
// Editor for a list of model names (value is []string). Selected models render
|
||||
// as compact removable chips; a single capability-filtered, commit-only picker
|
||||
// adds new ones. Used for pii.detectors / the instance-wide default detector,
|
||||
// where every entry must be a token_classify model. Already-selected models are
|
||||
// guarded against so each appears at most once.
|
||||
//
|
||||
// The picker is commit-only on purpose: typing a partial query must never be
|
||||
// treated as a chosen model (otherwise each keystroke would add a bogus entry),
|
||||
// and selecting one input box per detector wastes vertical space.
|
||||
export default function ModelMultiSelect({ value, onChange, capability, placeholder }) {
|
||||
const items = Array.isArray(value) ? value : []
|
||||
|
||||
const remove = (index) => onChange(items.filter((_, i) => i !== index))
|
||||
const add = (v) => {
|
||||
if (!v || items.includes(v)) return
|
||||
onChange([...items, v])
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 6, width: '100%' }}>
|
||||
{items.length === 0 ? (
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)' }}>
|
||||
No detectors — PII is enabled but nothing scans requests. Add a token-classification
|
||||
(NER) model below; its <code>pii_detection</code> block supplies the policy.
|
||||
</div>
|
||||
) : (
|
||||
<div style={{ display: 'flex', flexWrap: 'wrap', gap: 6 }}>
|
||||
{items.map((name, i) => (
|
||||
<span key={i} style={{
|
||||
display: 'inline-flex', alignItems: 'center', gap: 6,
|
||||
padding: '2px 4px 2px 10px', fontSize: '0.8125rem',
|
||||
fontFamily: 'var(--font-mono)', background: 'var(--color-bg-tertiary)',
|
||||
borderRadius: 'var(--radius-md)',
|
||||
}}>
|
||||
{name}
|
||||
<button type="button" className="btn btn-secondary btn-sm"
|
||||
onClick={() => remove(i)}
|
||||
style={{ padding: '0 6px', fontSize: '0.75rem', lineHeight: 1.6 }}
|
||||
aria-label={`Remove ${name}`}>
|
||||
<i className="fas fa-times" />
|
||||
</button>
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Size by width only. The container is a flex column, so a flex-basis
|
||||
here would set the wrapper's HEIGHT — which the dropdown anchors to
|
||||
(top: 100%), opening it far below the input. */}
|
||||
<SearchableModelSelect
|
||||
value=""
|
||||
onChange={add}
|
||||
commitOnly
|
||||
capability={capability}
|
||||
placeholder={placeholder || '+ Add detector model...'}
|
||||
style={{ width: '100%', maxWidth: 360 }}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -1,120 +0,0 @@
|
||||
import { useState, useEffect, useMemo } from 'react'
|
||||
import { apiUrl } from '../utils/basePath'
|
||||
import SearchableSelect from './SearchableSelect'
|
||||
|
||||
const ACTION_OPTIONS = [
|
||||
{ value: 'mask', label: 'Mask — replace with a [REDACTED:id] placeholder' },
|
||||
{ value: 'block', label: 'Block — reject the request (request side) / mask in stream' },
|
||||
{ value: 'allow', label: 'Allow — detect & log, leave text unchanged' },
|
||||
]
|
||||
|
||||
export default function PIIPatternListEditor({ value, onChange }) {
|
||||
const items = Array.isArray(value) ? value : []
|
||||
|
||||
const [catalog, setCatalog] = useState([])
|
||||
const [loadError, setLoadError] = useState(null)
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false
|
||||
fetch(apiUrl('/api/pii/patterns'))
|
||||
.then(r => r.ok ? r.json() : Promise.reject(new Error(`HTTP ${r.status}`)))
|
||||
.then(data => { if (!cancelled) setCatalog(data?.patterns || []) })
|
||||
.catch(err => { if (!cancelled) setLoadError(err.message) })
|
||||
return () => { cancelled = true }
|
||||
}, [])
|
||||
|
||||
const idOptions = useMemo(() =>
|
||||
catalog.map(p => ({
|
||||
value: p.id,
|
||||
label: p.description ? `${p.id} — ${p.description}` : p.id,
|
||||
})),
|
||||
[catalog]
|
||||
)
|
||||
|
||||
// Patterns already chosen — exclude from the "add row" select so each
|
||||
// pattern only appears once per model.
|
||||
const usedIDs = new Set(items.map(it => it?.id).filter(Boolean))
|
||||
const availableForAdd = idOptions.filter(o => !usedIDs.has(o.value))
|
||||
|
||||
const update = (index, key, val) => {
|
||||
const next = items.map((it, i) =>
|
||||
i === index ? { ...it, [key]: val } : it
|
||||
)
|
||||
onChange(next)
|
||||
}
|
||||
|
||||
const remove = (index) => {
|
||||
onChange(items.filter((_, i) => i !== index))
|
||||
}
|
||||
|
||||
const add = (id) => {
|
||||
const cat = catalog.find(c => c.id === id)
|
||||
onChange([...items, { id, action: cat?.action || 'mask' }])
|
||||
}
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 6, width: '100%' }}>
|
||||
{loadError && (
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-error)' }}>
|
||||
Could not load pattern catalog: {loadError}. You can still type IDs manually.
|
||||
</div>
|
||||
)}
|
||||
|
||||
{items.length === 0 && (
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)' }}>
|
||||
No overrides — every pattern uses its global default action. Add a row below to
|
||||
tighten or relax the action for a specific pattern on this model.
|
||||
</div>
|
||||
)}
|
||||
|
||||
{items.map((row, i) => {
|
||||
const cat = catalog.find(c => c.id === row?.id)
|
||||
const idLabel = cat?.description ? `${row.id} — ${cat.description}` : (row?.id || '')
|
||||
// Show the chosen id even if the catalog hasn't loaded yet (or
|
||||
// the YAML references an unknown pattern), so users can edit
|
||||
// without losing context.
|
||||
const idItems = [
|
||||
...(row?.id && !idOptions.some(o => o.value === row.id)
|
||||
? [{ value: row.id, label: idLabel }]
|
||||
: []),
|
||||
...idOptions.filter(o => o.value === row?.id || !usedIDs.has(o.value)),
|
||||
]
|
||||
return (
|
||||
<div key={i} style={{ display: 'flex', gap: 6, alignItems: 'center', flexWrap: 'wrap' }}>
|
||||
<SearchableSelect
|
||||
value={row?.id || ''}
|
||||
onChange={v => update(i, 'id', v)}
|
||||
options={idItems}
|
||||
placeholder="Pattern..."
|
||||
style={{ flex: '1 1 220px', minWidth: 200 }}
|
||||
/>
|
||||
<SearchableSelect
|
||||
value={row?.action || 'mask'}
|
||||
onChange={v => update(i, 'action', v)}
|
||||
options={ACTION_OPTIONS}
|
||||
placeholder="Action..."
|
||||
style={{ flex: '1 1 240px', minWidth: 220 }}
|
||||
/>
|
||||
<button type="button" className="btn btn-secondary btn-sm"
|
||||
onClick={() => remove(i)}
|
||||
style={{ padding: '2px 8px', fontSize: '0.75rem' }}>
|
||||
<i className="fas fa-times" />
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
|
||||
{availableForAdd.length > 0 && (
|
||||
<div style={{ display: 'flex', gap: 6, alignItems: 'center' }}>
|
||||
<SearchableSelect
|
||||
value=""
|
||||
onChange={v => v && add(v)}
|
||||
options={availableForAdd}
|
||||
placeholder="+ Add pattern override..."
|
||||
style={{ flex: '1 1 220px', minWidth: 200 }}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
96
core/http/react-ui/src/components/PatternListEditor.jsx
Normal file
96
core/http/react-ui/src/components/PatternListEditor.jsx
Normal file
@@ -0,0 +1,96 @@
|
||||
import { useMemo } from 'react'
|
||||
import SearchableSelect from './SearchableSelect'
|
||||
|
||||
// Editor for a pattern detector's pii_detection.patterns: a list of
|
||||
// operator-defined secret patterns. Value is an array of
|
||||
// { name, match, action?, min_len? }; this renders one row per pattern and
|
||||
// emits a fresh array on every change. Patterns use a restricted regex subset
|
||||
// validated server-side at save (an invalid pattern surfaces as the save
|
||||
// error), so no regex engine is shipped to the client.
|
||||
|
||||
const ACTION_OPTIONS = [
|
||||
{ value: '', label: 'default (use Default Action)' },
|
||||
{ value: 'mask', label: 'mask — replace the span' },
|
||||
{ value: 'block', label: 'block — reject the request' },
|
||||
{ value: 'allow', label: 'allow — detect & log only' },
|
||||
]
|
||||
|
||||
function emptyPattern() {
|
||||
return { name: '', match: '', action: '', min_len: 0 }
|
||||
}
|
||||
|
||||
export default function PatternListEditor({ value, onChange }) {
|
||||
const rows = useMemo(() => (Array.isArray(value) ? value : []), [value])
|
||||
|
||||
const update = (index, patch) => {
|
||||
onChange(rows.map((r, i) => (i === index ? { ...r, ...patch } : r)))
|
||||
}
|
||||
const remove = (index) => onChange(rows.filter((_, i) => i !== index))
|
||||
const add = () => onChange([...rows, emptyPattern()])
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', flexDirection: 'column', gap: 8, width: '100%' }}>
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)' }}>
|
||||
Restricted regex: literals, <code>[…]</code> classes, <code>\w \d \s</code>, <code>?*+{'{m,n}'}</code>, anchors.
|
||||
Each pattern must contain a fixed literal run of ≥3 characters (e.g. <code>sk-prefix-</code>);
|
||||
<code>.</code> and capturing groups are not allowed. Matches report under the pattern name.
|
||||
</div>
|
||||
|
||||
{rows.length === 0 && (
|
||||
<div style={{ fontSize: '0.75rem', color: 'var(--color-text-muted)' }}>
|
||||
No custom patterns. Enable built-ins above, or add a pattern for an internal credential
|
||||
format (e.g. <code>tok-[A-Za-z0-9]{'{32,64}'}</code>).
|
||||
</div>
|
||||
)}
|
||||
|
||||
{rows.map((r, i) => (
|
||||
<div key={i} style={{ display: 'flex', gap: 6, alignItems: 'center', flexWrap: 'wrap' }}>
|
||||
<input
|
||||
className="input"
|
||||
value={r.name || ''}
|
||||
placeholder="Name (group), e.g. INTERNAL_TOKEN"
|
||||
onChange={e => update(i, { name: e.target.value })}
|
||||
style={{ flex: '1 1 180px', minWidth: 150, fontSize: '0.8125rem' }}
|
||||
aria-label="Pattern name"
|
||||
/>
|
||||
<input
|
||||
className="input input-mono"
|
||||
value={r.match || ''}
|
||||
placeholder="match, e.g. tok-[A-Za-z0-9]{32,64}"
|
||||
onChange={e => update(i, { match: e.target.value })}
|
||||
style={{ flex: '2 1 240px', minWidth: 200, fontSize: '0.8125rem', fontFamily: 'var(--font-mono)' }}
|
||||
aria-label="Pattern match"
|
||||
/>
|
||||
<SearchableSelect
|
||||
value={r.action || ''}
|
||||
onChange={v => update(i, { action: v })}
|
||||
options={ACTION_OPTIONS}
|
||||
placeholder="Action..."
|
||||
style={{ flex: '1 1 200px', minWidth: 180 }}
|
||||
/>
|
||||
<input
|
||||
className="input"
|
||||
type="number"
|
||||
min={0}
|
||||
value={r.min_len || 0}
|
||||
title="Minimum match length (0 = no floor)"
|
||||
onChange={e => update(i, { min_len: parseInt(e.target.value, 10) || 0 })}
|
||||
style={{ width: 80, fontSize: '0.8125rem' }}
|
||||
aria-label="Minimum length"
|
||||
/>
|
||||
<button type="button" className="btn btn-secondary btn-sm"
|
||||
onClick={() => remove(i)}
|
||||
style={{ padding: '2px 8px', fontSize: '0.75rem' }}
|
||||
aria-label="Remove pattern">
|
||||
<i className="fas fa-times" />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
|
||||
<button type="button" className="btn btn-secondary btn-sm" onClick={add}
|
||||
style={{ alignSelf: 'flex-start', fontSize: '0.75rem' }}>
|
||||
<i className="fas fa-plus" /> Add pattern
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -1,7 +1,13 @@
|
||||
import { useState, useEffect, useRef, useCallback } from 'react'
|
||||
import { useModels } from '../hooks/useModels'
|
||||
|
||||
export default function SearchableModelSelect({ value, onChange, capability, placeholder = 'Type or select a model...', style }) {
|
||||
// commitOnly: when true, onChange fires only on an explicit commit (selecting an
|
||||
// item, or Enter) — never on each keystroke. Use it where each onChange is a
|
||||
// final selection (e.g. the ModelMultiSelect "add" picker), so a partial typed
|
||||
// query isn't treated as a chosen value. After a commit the field is cleared,
|
||||
// matching the add-and-clear flow. Default false keeps the as-you-type
|
||||
// behaviour single-value editors rely on.
|
||||
export default function SearchableModelSelect({ value, onChange, capability, placeholder = 'Type or select a model...', style, commitOnly = false }) {
|
||||
const { models, loading } = useModels(capability)
|
||||
const [query, setQuery] = useState('')
|
||||
const [open, setOpen] = useState(false)
|
||||
@@ -33,11 +39,13 @@ export default function SearchableModelSelect({ value, onChange, capability, pla
|
||||
: -1
|
||||
|
||||
const commit = useCallback((val) => {
|
||||
setQuery(val)
|
||||
// In commitOnly mode the field is an "add" box — clear it after a pick so
|
||||
// the next selection starts fresh; otherwise reflect the chosen value.
|
||||
setQuery(commitOnly ? '' : val)
|
||||
onChange(val)
|
||||
setOpen(false)
|
||||
setFocusIndex(-1)
|
||||
}, [onChange])
|
||||
}, [onChange, commitOnly])
|
||||
|
||||
const handleKeyDown = (e) => {
|
||||
if (!open && (e.key === 'ArrowDown' || e.key === 'ArrowUp')) {
|
||||
@@ -133,8 +141,10 @@ export default function SearchableModelSelect({ value, onChange, capability, pla
|
||||
setQuery(e.target.value)
|
||||
setOpen(true)
|
||||
setFocusIndex(-1)
|
||||
// Commit on every keystroke so the parent always has current value
|
||||
onChange(e.target.value)
|
||||
// Single-value editors want the parent updated as you type; an
|
||||
// "add" picker (commitOnly) must wait for an explicit commit so a
|
||||
// partial query is never mistaken for a chosen model.
|
||||
if (!commitOnly) onChange(e.target.value)
|
||||
}}
|
||||
onFocus={() => setOpen(true)}
|
||||
onKeyDown={handleKeyDown}
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import { useState, useEffect } from 'react'
|
||||
import { modelsApi } from '../utils/api'
|
||||
|
||||
// Stable empty references so consumers that memoize on `sections`/`fields`
|
||||
// (e.g. ModelEditor's leafPaths) don't see a new array every render while
|
||||
// the metadata request is still in flight — which would thrash their effects.
|
||||
const EMPTY = []
|
||||
|
||||
export function useConfigMetadata() {
|
||||
const [metadata, setMetadata] = useState(null)
|
||||
const [loading, setLoading] = useState(true)
|
||||
@@ -14,8 +19,8 @@ export function useConfigMetadata() {
|
||||
}, [])
|
||||
|
||||
return {
|
||||
sections: metadata?.sections || [],
|
||||
fields: metadata?.fields || [],
|
||||
sections: metadata?.sections || EMPTY,
|
||||
fields: metadata?.fields || EMPTY,
|
||||
loading,
|
||||
error,
|
||||
}
|
||||
|
||||
@@ -2,12 +2,13 @@ import { useState, useEffect, useCallback, useRef, useMemo, Fragment } from 'rea
|
||||
import { useOutletContext, Link, useNavigate, useLocation, useSearchParams } from 'react-router-dom'
|
||||
import { apiUrl } from '../utils/basePath'
|
||||
import { fromState } from '../utils/editorNav'
|
||||
import { settingsApi } from '../utils/api'
|
||||
import { settingsApi, modelsApi } from '../utils/api'
|
||||
import LoadingSpinner from '../components/LoadingSpinner'
|
||||
import Toggle from '../components/Toggle'
|
||||
|
||||
// Middleware admin page. Three tabs:
|
||||
// - Filtering: PII pattern catalogue + per-model resolved state +
|
||||
// pattern-action editor (PUT /api/pii/patterns/:id, transient).
|
||||
// - Filtering: per-model resolved PII state + per-model detector list
|
||||
// (detection policy lives on each detector model's pii_detection block).
|
||||
// - Routing: placeholder until subsystem 2 lands. Renders the note
|
||||
// from /api/router/status so admins see "not yet implemented" rather
|
||||
// than an empty page.
|
||||
@@ -27,8 +28,6 @@ const TABS = [
|
||||
{ id: 'events', label: 'Events', icon: 'fa-list-ul' },
|
||||
]
|
||||
|
||||
const ACTIONS = ['mask', 'block', 'allow']
|
||||
|
||||
function actionBadge(action) {
|
||||
const colors = {
|
||||
mask: 'var(--color-primary)',
|
||||
@@ -82,8 +81,6 @@ export default function Middleware() {
|
||||
const [searchParams, setSearchParams] = useSearchParams()
|
||||
const initialTab = searchParams.get('tab') || localStorage.getItem('middleware-tab') || 'filtering'
|
||||
const [activeTab, setActiveTab] = useState(TABS.some(t => t.id === initialTab) ? initialTab : 'filtering')
|
||||
const [pendingPattern, setPendingPattern] = useState(null) // id while a PUT is in flight
|
||||
|
||||
const selectTab = (id) => {
|
||||
setActiveTab(id)
|
||||
localStorage.setItem('middleware-tab', id)
|
||||
@@ -130,51 +127,6 @@ export default function Middleware() {
|
||||
return () => clearInterval(refreshRef.current)
|
||||
}, [fetchAll])
|
||||
|
||||
const mutatePattern = async (patternID, body, successMsg) => {
|
||||
setPendingPattern(patternID)
|
||||
try {
|
||||
const res = await fetch(apiUrl(`/api/pii/patterns/${encodeURIComponent(patternID)}`), {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(body),
|
||||
})
|
||||
if (!res.ok) {
|
||||
const data = await res.json().catch(() => ({}))
|
||||
throw new Error(data.error || `HTTP ${res.status}`)
|
||||
}
|
||||
addToast(successMsg, 'success')
|
||||
await fetchAll()
|
||||
} catch (err) {
|
||||
addToast(`Failed to update pattern: ${err.message}`, 'error')
|
||||
} finally {
|
||||
setPendingPattern(null)
|
||||
}
|
||||
}
|
||||
|
||||
const setPatternAction = (patternID, action) =>
|
||||
mutatePattern(patternID, { action }, `Pattern ${patternID}: action ${action} (transient — click "Save to disk" to persist)`)
|
||||
|
||||
const setPatternDisabled = (patternID, disabled) =>
|
||||
mutatePattern(patternID, { disabled }, `Pattern ${patternID}: ${disabled ? 'disabled' : 'enabled'} (transient — click "Save to disk" to persist)`)
|
||||
|
||||
const [persisting, setPersisting] = useState(false)
|
||||
const persistPatterns = async () => {
|
||||
setPersisting(true)
|
||||
try {
|
||||
const res = await fetch(apiUrl('/api/pii/patterns/persist'), { method: 'POST' })
|
||||
if (!res.ok) {
|
||||
const data = await res.json().catch(() => ({}))
|
||||
throw new Error(data.error || `HTTP ${res.status}`)
|
||||
}
|
||||
const data = await res.json().catch(() => ({}))
|
||||
addToast(`Saved ${data.override_count ?? 0} pattern override(s) to runtime_settings.json`, 'success')
|
||||
} catch (err) {
|
||||
addToast(`Failed to persist: ${err.message}`, 'error')
|
||||
} finally {
|
||||
setPersisting(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="page page--wide">
|
||||
<div className="page-header" style={{ marginBottom: 'var(--spacing-sm)' }}>
|
||||
@@ -207,14 +159,7 @@ export default function Middleware() {
|
||||
<LoadingSpinner size="lg" />
|
||||
</div>
|
||||
) : activeTab === 'filtering' ? (
|
||||
<FilteringTab
|
||||
status={status}
|
||||
pendingPattern={pendingPattern}
|
||||
onSetAction={setPatternAction}
|
||||
onSetDisabled={setPatternDisabled}
|
||||
onPersist={persistPatterns}
|
||||
persisting={persisting}
|
||||
/>
|
||||
<FilteringTab status={status} addToast={addToast} onChanged={fetchAll} />
|
||||
) : activeTab === 'routing' ? (
|
||||
<RoutingTab status={status} decisions={decisions} />
|
||||
) : activeTab === 'proxy' ? (
|
||||
@@ -226,24 +171,33 @@ export default function Middleware() {
|
||||
)
|
||||
}
|
||||
|
||||
function FilteringTab({ status, pendingPattern, onSetAction, onSetDisabled, onPersist, persisting }) {
|
||||
function FilteringTab({ status, addToast, onChanged }) {
|
||||
const location = useLocation()
|
||||
// Rows mid-save, so just that model's toggle disables while the PATCH
|
||||
// round-trips (and the 5s background poll re-syncs the resolved state).
|
||||
const [piiBusy, setPiiBusy] = useState(() => new Set())
|
||||
|
||||
// Toggling the PII column writes an explicit pii.enabled to the model YAML
|
||||
// via PATCH /api/models/config-json/:name (a deep-merge that preserves
|
||||
// pii.detectors and every other field). This makes the resolved state
|
||||
// explicit: a cloud-proxy model shown ON by backend default becomes
|
||||
// pii.enabled:true; toggling it OFF writes pii.enabled:false.
|
||||
const togglePII = async (name, on) => {
|
||||
setPiiBusy(prev => new Set(prev).add(name))
|
||||
try {
|
||||
await modelsApi.patchConfig(name, { pii: { enabled: on } })
|
||||
addToast?.(on ? `PII filtering enabled for ${name}` : `PII filtering disabled for ${name}`, 'success')
|
||||
onChanged?.()
|
||||
} catch (err) {
|
||||
addToast?.(`Failed to update ${name}: ${err.message}`, 'error')
|
||||
} finally {
|
||||
setPiiBusy(prev => { const n = new Set(prev); n.delete(name); return n })
|
||||
}
|
||||
}
|
||||
|
||||
if (!status?.pii) return null
|
||||
const pii = status.pii
|
||||
|
||||
if (!pii.enabled_globally) {
|
||||
return (
|
||||
<div className="empty-state">
|
||||
<div className="empty-state-icon"><i className="fas fa-shield-slash" /></div>
|
||||
<h2 className="empty-state-title">PII filtering disabled</h2>
|
||||
<p className="empty-state-text">
|
||||
The PII filter is disabled by <code>{pii.reason || '--disable-pii'}</code>.
|
||||
Restart without that flag to enable it.
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Default rule banner */}
|
||||
@@ -251,90 +205,23 @@ function FilteringTab({ status, pendingPattern, onSetAction, onSetDisabled, onPe
|
||||
<div style={{ display: 'flex', alignItems: 'flex-start', gap: 'var(--spacing-sm)' }}>
|
||||
<i className="fas fa-info-circle" style={{ color: 'var(--color-text-muted)', marginTop: 2 }} />
|
||||
<div>
|
||||
<div style={{ fontWeight: 600, marginBottom: 4 }}>Default policy</div>
|
||||
<div style={{ fontWeight: 600, marginBottom: 4 }}>NER-based PII redaction</div>
|
||||
<div style={{ fontSize: '0.8125rem', color: 'var(--color-text-secondary)' }}>
|
||||
PII redaction is per-model and OFF by default. Backends matching <code>{(pii.default_enabled_for_backends || []).join(', ')}</code> default to ON (cloud passthroughs). Override per model with <code>pii: {'{'} enabled: true {'}'}</code> in the model YAML.
|
||||
Redaction is per-model and runs request-side. It is OFF by default; backends matching <code>{(pii.default_enabled_for_backends || []).join(', ')}</code> default to ON (cloud passthroughs). A model opts in with <code>pii: {'{'} enabled: true, detectors: […] {'}'}</code>; each detector is a <code>token_classify</code> model whose <code>pii_detection</code> block defines the policy (which entities, what action, min score). Edit a detector model to change its policy.
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Patterns table */}
|
||||
<div className="card" style={{ padding: 'var(--spacing-md)', marginBottom: 'var(--spacing-md)' }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', marginBottom: 'var(--spacing-sm)' }}>
|
||||
<span style={{ fontSize: '0.875rem', fontWeight: 600 }}>Active patterns</span>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 'var(--spacing-sm)' }}>
|
||||
<span style={{ fontSize: '0.6875rem', color: 'var(--color-text-muted)' }}>
|
||||
Toggle / action edits are transient — click Save to disk to persist.
|
||||
</span>
|
||||
<button
|
||||
className="btn btn-secondary btn-sm"
|
||||
onClick={onPersist}
|
||||
disabled={persisting}
|
||||
style={{ fontSize: '0.75rem' }}
|
||||
>
|
||||
<i className={`fas ${persisting ? 'fa-spinner fa-spin' : 'fa-save'}`} /> Save to disk
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="table-container">
|
||||
<table className="table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th style={{ width: 80 }}>Enabled</th>
|
||||
<th style={{ width: 140 }}>Pattern</th>
|
||||
<th>Description</th>
|
||||
<th style={{ width: 110 }}>Action</th>
|
||||
<th style={{ width: 250 }}>Change</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{pii.patterns.map(p => {
|
||||
const enabled = !p.disabled
|
||||
const muted = p.disabled
|
||||
return (
|
||||
<tr key={p.id} style={muted ? { opacity: 0.55 } : undefined}>
|
||||
<td>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={enabled}
|
||||
disabled={pendingPattern === p.id}
|
||||
onChange={e => onSetDisabled(p.id, !e.target.checked)}
|
||||
style={{ cursor: 'pointer' }}
|
||||
aria-label={`Enable ${p.id} pattern`}
|
||||
/>
|
||||
</td>
|
||||
<td style={{ fontFamily: 'var(--font-mono)', fontSize: '0.8125rem', fontWeight: 600 }}>{p.id}</td>
|
||||
<td style={{ fontSize: '0.8125rem', color: 'var(--color-text-secondary)' }}>{p.description}</td>
|
||||
<td>{actionBadge(p.action)}</td>
|
||||
<td>
|
||||
<div style={{ display: 'flex', gap: 4 }}>
|
||||
{ACTIONS.map(a => (
|
||||
<button
|
||||
key={a}
|
||||
className={`btn btn-sm ${p.action === a ? 'btn-primary' : 'btn-secondary'}`}
|
||||
onClick={() => onSetAction(p.id, a)}
|
||||
disabled={pendingPattern === p.id || p.action === a || p.disabled}
|
||||
style={{ fontSize: '0.6875rem', padding: '2px 8px' }}
|
||||
>
|
||||
{a}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
</td>
|
||||
</tr>
|
||||
)})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
{/* Detector models + instance-wide default policy (per-row toggle) */}
|
||||
<DetectorModels pii={pii} addToast={addToast} onChanged={onChanged} />
|
||||
|
||||
{/* Per-model resolved state */}
|
||||
<div className="card" style={{ padding: 'var(--spacing-md)' }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', marginBottom: 'var(--spacing-sm)' }}>
|
||||
<span style={{ fontSize: '0.875rem', fontWeight: 600 }}>Per-model state</span>
|
||||
<span style={{ fontSize: '0.6875rem', color: 'var(--color-text-muted)' }}>
|
||||
Edit the model YAML to change these.
|
||||
Toggle PII inline; edit a row for detectors and policy.
|
||||
</span>
|
||||
</div>
|
||||
<div className="table-container">
|
||||
@@ -343,9 +230,9 @@ function FilteringTab({ status, pendingPattern, onSetAction, onSetDisabled, onPe
|
||||
<tr>
|
||||
<th>Model</th>
|
||||
<th style={{ width: 120 }}>Backend</th>
|
||||
<th style={{ width: 80 }}>PII</th>
|
||||
<th style={{ width: 120 }}>PII</th>
|
||||
<th style={{ width: 110 }}>Source</th>
|
||||
<th>Pattern overrides</th>
|
||||
<th>Detectors</th>
|
||||
<th style={{ width: 80 }}>Edit</th>
|
||||
</tr>
|
||||
</thead>
|
||||
@@ -354,13 +241,29 @@ function FilteringTab({ status, pendingPattern, onSetAction, onSetDisabled, onPe
|
||||
<tr key={m.name}>
|
||||
<td style={{ fontFamily: 'var(--font-mono)', fontSize: '0.8125rem' }}>{m.name}</td>
|
||||
<td style={{ fontFamily: 'var(--font-mono)', fontSize: '0.75rem', color: 'var(--color-text-muted)' }}>{m.backend || '—'}</td>
|
||||
<td>{enabledBadge(m.enabled)}</td>
|
||||
<td>
|
||||
<span style={{ display: 'inline-flex', alignItems: 'center', gap: 6 }}>
|
||||
<Toggle
|
||||
checked={!!m.enabled}
|
||||
disabled={piiBusy.has(m.name)}
|
||||
onChange={(v) => togglePII(m.name, v)}
|
||||
/>
|
||||
{m.enabled && (!m.detectors || m.detectors.length === 0) && (
|
||||
<span
|
||||
title="Enabled but no detector resolved — nothing is scanned. Toggle a detector's Default on above, or add pii.detectors to the model."
|
||||
style={{ fontSize: '0.6875rem', fontWeight: 600, color: 'var(--color-warning)', whiteSpace: 'nowrap', cursor: 'help' }}
|
||||
>
|
||||
<i className="fas fa-triangle-exclamation" style={{ marginRight: 3 }} />no-op
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
</td>
|
||||
<td style={{ fontSize: '0.6875rem', color: 'var(--color-text-muted)' }}>
|
||||
{m.explicit ? 'YAML' : (m.default_for_backend ? 'backend default' : 'default off')}
|
||||
</td>
|
||||
<td style={{ fontSize: '0.75rem', fontFamily: 'var(--font-mono)' }}>
|
||||
{m.overrides && Object.keys(m.overrides).length > 0
|
||||
? Object.entries(m.overrides).map(([k, v]) => `${k}=${v}`).join(', ')
|
||||
{m.detectors && m.detectors.length > 0
|
||||
? <>{m.detectors.join(', ')}{m.detectors_from_default && <span style={{ color: 'var(--color-text-muted)', fontFamily: 'var(--font-sans)' }}> (default)</span>}</>
|
||||
: <span style={{ color: 'var(--color-text-muted)' }}>—</span>}
|
||||
</td>
|
||||
<td>
|
||||
@@ -391,6 +294,147 @@ function FilteringTab({ status, pendingPattern, onSetAction, onSetDisabled, onPe
|
||||
)
|
||||
}
|
||||
|
||||
// detectorTypeBadge labels a detector model by how it matches: a neural NER
|
||||
// token-classifier vs an in-process restricted-regex pattern matcher. `unknown`
|
||||
// is a default that names a model no longer loaded.
|
||||
function detectorTypeBadge(type) {
|
||||
const map = {
|
||||
ner: { label: 'NER', color: 'var(--color-primary)' },
|
||||
pattern: { label: 'pattern', color: 'var(--color-data-2, var(--color-warning))' },
|
||||
unknown: { label: 'not loaded', color: 'var(--color-text-muted)' },
|
||||
}
|
||||
const t = map[type] || map.unknown
|
||||
return (
|
||||
<span style={{
|
||||
display: 'inline-block',
|
||||
padding: '2px 8px',
|
||||
fontSize: '0.6875rem',
|
||||
fontWeight: 600,
|
||||
borderRadius: 'var(--radius-sm)',
|
||||
background: t.color,
|
||||
color: 'white',
|
||||
fontFamily: 'var(--font-mono)',
|
||||
textTransform: 'uppercase',
|
||||
}}>
|
||||
{t.label}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
// DetectorModels lists the token_classify "filter" models (NER + in-process
|
||||
// pattern matchers) and, via a per-row toggle, manages the instance-wide
|
||||
// default detector set (RuntimeSettings.pii_default_detectors, saved via POST
|
||||
// /api/settings). A detector toggled on is applied to any PII-enabled model
|
||||
// that names none of its own — chiefly cloud-proxy / MITM models, which are
|
||||
// PII-enabled by default but carry no detector list. Per-model `pii.detectors`
|
||||
// always overrides. This replaces the old model-multiselect chooser: the table
|
||||
// shows every available detector, so admins toggle defaults instead of retyping
|
||||
// names, and link straight to each detector's config to edit its policy.
|
||||
function DetectorModels({ pii, addToast, onChanged }) {
|
||||
const navigate = useNavigate()
|
||||
const location = useLocation()
|
||||
const rows = useMemo(() => pii.detector_models || [], [pii.detector_models])
|
||||
// Names currently in the default set; the toggle adds/removes against this.
|
||||
const defaults = useMemo(() => pii.default_detectors || [], [pii.default_detectors])
|
||||
// Track which rows are mid-save to disable just that toggle (optimistic).
|
||||
const [busy, setBusy] = useState(() => new Set())
|
||||
|
||||
const toggleDefault = async (name, on) => {
|
||||
const next = on
|
||||
? [...new Set([...defaults, name])]
|
||||
: defaults.filter(d => d !== name)
|
||||
setBusy(prev => new Set(prev).add(name))
|
||||
try {
|
||||
const body = await settingsApi.save({ pii_default_detectors: next })
|
||||
if (body && body.success === false) throw new Error(body.error || 'unknown error')
|
||||
addToast?.(on ? `${name} added to default detectors` : `${name} removed from default detectors`, 'success')
|
||||
onChanged?.()
|
||||
} catch (err) {
|
||||
addToast?.(`Failed to save: ${err.message}`, 'error')
|
||||
} finally {
|
||||
setBusy(prev => { const n = new Set(prev); n.delete(name); return n })
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="card" style={{ padding: 'var(--spacing-md)', marginBottom: 'var(--spacing-md)' }}>
|
||||
<div style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between', marginBottom: 'var(--spacing-sm)', gap: 'var(--spacing-sm)', flexWrap: 'wrap' }}>
|
||||
<span style={{ fontSize: '0.875rem', fontWeight: 600 }}>Detector models</span>
|
||||
<button
|
||||
className="btn btn-secondary btn-sm"
|
||||
onClick={() => navigate('/app/model-editor?template=secret-filter', { state: fromState(location, 'Middleware') })}
|
||||
title="Add a NER or pattern detector model"
|
||||
>
|
||||
<i className="fas fa-plus" /> Add detector model
|
||||
</button>
|
||||
</div>
|
||||
<div style={{ fontSize: '0.8125rem', color: 'var(--color-text-secondary)', marginBottom: 'var(--spacing-sm)' }}>
|
||||
These token_classify models do the scanning. Toggle <strong>Default</strong> on to apply a
|
||||
detector to any PII-enabled model that names none of its own (chiefly cloud-proxy / MITM models).
|
||||
Per-model <code>pii.detectors</code> always overrides. Edit a detector to change which entities it
|
||||
flags and what action it takes.
|
||||
</div>
|
||||
|
||||
<div className="table-container">
|
||||
<table className="table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Detector model</th>
|
||||
<th style={{ width: 110 }}>Type</th>
|
||||
<th style={{ width: 120 }}>Backend</th>
|
||||
<th style={{ width: 110 }}>Default</th>
|
||||
<th style={{ width: 80 }}>Edit</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{rows.map(d => (
|
||||
<tr key={d.name}>
|
||||
<td style={{ fontFamily: 'var(--font-mono)', fontSize: '0.8125rem', fontWeight: 600 }}>
|
||||
{d.missing
|
||||
? <span title="This default detector names a model that is not loaded.">{d.name}</span>
|
||||
: <Link to={`/app/model-editor/${encodeURIComponent(d.name)}`} state={fromState(location, 'Middleware')} title={`Edit ${d.name}.yaml`}>{d.name}</Link>}
|
||||
</td>
|
||||
<td>{detectorTypeBadge(d.type)}</td>
|
||||
<td style={{ fontFamily: 'var(--font-mono)', fontSize: '0.75rem', color: 'var(--color-text-muted)' }}>{d.backend || '—'}</td>
|
||||
<td>
|
||||
<Toggle
|
||||
checked={!!d.default}
|
||||
disabled={busy.has(d.name)}
|
||||
onChange={(v) => toggleDefault(d.name, v)}
|
||||
/>
|
||||
</td>
|
||||
<td>
|
||||
{d.missing ? (
|
||||
<span style={{ fontSize: '0.6875rem', color: 'var(--color-text-muted)' }}>—</span>
|
||||
) : (
|
||||
<Link
|
||||
to={`/app/model-editor/${encodeURIComponent(d.name)}`}
|
||||
state={fromState(location, 'Middleware')}
|
||||
className="btn btn-secondary btn-sm"
|
||||
style={{ fontSize: '0.6875rem', padding: '2px 8px' }}
|
||||
title={`Edit ${d.name}.yaml`}
|
||||
>
|
||||
<i className="fas fa-pen-to-square" /> Edit
|
||||
</Link>
|
||||
)}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
{rows.length === 0 && (
|
||||
<tr>
|
||||
<td colSpan={5} style={{ textAlign: 'center', color: 'var(--color-text-muted)', padding: 'var(--spacing-md)' }}>
|
||||
No detector models loaded. Add one with the button above (a token_classify NER model
|
||||
or a built-in secret pattern model).
|
||||
</td>
|
||||
</tr>
|
||||
)}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// decisionActiveSet rebuilds the Set of active labels from a
|
||||
// DecisionRecord's comma-joined `label` column. Used by both the
|
||||
// collapsed-row score suffix and the expanded-row bar rendering.
|
||||
@@ -1011,7 +1055,7 @@ function EventsTab({ events }) {
|
||||
<div className="empty-state-icon"><i className="fas fa-list-ul" /></div>
|
||||
<h2 className="empty-state-title">No events</h2>
|
||||
<p className="empty-state-text">
|
||||
Events appear here when the PII filter matches a pattern, when the MITM proxy decides whether
|
||||
Events appear here when a PII detector flags an entity, when the MITM proxy decides whether
|
||||
to intercept a hostname, or when an intercepted request finishes. Request bodies are never
|
||||
stored — use the API and backend traces for that.
|
||||
</p>
|
||||
|
||||
@@ -31,13 +31,23 @@ const SECTION_COLORS = {
|
||||
mitm: 'var(--color-warning)', pii: 'var(--color-error)', other: 'var(--color-text-muted)',
|
||||
}
|
||||
|
||||
function flattenConfig(obj, prefix = '') {
|
||||
// flattenConfig turns a parsed YAML config into a flat { 'a.b.c': value }
|
||||
// map keyed by the same dotted paths the field registry uses. leafPaths is
|
||||
// the set of registered schema leaf paths: recursion STOPS at any of them so
|
||||
// a map-typed field (e.g. pii_detection.entity_actions, a {GROUP: action}
|
||||
// object) is stored whole at its own path. Without this guard a map's value
|
||||
// was scattered into `pii_detection.entity_actions.SSN` etc. — paths that
|
||||
// match no registered field — so the editor rendered neither the field nor
|
||||
// its values, hiding per-entity policy like SSN→block from the operator.
|
||||
function flattenConfig(obj, leafPaths, prefix = '') {
|
||||
const result = {}
|
||||
if (!obj || typeof obj !== 'object') return result
|
||||
for (const [key, val] of Object.entries(obj)) {
|
||||
const path = prefix ? `${prefix}.${key}` : key
|
||||
if (val !== null && typeof val === 'object' && !Array.isArray(val)) {
|
||||
Object.assign(result, flattenConfig(val, path))
|
||||
if (leafPaths && leafPaths.has(path)) {
|
||||
result[path] = val
|
||||
} else if (val !== null && typeof val === 'object' && !Array.isArray(val)) {
|
||||
Object.assign(result, flattenConfig(val, leafPaths, path))
|
||||
} else {
|
||||
result[path] = val
|
||||
}
|
||||
@@ -82,6 +92,16 @@ export default function ModelEditor() {
|
||||
const { addToast } = useOutletContext()
|
||||
const { sections, fields, loading: metaLoading, error: metaError } = useConfigMetadata()
|
||||
|
||||
// Registered schema leaf paths. flattenConfig stops recursing at these so
|
||||
// map-typed fields (e.g. pii_detection.entity_actions) bind as a whole
|
||||
// object to their registered editor instead of vanishing into sub-paths.
|
||||
const leafPaths = useMemo(() => new Set(fields.map(f => f.path)), [fields])
|
||||
|
||||
// The parsed (not-yet-flattened) config loaded from the server. Flattening
|
||||
// is deferred to a separate effect keyed on leafPaths so the schema metadata
|
||||
// can arrive after the config without a fetch race re-clobbering values.
|
||||
const [loadedConfig, setLoadedConfig] = useState(null)
|
||||
|
||||
const isCreateMode = !name
|
||||
const [selectedTemplate, setSelectedTemplate] = useState(null)
|
||||
|
||||
@@ -123,7 +143,9 @@ export default function ModelEditor() {
|
||||
}
|
||||
}, [isCreateMode, searchParams, handleSelectTemplate])
|
||||
|
||||
// Load raw YAML config (edit mode only)
|
||||
// Load raw YAML config (edit mode only). This only fetches + parses; the
|
||||
// flatten-into-form-values step is the separate effect below so it can
|
||||
// re-run when the schema metadata (leafPaths) resolves without re-fetching.
|
||||
useEffect(() => {
|
||||
if (!name) return
|
||||
modelsApi.getEditConfig(name)
|
||||
@@ -131,26 +153,29 @@ export default function ModelEditor() {
|
||||
const raw = data?.config || ''
|
||||
setYamlText(raw)
|
||||
setSavedYamlText(raw)
|
||||
|
||||
// Parse YAML to get only the fields actually present in the file
|
||||
try {
|
||||
const parsed = YAML.parse(raw)
|
||||
const flat = flattenConfig(parsed || {})
|
||||
const active = new Set(Object.keys(flat))
|
||||
setValues(flat)
|
||||
setInitialValues(structuredClone(flat))
|
||||
setActiveFieldPaths(active)
|
||||
setLoadedConfig(YAML.parse(raw) || {})
|
||||
} catch {
|
||||
// If YAML parsing fails, start with empty state
|
||||
setValues({})
|
||||
setInitialValues({})
|
||||
setActiveFieldPaths(new Set())
|
||||
setLoadedConfig({})
|
||||
}
|
||||
})
|
||||
.catch(err => addToast(`Failed to load config: ${err.message}`, 'error'))
|
||||
.finally(() => setConfigLoading(false))
|
||||
}, [name, addToast])
|
||||
|
||||
// Flatten the loaded config into form values. Keyed on leafPaths so a late
|
||||
// schema-metadata resolution re-flattens (keeping map fields whole) WITHOUT
|
||||
// re-fetching — avoiding a two-fetch race that could clobber values. Only
|
||||
// fires on (re)load: loadedConfig changes per model, leafPaths is stable
|
||||
// once metadata is in, so this never stomps in-progress edits.
|
||||
useEffect(() => {
|
||||
if (loadedConfig === null) return
|
||||
const flat = flattenConfig(loadedConfig, leafPaths)
|
||||
setValues(flat)
|
||||
setInitialValues(structuredClone(flat))
|
||||
setActiveFieldPaths(new Set(Object.keys(flat)))
|
||||
}, [loadedConfig, leafPaths])
|
||||
|
||||
// Build field lookup
|
||||
const fieldsByPath = useMemo(() => {
|
||||
const map = {}
|
||||
@@ -325,7 +350,7 @@ export default function ModelEditor() {
|
||||
try {
|
||||
const parsed = YAML.parse(yamlText)
|
||||
parsedName = parsed?.name ?? null
|
||||
const flat = flattenConfig(parsed || {})
|
||||
const flat = flattenConfig(parsed || {}, leafPaths)
|
||||
setValues(flat)
|
||||
setInitialValues(structuredClone(flat))
|
||||
setActiveFieldPaths(new Set(Object.keys(flat)))
|
||||
|
||||
@@ -36,6 +36,7 @@ const FILTERS = [
|
||||
{ key: 'rerank', labelKey: 'filters.rerank', icon: 'fa-sort' },
|
||||
{ key: 'detection', labelKey: 'filters.detection', icon: 'fa-bullseye' },
|
||||
{ key: 'vad', labelKey: 'filters.vad', icon: 'fa-wave-square' },
|
||||
{ key: 'token_classify', labelKey: 'filters.ner', icon: 'fa-tags' },
|
||||
]
|
||||
|
||||
export default function Models() {
|
||||
|
||||
@@ -75,6 +75,8 @@ const TYPE_COLORS = {
|
||||
detection: { bg: 'var(--color-info-light)', color: 'var(--color-data-8)' },
|
||||
model_load: { bg: 'var(--color-error-light)', color: 'var(--color-data-2)' },
|
||||
vector_store: { bg: 'var(--color-accent-light)', color: 'var(--color-data-7)' },
|
||||
token_classify: { bg: 'var(--color-info-light)', color: 'var(--color-data-3)' },
|
||||
pattern_pii: { bg: 'var(--color-error-light)', color: 'var(--color-data-2)' },
|
||||
}
|
||||
|
||||
function typeBadgeStyle(type) {
|
||||
|
||||
1
core/http/react-ui/src/utils/capabilities.js
vendored
1
core/http/react-ui/src/utils/capabilities.js
vendored
@@ -22,3 +22,4 @@ export const CAP_SPEAKER_RECOGNITION = 'FLAG_SPEAKER_RECOGNITION'
|
||||
export const CAP_AUDIO_TRANSFORM = 'FLAG_AUDIO_TRANSFORM'
|
||||
export const CAP_REALTIME_AUDIO = 'FLAG_REALTIME_AUDIO'
|
||||
export const CAP_SCORE = 'FLAG_SCORE'
|
||||
export const CAP_TOKEN_CLASSIFY = 'FLAG_TOKEN_CLASSIFY'
|
||||
|
||||
28
core/http/react-ui/src/utils/modelTemplates.js
vendored
28
core/http/react-ui/src/utils/modelTemplates.js
vendored
@@ -146,22 +146,38 @@ const MODEL_TEMPLATES = [
|
||||
id: 'mitm',
|
||||
label: 'MITM Intercept',
|
||||
icon: 'fa-shield-halved',
|
||||
description: 'Bind a hostname to this config for the cloudproxy MITM listener. PII filtering and pattern overrides flow from this config when the host is intercepted.',
|
||||
description: 'Bind a hostname to this config for the cloudproxy MITM listener. PII filtering (the NER detectors listed here) is applied to intercepted request bodies for the host.',
|
||||
// The mitm- name prefix is a convention, not a contract — the
|
||||
// dispatcher looks up by host, not name. Prefixing keeps the
|
||||
// config out of the way of callable model names so a chat client
|
||||
// accidentally requesting "anthropic" doesn't hit a backendless
|
||||
// intercept config.
|
||||
//
|
||||
// pii.patterns is pre-seeded with an empty list so the override
|
||||
// editor is visible by default — admins typically want to tighten
|
||||
// a couple of pattern actions when intercepting a cloud provider.
|
||||
// An empty list serializes out and the redactor ignores it.
|
||||
// pii.detectors is pre-seeded empty so the detector picker is visible
|
||||
// by default — admins point it at a token_classify model whose
|
||||
// pii_detection block defines the policy.
|
||||
fields: {
|
||||
'name': 'mitm-anthropic',
|
||||
'mitm.hosts': ['api.anthropic.com'],
|
||||
'pii.enabled': true,
|
||||
'pii.patterns': [],
|
||||
'pii.detectors': [],
|
||||
},
|
||||
},
|
||||
{
|
||||
id: 'secret-filter',
|
||||
label: 'Secret Pattern Detector',
|
||||
icon: 'fa-key',
|
||||
description: 'An in-process token_classify detector that flags high-entropy secrets (API keys, tokens) with bounded restricted-regex patterns — no backend, no GGUF, zero VRAM. Enable the built-in provider patterns below and/or add your own under PII Detection. Reference it from a model\'s pii.detectors, or toggle it on as a default detector on the Middleware page.',
|
||||
fields: {
|
||||
'name': 'secret-filter',
|
||||
'backend': 'pattern',
|
||||
'known_usecases': ['token_classify'],
|
||||
'pii_detection.default_action': 'block',
|
||||
'pii_detection.builtins': [
|
||||
'anthropic_api_key', 'openai_api_key', 'github_token', 'github_pat',
|
||||
'aws_access_key', 'google_api_key', 'slack_token', 'stripe_key',
|
||||
'jwt', 'private_key_block',
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@@ -22,8 +22,8 @@ import (
|
||||
|
||||
func RegisterAnthropicRoutes(app *echo.Echo,
|
||||
re *middleware.RequestExtractor,
|
||||
application *application.Application) {
|
||||
|
||||
application *application.Application,
|
||||
) {
|
||||
// Anthropic Messages API endpoint
|
||||
var natsClient mcpTools.MCPNATSClient
|
||||
if d := application.Distributed(); d != nil {
|
||||
@@ -36,8 +36,6 @@ func RegisterAnthropicRoutes(app *echo.Echo,
|
||||
application.TemplatesEvaluator(),
|
||||
application.ApplicationConfig(),
|
||||
natsClient,
|
||||
application.PIIRedactor(),
|
||||
application.PIIEvents(),
|
||||
)
|
||||
|
||||
messagesMiddleware := []echo.MiddlewareFunc{
|
||||
@@ -69,7 +67,7 @@ func RegisterAnthropicRoutes(app *echo.Echo,
|
||||
},
|
||||
),
|
||||
middleware.AdmissionControl(application.AdmissionLimiter(), application.PIIEvents()),
|
||||
pii.RequestMiddleware(application.PIIRedactor(), application.PIIEvents(), piiadapter.Anthropic(), application.FallbackUser()),
|
||||
pii.RequestMiddleware(application.PIIRedactor(), application.PIIEvents(), piiadapter.Anthropic(), application.FallbackUser(), pii.WithNERResolver(application.PIINERResolver()), pii.WithPolicyResolver(application.PIIPolicyResolver())),
|
||||
}
|
||||
|
||||
// Main Anthropic endpoint
|
||||
|
||||
@@ -299,53 +299,85 @@ func buildAdmissionStatus(app *application.Application) map[string]any {
|
||||
}
|
||||
|
||||
// buildPIIStatus builds the pii section of /api/middleware/status. It
|
||||
// reads the live redactor, walks every model config, and reports the
|
||||
// resolved enabled state plus any per-pattern overrides — that's what
|
||||
// the admin page renders side-by-side so the operator can see at a
|
||||
// glance which models are protected.
|
||||
//
|
||||
// Returns a sentinel "disabled" payload when the redactor is nil
|
||||
// (--disable-pii), letting the page show "filter switched off" rather
|
||||
// than a confusing empty state.
|
||||
// walks every model config and reports the resolved enabled state plus
|
||||
// the NER detector models each one references — that's what the admin
|
||||
// page renders so the operator can see at a glance which models are
|
||||
// protected and by which detectors. The detection policy itself
|
||||
// (entity→action, min score) lives on each detector model's
|
||||
// pii_detection block.
|
||||
func buildPIIStatus(app *application.Application) map[string]any {
|
||||
redactor := app.PIIRedactor()
|
||||
if redactor == nil {
|
||||
return map[string]any{
|
||||
"enabled_globally": false,
|
||||
"reason": "--disable-pii",
|
||||
"patterns": []any{},
|
||||
"models": []any{},
|
||||
}
|
||||
}
|
||||
|
||||
patterns := redactor.Patterns()
|
||||
patternList := make([]map[string]any, 0, len(patterns))
|
||||
for _, p := range patterns {
|
||||
patternList = append(patternList, map[string]any{
|
||||
"id": p.ID,
|
||||
"description": p.Description,
|
||||
"action": string(p.Action),
|
||||
"disabled": p.Disabled,
|
||||
"max_match_length": p.MaxMatchLength,
|
||||
})
|
||||
}
|
||||
|
||||
appCfg := app.ApplicationConfig()
|
||||
models := []map[string]any{}
|
||||
for _, cfg := range app.ModelConfigLoader().GetAllModelsConfigs() {
|
||||
// Only list models PII filtering can actually apply to (reachable
|
||||
// through a text-accepting endpoint with a PII adapter wired).
|
||||
// Skips VAD/STT/embedding/image-only models and the token_classify
|
||||
// detector models themselves, which are the filters, not consumers.
|
||||
if !cfg.PIIFilterApplies() {
|
||||
continue
|
||||
}
|
||||
explicit := cfg.PII.Enabled != nil
|
||||
ownDetectors := cfg.PIIDetectors()
|
||||
// Resolve through the shared policy so the table reflects the EFFECTIVE
|
||||
// state, including the instance-wide default detector — what the
|
||||
// request path actually does.
|
||||
enabled, detectors := app.ResolvePIIPolicy(&cfg)
|
||||
|
||||
entry := map[string]any{
|
||||
"name": cfg.Name,
|
||||
"backend": cfg.Backend,
|
||||
"enabled": cfg.PIIIsEnabled(),
|
||||
"overrides": cfg.PIIPatternOverrides(),
|
||||
"enabled": enabled,
|
||||
"detectors": detectors,
|
||||
"explicit": explicit,
|
||||
// Why is this on? backend default (cloud-proxy) vs an explicit YAML
|
||||
// toggle. Helps admins understand the resolved state without
|
||||
// reading source.
|
||||
"default_for_backend": !explicit && cfg.Backend == "cloud-proxy",
|
||||
// The detectors came from the global default, not this model's YAML.
|
||||
"detectors_from_default": enabled && len(ownDetectors) == 0 && len(detectors) > 0,
|
||||
}
|
||||
// explicit-set tells the UI whether the resolved state came
|
||||
// from the YAML or the backend-prefix default. Helps admins
|
||||
// understand "why is this on?" without reading source.
|
||||
entry["explicit"] = cfg.PII.Enabled != nil
|
||||
entry["default_for_backend"] = cfg.Backend == "cloud-proxy"
|
||||
models = append(models, entry)
|
||||
}
|
||||
|
||||
// Detector models: the token_classify "filter" models themselves (NER and
|
||||
// in-process pattern matchers), which PIIFilterApplies deliberately omits
|
||||
// from the consumer list above. The Filtering tab renders these as a table
|
||||
// with a per-row toggle marking membership in the instance-wide default
|
||||
// detector set, so admins manage defaults without retyping model names.
|
||||
defaultSet := map[string]bool{}
|
||||
for _, d := range appCfg.PIIDefaultDetectors {
|
||||
defaultSet[d] = true
|
||||
}
|
||||
detectorModels := []map[string]any{}
|
||||
for _, cfg := range app.ModelConfigLoader().GetAllModelsConfigs() {
|
||||
if !cfg.HasUsecases(config.FLAG_TOKEN_CLASSIFY) {
|
||||
continue
|
||||
}
|
||||
typ := "ner"
|
||||
if cfg.IsPatternDetector() {
|
||||
typ = "pattern"
|
||||
}
|
||||
detectorModels = append(detectorModels, map[string]any{
|
||||
"name": cfg.Name,
|
||||
"backend": cfg.Backend,
|
||||
"type": typ,
|
||||
// Whether this detector is in the instance-wide default set.
|
||||
"default": defaultSet[cfg.Name],
|
||||
})
|
||||
delete(defaultSet, cfg.Name)
|
||||
}
|
||||
// Surface any default detector that names a model that is no longer loaded
|
||||
// (or lost the token_classify usecase) so the admin can still toggle it off.
|
||||
for name := range defaultSet {
|
||||
detectorModels = append(detectorModels, map[string]any{
|
||||
"name": name,
|
||||
"backend": "",
|
||||
"type": "unknown",
|
||||
"default": true,
|
||||
"missing": true,
|
||||
})
|
||||
}
|
||||
|
||||
recentCount := 0
|
||||
if app.PIIEvents() != nil {
|
||||
if n, err := app.PIIEvents().Count(context.Background()); err == nil {
|
||||
@@ -356,8 +388,10 @@ func buildPIIStatus(app *application.Application) map[string]any {
|
||||
return map[string]any{
|
||||
"enabled_globally": true,
|
||||
"default_enabled_for_backends": []string{"cloud-proxy"},
|
||||
"patterns": patternList,
|
||||
"models": models,
|
||||
"detector_models": detectorModels,
|
||||
"recent_event_count": recentCount,
|
||||
// Instance-wide default policy (the Default PII policy editor).
|
||||
"default_detectors": appCfg.PIIDefaultDetectors,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,13 +10,15 @@ import (
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/ollama"
|
||||
"github.com/mudler/LocalAI/core/http/middleware"
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/piiadapter"
|
||||
"github.com/mudler/LocalAI/pkg/distributedhdr"
|
||||
)
|
||||
|
||||
func RegisterOllamaRoutes(app *echo.Echo,
|
||||
re *middleware.RequestExtractor,
|
||||
application *application.Application) {
|
||||
|
||||
application *application.Application,
|
||||
) {
|
||||
traceMiddleware := middleware.TraceMiddleware(application)
|
||||
usageMiddleware := middleware.UsageMiddleware(application.StatsRecorder(), application.FallbackUser())
|
||||
nodeHeaderMiddleware := middleware.ExposeNodeHeader(application.ApplicationConfig())
|
||||
@@ -35,6 +37,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OllamaChatRequest) }),
|
||||
setOllamaChatRequestContext(application.ApplicationConfig()),
|
||||
pii.RequestMiddleware(application.PIIRedactor(), application.PIIEvents(), piiadapter.OllamaChat(), application.FallbackUser(), pii.WithNERResolver(application.PIINERResolver()), pii.WithPolicyResolver(application.PIIPolicyResolver())),
|
||||
}
|
||||
app.POST("/api/chat", chatHandler, chatMiddleware...)
|
||||
|
||||
@@ -52,6 +55,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OllamaGenerateRequest) }),
|
||||
setOllamaGenerateRequestContext(application.ApplicationConfig()),
|
||||
pii.RequestMiddleware(application.PIIRedactor(), application.PIIEvents(), piiadapter.OllamaGenerate(), application.FallbackUser(), pii.WithNERResolver(application.PIINERResolver()), pii.WithPolicyResolver(application.PIIPolicyResolver())),
|
||||
}
|
||||
app.POST("/api/generate", generateHandler, generateMiddleware...)
|
||||
|
||||
@@ -67,6 +71,7 @@ func RegisterOllamaRoutes(app *echo.Echo,
|
||||
traceMiddleware,
|
||||
re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)),
|
||||
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OllamaEmbedRequest) }),
|
||||
pii.RequestMiddleware(application.PIIRedactor(), application.PIIEvents(), piiadapter.OllamaEmbed(), application.FallbackUser(), pii.WithNERResolver(application.PIINERResolver()), pii.WithPolicyResolver(application.PIIPolicyResolver())),
|
||||
}
|
||||
app.POST("/api/embed", embedHandler, embedMiddleware...)
|
||||
app.POST("/api/embeddings", embedHandler, embedMiddleware...)
|
||||
|
||||
@@ -16,7 +16,8 @@ import (
|
||||
|
||||
func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
re *middleware.RequestExtractor,
|
||||
application *application.Application) {
|
||||
application *application.Application,
|
||||
) {
|
||||
// openAI compatible API endpoint
|
||||
traceMiddleware := middleware.TraceMiddleware(application)
|
||||
usageMiddleware := middleware.UsageMiddleware(application.StatsRecorder(), application.FallbackUser())
|
||||
@@ -42,7 +43,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
}
|
||||
|
||||
// chat
|
||||
chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig(), natsClient, application.LocalAIAssistant(), application.PIIRedactor(), application.PIIEvents())
|
||||
chatHandler := openai.ChatEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig(), natsClient, application.LocalAIAssistant())
|
||||
chatMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
@@ -91,7 +92,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
// configs honour the routed target (e.g., a router fans out to
|
||||
// claude-strict; that model's pii block applies, not the
|
||||
// router model's).
|
||||
pii.RequestMiddleware(application.PIIRedactor(), application.PIIEvents(), piiadapter.OpenAI(), application.FallbackUser()),
|
||||
pii.RequestMiddleware(application.PIIRedactor(), application.PIIEvents(), piiadapter.OpenAI(), application.FallbackUser(), pii.WithNERResolver(application.PIINERResolver()), pii.WithPolicyResolver(application.PIIPolicyResolver())),
|
||||
}
|
||||
app.POST("/v1/chat/completions", chatHandler, chatMiddleware...)
|
||||
app.POST("/chat/completions", chatHandler, chatMiddleware...)
|
||||
@@ -112,12 +113,13 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
pii.RequestMiddleware(application.PIIRedactor(), application.PIIEvents(), piiadapter.OpenAICompletion(), application.FallbackUser(), pii.WithNERResolver(application.PIINERResolver()), pii.WithPolicyResolver(application.PIIPolicyResolver())),
|
||||
}
|
||||
app.POST("/v1/edits", editHandler, editMiddleware...)
|
||||
app.POST("/edits", editHandler, editMiddleware...)
|
||||
|
||||
// completion
|
||||
completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig(), application.PIIRedactor(), application.PIIEvents())
|
||||
completionHandler := openai.CompletionEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.TemplatesEvaluator(), application.ApplicationConfig())
|
||||
completionMiddleware := []echo.MiddlewareFunc{
|
||||
nodeHeaderMiddleware,
|
||||
usageMiddleware,
|
||||
@@ -133,6 +135,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
pii.RequestMiddleware(application.PIIRedactor(), application.PIIEvents(), piiadapter.OpenAICompletion(), application.FallbackUser(), pii.WithNERResolver(application.PIINERResolver()), pii.WithPolicyResolver(application.PIIPolicyResolver())),
|
||||
}
|
||||
app.POST("/v1/completions", completionHandler, completionMiddleware...)
|
||||
app.POST("/completions", completionHandler, completionMiddleware...)
|
||||
@@ -155,6 +158,7 @@ func RegisterOpenAIRoutes(app *echo.Echo,
|
||||
return next(c)
|
||||
}
|
||||
},
|
||||
pii.RequestMiddleware(application.PIIRedactor(), application.PIIEvents(), piiadapter.OpenAICompletion(), application.FallbackUser(), pii.WithNERResolver(application.PIINERResolver()), pii.WithPolicyResolver(application.PIIPolicyResolver())),
|
||||
}
|
||||
app.POST("/v1/embeddings", embeddingHandler, embeddingMiddleware...)
|
||||
app.POST("/embeddings", embeddingHandler, embeddingMiddleware...)
|
||||
|
||||
@@ -6,58 +6,30 @@ import (
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
|
||||
// RegisterPIIRoutes wires the read-only routing-PII endpoints. They
|
||||
// surface (a) the active pattern set so admins can verify what is
|
||||
// being filtered, (b) the recent PIIEvent log so they can audit what
|
||||
// has been redacted, and (c) a dry-run "test" endpoint so an admin
|
||||
// can paste candidate text and see what the redactor would do without
|
||||
// sending a real request.
|
||||
// RegisterPIIRoutes wires the read-only PII audit endpoint. The
|
||||
// detection itself runs request-side from the chat middleware
|
||||
// (routes/openai.go) and the MITM input path, driven by per-model NER
|
||||
// detectors; this endpoint is observation-side only.
|
||||
//
|
||||
// The redactor itself runs from the chat middleware in routes/openai.go;
|
||||
// these endpoints are observation- and configuration-side only.
|
||||
// The legacy regex tier (pattern catalogue + per-pattern action editor
|
||||
// + dry-run/decide oracles) was removed — policy now lives on each
|
||||
// detector model's pii_detection block, so there is nothing global to
|
||||
// list or mutate here.
|
||||
func RegisterPIIRoutes(e *echo.Echo, app *application.Application) {
|
||||
if app.PIIRedactor() == nil {
|
||||
stub := func(c echo.Context) error {
|
||||
if app.PIIEvents() == nil {
|
||||
e.GET("/api/pii/events", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusServiceUnavailable, map[string]string{
|
||||
"error": "PII filter is disabled (--disable-pii)",
|
||||
"error": "PII subsystem unavailable",
|
||||
})
|
||||
}
|
||||
e.GET("/api/pii/patterns", stub)
|
||||
e.GET("/api/pii/events", stub)
|
||||
e.POST("/api/pii/test", stub)
|
||||
e.POST("/api/pii/decide", stub)
|
||||
e.POST("/api/pii/patterns/persist", stub)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// GetPIIPatternsEndpoint godoc
|
||||
// @Summary List the active PII patterns
|
||||
// @Description Returns the configured pattern set with their actions. Available without auth.
|
||||
// @Tags pii
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/pii/patterns [get]
|
||||
e.GET("/api/pii/patterns", func(c echo.Context) error {
|
||||
patterns := app.PIIRedactor().Patterns()
|
||||
out := make([]map[string]any, 0, len(patterns))
|
||||
for _, p := range patterns {
|
||||
out = append(out, map[string]any{
|
||||
"id": p.ID,
|
||||
"description": p.Description,
|
||||
"action": string(p.Action),
|
||||
"disabled": p.Disabled,
|
||||
"max_match_length": p.MaxMatchLength,
|
||||
})
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]any{"patterns": out})
|
||||
})
|
||||
|
||||
// GetPIIEventsEndpoint godoc
|
||||
// @Summary List recent middleware events
|
||||
// @Description The event log is shared between the PII filter and the MITM proxy: PII redactions, proxy_connect (intercept decisions), and proxy_traffic (per-request byte counts) all flow through the same store. Filter by kind to narrow the view. Admin-only when auth is on; available to the local user in single-user mode.
|
||||
@@ -65,8 +37,9 @@ func RegisterPIIRoutes(e *echo.Echo, app *application.Application) {
|
||||
// @Produce json
|
||||
// @Param correlation_id query string false "Correlation ID join key"
|
||||
// @Param user_id query string false "User id"
|
||||
// @Param pattern_id query string false "Pattern id (e.g. email, ssn)"
|
||||
// @Param pattern_id query string false "Detector group id (e.g. ner:EMAIL, pattern:ANTHROPIC_KEY)"
|
||||
// @Param kind query string false "Event kind: pii | proxy_connect | proxy_traffic"
|
||||
// @Param origin query string false "Redaction origin: middleware | proxy | pii_analyze | pii_redact"
|
||||
// @Param limit query int false "Max events" default(100)
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/pii/events [get]
|
||||
@@ -91,6 +64,7 @@ func RegisterPIIRoutes(e *echo.Echo, app *application.Application) {
|
||||
UserID: c.QueryParam("user_id"),
|
||||
PatternID: c.QueryParam("pattern_id"),
|
||||
Kind: pii.EventKind(c.QueryParam("kind")),
|
||||
Origin: c.QueryParam("origin"),
|
||||
Limit: limit,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -99,162 +73,11 @@ func RegisterPIIRoutes(e *echo.Echo, app *application.Application) {
|
||||
return c.JSON(http.StatusOK, map[string]any{"events": events})
|
||||
})
|
||||
|
||||
// PostPIITestEndpoint godoc
|
||||
// @Summary Dry-run the PII redactor against text
|
||||
// @Description Useful for admins tuning patterns. Returns the redacted text, matched spans, and whether the input would have been blocked.
|
||||
// @Tags pii
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param body body map[string]string true "JSON {\"text\":\"...\"}"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/pii/test [post]
|
||||
e.POST("/api/pii/test", func(c echo.Context) error {
|
||||
var body struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if err := c.Bind(&body); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid JSON"})
|
||||
}
|
||||
res := app.PIIRedactor().Redact(body.Text)
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"redacted": res.Redacted,
|
||||
"spans": res.Spans,
|
||||
"blocked": res.Blocked,
|
||||
"masked": res.Masked,
|
||||
})
|
||||
})
|
||||
|
||||
// POST /api/pii/decide — programmatic PII decision oracle for
|
||||
// external routers. Returns findings + suggested action without
|
||||
// mutating the caller's request or recording an audit event.
|
||||
// Production hot path — admin-only, matching /api/pii/events.
|
||||
decideHandler := localai.PIIDecideEndpoint(app.PIIRedactor())
|
||||
e.POST("/api/pii/decide", func(c echo.Context) error {
|
||||
viewer := resolveUsageUser(c, app)
|
||||
if viewer == nil {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"})
|
||||
}
|
||||
if viewer.Role != auth.RoleAdmin {
|
||||
return c.JSON(http.StatusForbidden, map[string]string{"error": "admin access required"})
|
||||
}
|
||||
return decideHandler(c)
|
||||
})
|
||||
|
||||
// PutPIIPatternActionEndpoint godoc
|
||||
// @Summary Change a pattern's action in-process
|
||||
// @Description Mutates the named pattern's action (mask|block|allow). Transient — restored to YAML defaults on restart. Admin-only.
|
||||
// @Tags pii
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "Pattern id"
|
||||
// @Param body body map[string]string true "JSON {\"action\":\"mask|block|allow\"}"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/pii/patterns/{id} [put]
|
||||
e.PUT("/api/pii/patterns/:id", func(c echo.Context) error {
|
||||
viewer := resolveUsageUser(c, app)
|
||||
if viewer == nil {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"})
|
||||
}
|
||||
if viewer.Role != auth.RoleAdmin {
|
||||
return c.JSON(http.StatusForbidden, map[string]string{"error": "admin access required"})
|
||||
}
|
||||
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "pattern id is required"})
|
||||
}
|
||||
// Either field is optional. The body must set at least one;
|
||||
// otherwise the call is a no-op and the client probably means
|
||||
// to PUT something.
|
||||
var body struct {
|
||||
Action *string `json:"action,omitempty"`
|
||||
Disabled *bool `json:"disabled,omitempty"`
|
||||
}
|
||||
if err := c.Bind(&body); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid JSON"})
|
||||
}
|
||||
if body.Action == nil && body.Disabled == nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": "must specify action and/or disabled"})
|
||||
}
|
||||
if body.Action != nil {
|
||||
if err := app.PIIRedactor().SetAction(id, pii.Action(*body.Action)); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
if body.Disabled != nil {
|
||||
if err := app.PIIRedactor().SetDisabled(id, *body.Disabled); err != nil {
|
||||
return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"id": id,
|
||||
"action": body.Action,
|
||||
"disabled": body.Disabled,
|
||||
"persisted": false,
|
||||
})
|
||||
})
|
||||
|
||||
// PostPIIPatternsPersistEndpoint godoc
|
||||
// @Summary Persist current pattern overrides to disk
|
||||
// @Description Snapshots the live redactor's per-pattern (action, disabled) state into runtime_settings.json so the next process start re-applies it. Admin-only. Pairs with PUT /api/pii/patterns/:id which only mutates in-process.
|
||||
// @Tags pii
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/pii/patterns/persist [post]
|
||||
e.POST("/api/pii/patterns/persist", func(c echo.Context) error {
|
||||
viewer := resolveUsageUser(c, app)
|
||||
if viewer == nil {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"})
|
||||
}
|
||||
if viewer.Role != auth.RoleAdmin {
|
||||
return c.JSON(http.StatusForbidden, map[string]string{"error": "admin access required"})
|
||||
}
|
||||
|
||||
appCfg := app.ApplicationConfig()
|
||||
existing, err := appCfg.ReadPersistedSettings()
|
||||
if err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "read settings: " + err.Error()})
|
||||
}
|
||||
// Only persist patterns whose live state differs from the YAML
|
||||
// default — that way an operator can compare runtime_settings.json
|
||||
// at a glance and see only the deltas they applied.
|
||||
defaults, dErr := pii.LoadConfig(appCfg.PIIConfigPath)
|
||||
if dErr != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "reload defaults: " + dErr.Error()})
|
||||
}
|
||||
defaultByID := make(map[string]pii.Pattern, len(defaults))
|
||||
for _, d := range defaults {
|
||||
defaultByID[d.ID] = d
|
||||
}
|
||||
overrides := map[string]config.PIIPatternRuntimeOverride{}
|
||||
for _, p := range app.PIIRedactor().Patterns() {
|
||||
d, ok := defaultByID[p.ID]
|
||||
ov := config.PIIPatternRuntimeOverride{}
|
||||
changed := false
|
||||
if !ok || p.Action != d.Action {
|
||||
action := string(p.Action)
|
||||
ov.Action = &action
|
||||
changed = true
|
||||
}
|
||||
if !ok || p.Disabled != d.Disabled {
|
||||
disabled := p.Disabled
|
||||
ov.Disabled = &disabled
|
||||
changed = true
|
||||
}
|
||||
if changed {
|
||||
overrides[p.ID] = ov
|
||||
}
|
||||
}
|
||||
existing.PIIPatternOverrides = &overrides
|
||||
if err := appCfg.WritePersistedSettings(existing); err != nil {
|
||||
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "write settings: " + err.Error()})
|
||||
}
|
||||
// Mirror onto the live ApplicationConfig so a subsequent reload
|
||||
// without a process restart sees the same map.
|
||||
appCfg.PIIPatternOverrides = overrides
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"persisted": true,
|
||||
"override_count": len(overrides),
|
||||
})
|
||||
})
|
||||
// Synchronous redaction service: scan a string and either report the
|
||||
// detected entities (analyze) or apply the policy (redact). Unlike the
|
||||
// admin-only events log above, these are an inference-tier service gated
|
||||
// by the pii_filter feature (any authenticated user), so a client can use
|
||||
// LocalAI's PII engine without routing a full chat request through it.
|
||||
e.POST("/api/pii/analyze", localai.PIIAnalyzeEndpoint(app))
|
||||
e.POST("/api/pii/redact", localai.PIIRedactEndpoint(app))
|
||||
}
|
||||
|
||||
@@ -56,6 +56,7 @@ var usecaseFilters = map[string]config.ModelConfigUsecase{
|
||||
config.UsecaseAudioTransform: config.FLAG_AUDIO_TRANSFORM,
|
||||
config.UsecaseDiarization: config.FLAG_DIARIZATION,
|
||||
config.UsecaseRealtimeAudio: config.FLAG_REALTIME_AUDIO,
|
||||
config.UsecaseTokenClassify: config.FLAG_TOKEN_CLASSIFY,
|
||||
}
|
||||
|
||||
// extractHFRepo tries to find a HuggingFace repo ID from model overrides or URLs.
|
||||
|
||||
74
core/schema/pii.go
Normal file
74
core/schema/pii.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package schema
|
||||
|
||||
// PIIAnalyzeRequest is the body for POST /api/pii/analyze and
|
||||
// POST /api/pii/redact. The two endpoints share a request shape; only the
|
||||
// response differs (analyze never mutates text, redact applies policy).
|
||||
//
|
||||
// Detector selection is one of two ways:
|
||||
// - Detectors: explicit detector-model names (the primary path).
|
||||
// - Model: a consuming model name whose effective PII policy is used when
|
||||
// Detectors is empty — "what would this model do with this text?". The
|
||||
// policy resolves exactly as for the inline middleware: the model's own
|
||||
// pii.detectors, else the instance-wide pii_default_detectors, and
|
||||
// nothing when the model has PII disabled.
|
||||
//
|
||||
// One of the two must resolve to at least one detector, else the call is a
|
||||
// 400 — including a PII-enabled model with no detectors anywhere: the
|
||||
// middleware would scan nothing, and saying so loudly beats implying a clean
|
||||
// scan. The detection policy (mask/block/allow per entity group, min score)
|
||||
// lives on each detector model's own pii_detection block, exactly as for the
|
||||
// inline chat middleware.
|
||||
type PIIAnalyzeRequest struct {
|
||||
// Text is the string to scan. Bounded only by the server's global HTTP
|
||||
// body limit.
|
||||
Text string `json:"text"`
|
||||
// Detectors names the detector models to run (NER and/or pattern). Takes
|
||||
// precedence over Model.
|
||||
Detectors []string `json:"detectors,omitempty"`
|
||||
// Model is a consuming model whose effective PII policy (own
|
||||
// pii.detectors, else the instance default detectors; PII must be
|
||||
// enabled) is used when Detectors is empty.
|
||||
Model string `json:"model,omitempty"`
|
||||
// Reveal includes the per-entity hash_prefix in the response. Honoured
|
||||
// only for admin callers; ignored otherwise. The raw matched value is
|
||||
// never returned regardless.
|
||||
Reveal bool `json:"reveal,omitempty"`
|
||||
}
|
||||
|
||||
// PIIEntity is one detected span. EntityType is the detector group (e.g.
|
||||
// "EMAIL", "ANTHROPIC_KEY"); Source is the detector tier that produced it
|
||||
// ("ner" or "pattern"). Start/End are half-open byte offsets into the request
|
||||
// Text. Action is the policy action that fired after the overlap merge
|
||||
// (mask | block | allow). HashPrefix is present only for admin + reveal.
|
||||
type PIIEntity struct {
|
||||
EntityType string `json:"entity_type"`
|
||||
Source string `json:"source"`
|
||||
Start int `json:"start"`
|
||||
End int `json:"end"`
|
||||
Score float32 `json:"score"`
|
||||
Action string `json:"action"`
|
||||
HashPrefix string `json:"hash_prefix,omitempty"`
|
||||
}
|
||||
|
||||
// PIIAnalyzeResponse is returned by POST /api/pii/analyze (always 200). It
|
||||
// reports detections without mutating the text. Blocked is true when at
|
||||
// least one entity's action is block — i.e. the redact endpoint would reject
|
||||
// this text.
|
||||
type PIIAnalyzeResponse struct {
|
||||
Entities []PIIEntity `json:"entities"`
|
||||
Blocked bool `json:"blocked"`
|
||||
CorrelationID string `json:"correlation_id,omitempty"`
|
||||
}
|
||||
|
||||
// PIIRedactResponse is returned by POST /api/pii/redact when nothing blocks
|
||||
// (200). RedactedText is the input with masked spans replaced; Masked is true
|
||||
// when at least one span was replaced. When a block action fires the endpoint
|
||||
// returns 400 instead (with an error of type "pii_blocked" and the offending
|
||||
// entities), never a redacted body.
|
||||
type PIIRedactResponse struct {
|
||||
RedactedText string `json:"redacted_text"`
|
||||
Entities []PIIEntity `json:"entities"`
|
||||
Blocked bool `json:"blocked"`
|
||||
Masked bool `json:"masked"`
|
||||
CorrelationID string `json:"correlation_id,omitempty"`
|
||||
}
|
||||
@@ -10,8 +10,6 @@ import (
|
||||
"github.com/labstack/echo/v4"
|
||||
corebackend "github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
pkggrpc "github.com/mudler/LocalAI/pkg/grpc"
|
||||
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
@@ -19,41 +17,14 @@ import (
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// BuildStreamFilter constructs the per-request streaming PII filter
|
||||
// for a cloud-proxy forward. Returns nil when the request isn't
|
||||
// streaming, PII is disabled for this model, or no redactor is wired
|
||||
// up — callers pass the result through unchanged. correlationID is
|
||||
// caller-supplied because the OpenAI and Anthropic endpoints read it
|
||||
// from different headers.
|
||||
func BuildStreamFilter(c echo.Context, cfg *config.ModelConfig, isStream bool, piiRedactor *pii.Redactor, piiEvents pii.EventStore, correlationID string) *pii.StreamFilter {
|
||||
if !isStream || piiRedactor == nil || !cfg.PIIIsEnabled() {
|
||||
return nil
|
||||
}
|
||||
userID := ""
|
||||
if u := auth.GetUser(c); u != nil {
|
||||
userID = u.ID
|
||||
}
|
||||
var overrides map[string]pii.Action
|
||||
if raw := cfg.PIIPatternOverrides(); len(raw) > 0 {
|
||||
overrides = make(map[string]pii.Action, len(raw))
|
||||
for ovid, action := range raw {
|
||||
switch pii.Action(action) {
|
||||
case pii.ActionMask, pii.ActionBlock, pii.ActionAllow:
|
||||
overrides[ovid] = pii.Action(action)
|
||||
}
|
||||
}
|
||||
}
|
||||
return pii.NewStreamFilter(piiRedactor, overrides, piiEvents, correlationID, userID)
|
||||
}
|
||||
|
||||
// ForwardViaBackend loads the cloud-proxy gRPC backend, ships the
|
||||
// request via the Forward RPC, and pumps the response back to the
|
||||
// client through the SSE-aware PII pipeline.
|
||||
// client. PII redaction runs request-side (the NER middleware + MITM
|
||||
// input path); the response is forwarded unmodified.
|
||||
func ForwardViaBackend(
|
||||
c echo.Context,
|
||||
cfg *config.ModelConfig,
|
||||
body []byte,
|
||||
filter *pii.StreamFilter,
|
||||
loader *model.ModelLoader,
|
||||
appConfig *config.ApplicationConfig,
|
||||
) (resultErr error) {
|
||||
@@ -176,7 +147,7 @@ func ForwardViaBackend(
|
||||
return passthroughError(c, statusCode, contentType, bodyReader)
|
||||
}
|
||||
if isStream {
|
||||
return forwardStream(c, bodyReader, cfg.Proxy.Provider, filter)
|
||||
return forwardStream(c, bodyReader)
|
||||
}
|
||||
return forwardBuffered(c, statusCode, contentType, bodyReader)
|
||||
}
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
package cloudproxy
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("BuildStreamFilter", func() {
|
||||
var (
|
||||
c echo.Context
|
||||
cfg *config.ModelConfig
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
e := echo.New()
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
c = e.NewContext(req, rec)
|
||||
piiOn := true
|
||||
cfg = &config.ModelConfig{
|
||||
Backend: "cloud-proxy",
|
||||
PII: config.PIIConfig{Enabled: &piiOn},
|
||||
}
|
||||
})
|
||||
|
||||
// Three guards must each independently force a nil return — proves
|
||||
// the gate is a logical AND, not an order-dependent short-circuit
|
||||
// that silently activates one branch.
|
||||
It("returns nil when isStream is false", func() {
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
r := pii.NewRedactor(patterns)
|
||||
Expect(BuildStreamFilter(c, cfg, false, r, nil, "corr-1")).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil when piiRedactor is nil", func() {
|
||||
Expect(BuildStreamFilter(c, cfg, true, nil, nil, "corr-1")).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns nil when the model has PII disabled", func() {
|
||||
piiOff := false
|
||||
cfg.PII.Enabled = &piiOff
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
r := pii.NewRedactor(patterns)
|
||||
Expect(BuildStreamFilter(c, cfg, true, r, nil, "corr-1")).To(BeNil())
|
||||
})
|
||||
|
||||
It("returns a configured filter when all preconditions hold", func() {
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
r := pii.NewRedactor(patterns)
|
||||
store := pii.NewMemoryEventStore(8)
|
||||
filter := BuildStreamFilter(c, cfg, true, r, store, "corr-xyz")
|
||||
Expect(filter).NotTo(BeNil())
|
||||
})
|
||||
|
||||
// Empty correlationID is allowed — some entry points don't have one.
|
||||
// The filter must still construct so the stream can flow.
|
||||
It("constructs a filter even when correlationID is empty", func() {
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
r := pii.NewRedactor(patterns)
|
||||
Expect(BuildStreamFilter(c, cfg, true, r, nil, "")).NotTo(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/ssewire"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/piiadapter"
|
||||
"github.com/mudler/LocalAI/pkg/httpclient"
|
||||
@@ -24,8 +23,14 @@ import (
|
||||
|
||||
// PIIHandlerOptions configures NewPIIHandler.
|
||||
type PIIHandlerOptions struct {
|
||||
// Redactor is the regex PII redactor. nil disables redaction.
|
||||
Redactor *pii.Redactor
|
||||
// DetectorsByHost maps an intercepted host (lower-cased) to the NER
|
||||
// detector configs that should scan request bodies bound for it. The
|
||||
// configs are resolved at listener-start from each host's owning
|
||||
// model's pii.detectors + the detector models' pii_detection policy
|
||||
// (a model-config edit needs a MITM restart, as hosts already do). A
|
||||
// host absent from the map (or with an empty slice) is forwarded
|
||||
// unredacted. Detector errors at request time fail closed.
|
||||
DetectorsByHost map[string][]pii.NERConfig
|
||||
|
||||
// EventStore receives PIIEvent rows. nil discards events.
|
||||
EventStore pii.EventStore
|
||||
@@ -42,13 +47,6 @@ type PIIHandlerOptions struct {
|
||||
// upstream URL. Identity by default; tests inject a httptest
|
||||
// listener address.
|
||||
DialHost func(host string) string
|
||||
|
||||
// HostsWithPIIDisabled lists destination hosts whose request
|
||||
// bodies should NOT run through the redactor. TLS termination,
|
||||
// upstream forwarding, and audit events still happen — only the
|
||||
// regex pass is bypassed. Useful for telemetry/probe endpoints
|
||||
// whose bodies aren't PII-shaped.
|
||||
HostsWithPIIDisabled []string
|
||||
}
|
||||
|
||||
func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler {
|
||||
@@ -76,16 +74,9 @@ func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler {
|
||||
dialHost = func(h string) string { return h }
|
||||
}
|
||||
|
||||
patternAction := map[string]pii.Action{}
|
||||
if opts.Redactor != nil {
|
||||
for _, p := range opts.Redactor.Patterns() {
|
||||
patternAction[p.ID] = p.Action
|
||||
}
|
||||
}
|
||||
|
||||
piiDisabled := make(map[string]bool, len(opts.HostsWithPIIDisabled))
|
||||
for _, h := range opts.HostsWithPIIDisabled {
|
||||
piiDisabled[strings.ToLower(strings.TrimSpace(h))] = true
|
||||
detectorsByHost := make(map[string][]pii.NERConfig, len(opts.DetectorsByHost))
|
||||
for h, cfgs := range opts.DetectorsByHost {
|
||||
detectorsByHost[strings.ToLower(strings.TrimSpace(h))] = cfgs
|
||||
}
|
||||
|
||||
d := &piiDispatcher{
|
||||
@@ -96,26 +87,22 @@ func NewPIIHandler(opts PIIHandlerOptions) InterceptHandler {
|
||||
// API keys such as Anthropic's x-api-key, which Go does NOT
|
||||
// strip on cross-host redirects — to an unvetted host. Surface
|
||||
// it as an error (handled as a 502) instead.
|
||||
client: httpclient.New(httpclient.WithTransport(transport)),
|
||||
redactor: opts.Redactor,
|
||||
store: opts.EventStore,
|
||||
patternAction: patternAction,
|
||||
corrHeader: corrHeader,
|
||||
dialHost: dialHost,
|
||||
piiDisabled: piiDisabled,
|
||||
client: httpclient.New(httpclient.WithTransport(transport)),
|
||||
detectorsByHost: detectorsByHost,
|
||||
store: opts.EventStore,
|
||||
corrHeader: corrHeader,
|
||||
dialHost: dialHost,
|
||||
}
|
||||
return d.serve
|
||||
}
|
||||
|
||||
type piiDispatcher struct {
|
||||
client *http.Client
|
||||
redactor *pii.Redactor
|
||||
store pii.EventStore
|
||||
patternAction map[string]pii.Action
|
||||
corrHeader string
|
||||
dialHost func(host string) string
|
||||
piiDisabled map[string]bool
|
||||
eventSeq atomic.Uint64
|
||||
client *http.Client
|
||||
detectorsByHost map[string][]pii.NERConfig
|
||||
store pii.EventStore
|
||||
corrHeader string
|
||||
dialHost func(host string) string
|
||||
eventSeq atomic.Uint64
|
||||
}
|
||||
|
||||
func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host string) {
|
||||
@@ -144,11 +131,17 @@ func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host strin
|
||||
}
|
||||
|
||||
shape := classifyRequestShape(host, r.URL.Path)
|
||||
if d.redactor != nil && shape != shapeUnknown && !d.piiDisabled[strings.ToLower(host)] {
|
||||
redacted, blocked, err := d.redactRequest(body, shape, correlationID)
|
||||
cfgs := d.detectorsByHost[strings.ToLower(host)]
|
||||
if len(cfgs) > 0 && shape != shapeUnknown {
|
||||
redacted, blocked, err := d.redactRequest(r.Context(), body, shape, cfgs, correlationID)
|
||||
switch {
|
||||
case err != nil:
|
||||
xlog.Debug("mitm: redact request failed; forwarding unchanged", "host", host, "path", r.URL.Path, "error", err)
|
||||
// Fail closed: a detector outage must not silently forward the
|
||||
// request unredacted — the operator configured this host's
|
||||
// model with detectors precisely to catch this PII.
|
||||
xlog.Error("mitm: NER redaction failed; blocking request (fail-closed)", "host", host, "path", r.URL.Path, "error", err)
|
||||
writePIIBlocked(w, correlationID)
|
||||
return
|
||||
case blocked:
|
||||
writePIIBlocked(w, correlationID)
|
||||
return
|
||||
@@ -185,12 +178,10 @@ func (d *piiDispatcher) serve(w http.ResponseWriter, r *http.Request, host strin
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// Response/output redaction is out of scope for now — the MITM proxy
|
||||
// only scans request bodies (input). SSE responses pass through
|
||||
// unmodified.
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if shape != shapeUnknown && d.redactor != nil && isSSE(contentType) {
|
||||
d.streamWithPII(w, resp.Body, shape, correlationID)
|
||||
return
|
||||
}
|
||||
|
||||
if isSSE(contentType) {
|
||||
flusher, _ := w.(http.Flusher)
|
||||
buf := make([]byte, 32*1024)
|
||||
@@ -232,7 +223,7 @@ func classifyRequestShape(host, path string) requestShape {
|
||||
return shapeUnknown
|
||||
}
|
||||
|
||||
func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlationID string) ([]byte, bool, error) {
|
||||
func (d *piiDispatcher) redactRequest(ctx context.Context, body []byte, shape requestShape, cfgs []pii.NERConfig, correlationID string) ([]byte, bool, error) {
|
||||
var parsed any
|
||||
var adapter pii.Adapter
|
||||
switch shape {
|
||||
@@ -259,13 +250,21 @@ func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlati
|
||||
return body, false, nil
|
||||
}
|
||||
|
||||
// One scan over the joined messages so the NER tier keeps
|
||||
// conversational context (see pii.RedactNERSegments); results map
|
||||
// back per message with local offsets.
|
||||
segTexts := make([]string, len(texts))
|
||||
for i, st := range texts {
|
||||
segTexts[i] = st.Text
|
||||
}
|
||||
results, err := pii.RedactNERSegments(ctx, segTexts, cfgs)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("ner detect: %w", err)
|
||||
}
|
||||
|
||||
updates := make([]pii.ScannedText, 0, len(texts))
|
||||
blocked := false
|
||||
for _, st := range texts {
|
||||
if st.Text == "" {
|
||||
continue
|
||||
}
|
||||
res := d.redactor.RedactWithOverrides(st.Text, nil)
|
||||
for i, res := range results {
|
||||
if len(res.Spans) == 0 {
|
||||
continue
|
||||
}
|
||||
@@ -273,7 +272,7 @@ func (d *piiDispatcher) redactRequest(body []byte, shape requestShape, correlati
|
||||
if res.Blocked {
|
||||
blocked = true
|
||||
}
|
||||
updates = append(updates, pii.ScannedText{Index: st.Index, Text: res.Redacted})
|
||||
updates = append(updates, pii.ScannedText{Index: texts[i].Index, Text: res.Redacted})
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
@@ -295,13 +294,14 @@ func (d *piiDispatcher) recordEvents(spans []pii.Span, correlationID string) {
|
||||
ev := pii.PIIEvent{
|
||||
ID: fmt.Sprintf("mitm_%s_%d", correlationID, d.eventSeq.Add(1)),
|
||||
Kind: pii.KindPII,
|
||||
Origin: pii.OriginProxy,
|
||||
CorrelationID: correlationID,
|
||||
Direction: pii.DirectionIn,
|
||||
PatternID: span.Pattern,
|
||||
ByteOffset: span.Start,
|
||||
Length: span.End - span.Start,
|
||||
HashPrefix: span.HashPrefix,
|
||||
Action: d.patternAction[span.Pattern],
|
||||
Action: span.Action,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := d.store.Record(context.Background(), ev); err != nil {
|
||||
@@ -310,49 +310,6 @@ func (d *piiDispatcher) recordEvents(spans []pii.Span, correlationID string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *piiDispatcher) streamWithPII(w http.ResponseWriter, src io.Reader, shape requestShape, correlationID string) {
|
||||
flusher, _ := w.(http.Flusher)
|
||||
filter := pii.NewStreamFilter(d.redactor, nil, d.store, correlationID, "")
|
||||
|
||||
provider := ssewire.OpenAI
|
||||
if shape == shapeAnthropicMessages {
|
||||
provider = ssewire.Anthropic
|
||||
}
|
||||
|
||||
emit := func(s string) {
|
||||
_, _ = w.Write([]byte(s))
|
||||
if flusher != nil {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
scanner := ssewire.NewScanner(src)
|
||||
for scanner.Scan() {
|
||||
ev := scanner.Event()
|
||||
if ssewire.IsTerminalMarker(ev.DataLine, provider) {
|
||||
if residual := filter.Drain(); residual != "" {
|
||||
emit(ssewire.SynthResidualEvent(provider, residual))
|
||||
}
|
||||
emit(ev.Raw)
|
||||
continue
|
||||
}
|
||||
out := ev.Raw
|
||||
if ev.DataLine != "" {
|
||||
rewritten, drop := ssewire.RewritePayload(ev.DataLine, provider, filter)
|
||||
if drop {
|
||||
continue
|
||||
}
|
||||
if rewritten != ev.DataLine {
|
||||
out = strings.Replace(ev.Raw, ev.DataLine, rewritten, 1)
|
||||
}
|
||||
}
|
||||
emit(out)
|
||||
}
|
||||
if residual := filter.Drain(); residual != "" {
|
||||
emit(ssewire.SynthResidualEvent(provider, residual))
|
||||
}
|
||||
}
|
||||
|
||||
func writePIIBlocked(w http.ResponseWriter, correlationID string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
@@ -19,34 +19,58 @@ import (
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// startPIITestRig is the same shape as startMITMTestRig but plugs
|
||||
// in the production PII handler instead of the passthrough fixture.
|
||||
// The "host" the client thinks it's reaching is forced to
|
||||
// api.anthropic.com so the request shape classifier matches.
|
||||
// substringDetector is a deterministic pii.NERDetector for tests: it
|
||||
// reports an entity for every occurrence of each configured substring,
|
||||
// with byte offsets into the scanned text. Lets the MITM tests drive
|
||||
// request redaction without a real token-classification backend.
|
||||
type substringDetector struct{ groups map[string]string } // substring -> entity group
|
||||
|
||||
func (d substringDetector) Detect(_ context.Context, text string) ([]pii.NEREntity, error) {
|
||||
var out []pii.NEREntity
|
||||
for sub, group := range d.groups {
|
||||
for idx := 0; ; {
|
||||
i := strings.Index(text[idx:], sub)
|
||||
if i < 0 {
|
||||
break
|
||||
}
|
||||
start := idx + i
|
||||
out = append(out, pii.NEREntity{Group: group, Start: start, End: start + len(sub), Score: 1})
|
||||
idx = start + len(sub)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// testDetectorCfg flags emails (mask) and a known secret token (block).
|
||||
func testDetectorCfg() pii.NERConfig {
|
||||
return pii.NERConfig{
|
||||
Detector: substringDetector{groups: map[string]string{
|
||||
"alice@example.com": "EMAIL",
|
||||
"bob@example.org": "EMAIL",
|
||||
"sk-abcdefghijklmnopqrstuvwxyz1234": "PASSWORD",
|
||||
}},
|
||||
EntityActions: map[string]pii.Action{"EMAIL": pii.ActionMask, "PASSWORD": pii.ActionBlock},
|
||||
}
|
||||
}
|
||||
|
||||
// startPIITestRig plugs the production PII handler into a CONNECT proxy,
|
||||
// with the upstream playing the role of api.anthropic.com. Request
|
||||
// bodies bound for api.anthropic.com run through the NER detector above.
|
||||
func startPIITestRig(upstream http.Handler) (*http.Client, string, *fakeStore, func()) {
|
||||
// Upstream fake — plays the role of api.anthropic.com.
|
||||
ts := httptest.NewTLSServer(upstream)
|
||||
upstreamCertPool := x509.NewCertPool()
|
||||
upstreamCertPool.AddCert(ts.Certificate())
|
||||
upstreamURL, _ := url.Parse(ts.URL)
|
||||
|
||||
// Compiled patterns required for the redactor to actually fire
|
||||
// (DefaultPatterns alone returns Pattern structs without regex).
|
||||
patterns, err := pii.Compile(pii.DefaultPatterns())
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred())
|
||||
redactor := pii.NewRedactor(patterns)
|
||||
store := &fakeStore{}
|
||||
|
||||
ca, err := NewInMemoryCA()
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred())
|
||||
|
||||
// DialHost remaps the upstream dial target to the httptest
|
||||
// fake while leaving the classifier-facing host
|
||||
// ("api.anthropic.com") untouched. ServerName=example.com is
|
||||
// what httptest.NewTLSServer issues its cert for.
|
||||
upstreamHost := upstreamURL.Host
|
||||
prodHandler := NewPIIHandler(PIIHandlerOptions{
|
||||
Redactor: redactor,
|
||||
DetectorsByHost: map[string][]pii.NERConfig{
|
||||
"api.anthropic.com": {testDetectorCfg()},
|
||||
},
|
||||
EventStore: store,
|
||||
UpstreamTLS: &tls.Config{
|
||||
RootCAs: upstreamCertPool,
|
||||
@@ -79,8 +103,6 @@ func startPIITestRig(upstream http.Handler) (*http.Client, string, *fakeStore, f
|
||||
srv.Stop()
|
||||
ts.Close()
|
||||
}
|
||||
// We point requests at api.anthropic.com so classifyRequestShape
|
||||
// matches; the wrappedHandler retargets to the upstream fake.
|
||||
return client, "https://api.anthropic.com", store, cleanup
|
||||
}
|
||||
|
||||
@@ -101,7 +123,7 @@ func (s *fakeStore) Close() error { return nil }
|
||||
func (s *fakeStore) recorded() int { return len(s.events) }
|
||||
|
||||
var _ = Describe("PIIHandler", func() {
|
||||
It("redacts request email", func() {
|
||||
It("redacts request email via NER", func() {
|
||||
var receivedBody []byte
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedBody, _ = io.ReadAll(r.Body)
|
||||
@@ -119,15 +141,11 @@ var _ = Describe("PIIHandler", func() {
|
||||
Expect(resp.StatusCode).To(Equal(200))
|
||||
|
||||
Expect(string(receivedBody)).NotTo(ContainSubstring("alice@example.com"), "upstream received unredacted body")
|
||||
Expect(string(receivedBody)).To(ContainSubstring("[REDACTED:email]"), "upstream did not see redaction marker")
|
||||
Expect(string(receivedBody)).To(ContainSubstring("[REDACTED:ner:EMAIL]"), "upstream did not see redaction marker")
|
||||
Expect(store.recorded()).NotTo(BeZero(), "no PIIEvent recorded for the email match")
|
||||
})
|
||||
|
||||
It("refuses to follow an upstream redirect", func() {
|
||||
// A 3xx from the upstream would otherwise be followed, replaying
|
||||
// the request (and its provider API key, e.g. Anthropic's
|
||||
// x-api-key which Go does NOT strip on cross-host redirects) to
|
||||
// the Location host. The refused redirect surfaces as a 502.
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "https://evil.example.com/steal", http.StatusFound)
|
||||
})
|
||||
@@ -142,7 +160,7 @@ var _ = Describe("PIIHandler", func() {
|
||||
Expect(resp.StatusCode).To(Equal(http.StatusBadGateway), "refused redirect must surface as 502, not be followed")
|
||||
})
|
||||
|
||||
It("blocks api key in request", func() {
|
||||
It("blocks a detected secret in the request", func() {
|
||||
upstreamCalled := false
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
upstreamCalled = true
|
||||
@@ -156,46 +174,13 @@ var _ = Describe("PIIHandler", func() {
|
||||
resp, err := client.Post(base+"/v1/messages", "application/json", strings.NewReader(body))
|
||||
Expect(err).NotTo(HaveOccurred(), "client.Post")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
Expect(resp.StatusCode).To(Equal(400), "api_key_prefix has Block default")
|
||||
Expect(resp.StatusCode).To(Equal(400), "PASSWORD entity action is block")
|
||||
Expect(upstreamCalled).To(BeFalse(), "upstream was called despite block — proxy should short-circuit")
|
||||
body2, _ := io.ReadAll(resp.Body)
|
||||
Expect(string(body2)).To(ContainSubstring("pii_blocked"))
|
||||
})
|
||||
|
||||
It("streaming redaction", func() {
|
||||
// Anthropic-shape SSE; "alice@" + "example.com" splits the
|
||||
// email across chunks so the StreamFilter has to buffer.
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.WriteHeader(200)
|
||||
flusher := w.(http.Flusher)
|
||||
chunks := []string{
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"contact me at alice@"}}`,
|
||||
`{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"example.com any time"}}`,
|
||||
`{"type":"message_stop"}`,
|
||||
}
|
||||
for _, c := range chunks {
|
||||
_, _ = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", "content_block_delta", c)
|
||||
flusher.Flush()
|
||||
}
|
||||
})
|
||||
|
||||
client, base, _, cleanup := startPIITestRig(upstream)
|
||||
defer cleanup()
|
||||
|
||||
body := `{"model":"claude-3-5-sonnet","max_tokens":100,"stream":true,"messages":[{"role":"user","content":"hi"}]}`
|
||||
resp, err := client.Post(base+"/v1/messages", "application/json", strings.NewReader(body))
|
||||
Expect(err).NotTo(HaveOccurred(), "Post")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
out, _ := io.ReadAll(resp.Body)
|
||||
outStr := string(out)
|
||||
Expect(outStr).NotTo(ContainSubstring("alice@example.com"), "email leaked through MITM stream")
|
||||
Expect(outStr).To(ContainSubstring("[REDACTED:email]"), "redaction marker missing from MITM stream")
|
||||
})
|
||||
|
||||
It("non-chat path passes through", func() {
|
||||
// A path the classifier doesn't recognise (e.g. an OAuth
|
||||
// callback) must forward the body verbatim, no PII parsing.
|
||||
var receivedBody []byte
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
receivedBody, _ = io.ReadAll(r.Body)
|
||||
@@ -216,14 +201,12 @@ var _ = Describe("PIIHandler", func() {
|
||||
|
||||
var _ = Describe("redactRequest", func() {
|
||||
It("handles anthropic shape", func() {
|
||||
patterns, _ := pii.Compile(pii.DefaultPatterns())
|
||||
r := pii.NewRedactor(patterns)
|
||||
body := []byte(`{"model":"claude","max_tokens":10,"messages":[{"role":"user","content":"reach me at bob@example.org"}]}`)
|
||||
|
||||
d := &piiDispatcher{redactor: r, patternAction: map[string]pii.Action{}}
|
||||
out, blocked, err := d.redactRequest(body, shapeAnthropicMessages, "corr-1")
|
||||
d := &piiDispatcher{}
|
||||
out, blocked, err := d.redactRequest(context.Background(), body, shapeAnthropicMessages, []pii.NERConfig{testDetectorCfg()}, "corr-1")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(blocked).To(BeFalse(), "email is mask, not block — blocked should be false")
|
||||
Expect(blocked).To(BeFalse(), "EMAIL is mask, not block — blocked should be false")
|
||||
var parsed map[string]any
|
||||
Expect(json.Unmarshal(out, &parsed)).To(Succeed())
|
||||
msgs := parsed["messages"].([]any)
|
||||
@@ -273,9 +256,6 @@ var _ = Describe("Proxy events", func() {
|
||||
})
|
||||
|
||||
It("tunneled host emits connect event only", func() {
|
||||
// A non-allowlisted CONNECT must record a proxy_connect with
|
||||
// Intercepted=false and NOT a proxy_traffic event (tunneled
|
||||
// bytes never reach the dispatcher).
|
||||
upstream := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = fmt.Fprint(w, "passthrough")
|
||||
})
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
// Package cloudproxy stitches the cloud-proxy gRPC backend to the
|
||||
// HTTP edge: model rewrite, body shaping, and SSE-aware PII filtering
|
||||
// on the response. The outbound HTTP request itself lives inside the
|
||||
// cloud-proxy backend binary (backend/go/cloud-proxy), not here — this
|
||||
// package is the core-side glue.
|
||||
// HTTP edge: model rewrite and body shaping. The outbound HTTP request
|
||||
// itself lives inside the cloud-proxy backend binary
|
||||
// (backend/go/cloud-proxy), not here — this package is the core-side
|
||||
// glue. PII redaction runs request-side (the NER middleware + MITM
|
||||
// input path); response/output is forwarded unmodified.
|
||||
package cloudproxy
|
||||
|
||||
import (
|
||||
@@ -10,11 +11,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/services/cloudproxy/ssewire"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
@@ -61,65 +59,30 @@ func forwardBuffered(c echo.Context, statusCode int, contentType string, body io
|
||||
return err
|
||||
}
|
||||
|
||||
// forwardStream applies SSE-aware PII rewriting as the response flows
|
||||
// to the client. provider selects the dialect (openai vs anthropic);
|
||||
// it comes from cfg.Proxy.Provider on the cloud-proxy backend.
|
||||
func forwardStream(c echo.Context, body io.Reader, provider string, filter *pii.StreamFilter) error {
|
||||
// forwardStream relays the upstream SSE response to the client,
|
||||
// flushing per read so events arrive in real time. Response/output PII
|
||||
// redaction is out of scope for now, so the stream is forwarded
|
||||
// unmodified.
|
||||
func forwardStream(c echo.Context, body io.Reader) error {
|
||||
c.Response().Header().Set("Content-Type", "text/event-stream")
|
||||
c.Response().Header().Set("Cache-Control", "no-cache")
|
||||
c.Response().Header().Set("Connection", "keep-alive")
|
||||
c.Response().WriteHeader(http.StatusOK)
|
||||
|
||||
emit := func(line string) error {
|
||||
_, err := fmt.Fprint(c.Response().Writer, line)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Response().Flush()
|
||||
return nil
|
||||
}
|
||||
|
||||
flushResidual := func() {
|
||||
if filter == nil {
|
||||
return
|
||||
}
|
||||
residual := filter.Drain()
|
||||
if residual == "" {
|
||||
return
|
||||
}
|
||||
if line := ssewire.SynthResidualEvent(ssewire.Provider(provider), residual); line != "" {
|
||||
_ = emit(line)
|
||||
}
|
||||
}
|
||||
|
||||
prov := ssewire.Provider(provider)
|
||||
scanner := ssewire.NewScanner(body)
|
||||
for scanner.Scan() {
|
||||
ev := scanner.Event()
|
||||
if ssewire.IsTerminalMarker(ev.DataLine, prov) {
|
||||
flushResidual()
|
||||
_ = emit(ev.Raw)
|
||||
continue
|
||||
}
|
||||
out := ev.Raw
|
||||
if filter != nil && ev.DataLine != "" {
|
||||
rewritten, drop := ssewire.RewritePayload(ev.DataLine, prov, filter)
|
||||
if drop {
|
||||
continue
|
||||
}
|
||||
if rewritten != ev.DataLine {
|
||||
// strings.Replace with n=1 touches only the data line,
|
||||
// preserving any "event:"/"id:" preamble.
|
||||
out = strings.Replace(ev.Raw, ev.DataLine, rewritten, 1)
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, rErr := body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, wErr := c.Response().Writer.Write(buf[:n]); wErr != nil {
|
||||
return nil
|
||||
}
|
||||
c.Response().Flush()
|
||||
}
|
||||
if err := emit(out); err != nil {
|
||||
if rErr != nil {
|
||||
if rErr != io.EOF {
|
||||
xlog.Debug("cloudproxy: stream read error", "error", rErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||
xlog.Debug("cloudproxy: stream read error", "error", err)
|
||||
}
|
||||
flushResidual()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
// Package ssewire holds the SSE-format helpers shared between
|
||||
// the request-shape cloud proxy (core/services/cloudproxy) and the
|
||||
// TLS-terminating MITM proxy (core/services/cloudproxy/mitm). Both
|
||||
// run a pii.StreamFilter over per-token text extracted from
|
||||
// provider-specific JSON chunks; this package owns the JSON shapes
|
||||
// so a future provider addition is one edit, not two.
|
||||
package ssewire
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
|
||||
// Provider is the upstream wire format an SSE stream conforms to.
|
||||
type Provider string
|
||||
|
||||
const (
|
||||
OpenAI Provider = "openai"
|
||||
Anthropic Provider = "anthropic"
|
||||
)
|
||||
|
||||
// Event is one SSE event with its exact wire bytes preserved in
|
||||
// Raw (so unmodified events round-trip byte-for-byte) and the
|
||||
// extracted JSON payload from the data: line in DataLine.
|
||||
type Event struct {
|
||||
Raw string
|
||||
DataLine string
|
||||
}
|
||||
|
||||
// Scanner reads SSE events one at a time from an upstream body.
|
||||
type Scanner struct {
|
||||
r *bufio.Reader
|
||||
ev Event
|
||||
err error
|
||||
}
|
||||
|
||||
func NewScanner(r io.Reader) *Scanner {
|
||||
return &Scanner{r: bufio.NewReaderSize(r, 64*1024)}
|
||||
}
|
||||
|
||||
func (s *Scanner) Scan() bool {
|
||||
var raw strings.Builder
|
||||
var dataLine string
|
||||
for {
|
||||
line, err := s.r.ReadString('\n')
|
||||
if line != "" {
|
||||
raw.WriteString(line)
|
||||
trimmed := strings.TrimRight(line, "\r\n")
|
||||
if trimmed == "" {
|
||||
if raw.Len() == len(line) {
|
||||
raw.Reset()
|
||||
continue
|
||||
}
|
||||
s.ev = Event{Raw: raw.String(), DataLine: dataLine}
|
||||
return true
|
||||
}
|
||||
if strings.HasPrefix(trimmed, "data:") && dataLine == "" {
|
||||
payload := strings.TrimPrefix(trimmed, "data:")
|
||||
payload = strings.TrimPrefix(payload, " ")
|
||||
dataLine = payload
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
s.err = err
|
||||
if raw.Len() > 0 {
|
||||
s.ev = Event{Raw: raw.String(), DataLine: dataLine}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Scanner) Event() Event { return s.ev }
|
||||
func (s *Scanner) Err() error { return s.err }
|
||||
|
||||
// IsTerminalMarker reports whether the data line is the per-provider
|
||||
// end-of-stream sentinel. The streaming PII filter must drain its
|
||||
// residue before the caller forwards a terminal marker — clients
|
||||
// stop reading after it.
|
||||
func IsTerminalMarker(dataLine string, provider Provider) bool {
|
||||
if dataLine == "" {
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(dataLine) == "[DONE]" {
|
||||
return true
|
||||
}
|
||||
if provider == Anthropic {
|
||||
var probe struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(dataLine), &probe); err == nil {
|
||||
return probe.Type == "message_stop"
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RewritePayload runs the data line's content-bearing field through
|
||||
// the streaming filter. drop=true tells the caller to suppress the
|
||||
// SSE event entirely (the filter buffered the whole token while
|
||||
// disambiguating a pattern boundary).
|
||||
func RewritePayload(dataLine string, provider Provider, filter *pii.StreamFilter) (rewritten string, drop bool) {
|
||||
if strings.TrimSpace(dataLine) == "[DONE]" {
|
||||
return dataLine, false
|
||||
}
|
||||
switch provider {
|
||||
case Anthropic:
|
||||
return rewriteAnthropic(dataLine, filter)
|
||||
default:
|
||||
return rewriteOpenAI(dataLine, filter)
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteOpenAI(dataLine string, filter *pii.StreamFilter) (string, bool) {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal([]byte(dataLine), &m); err != nil {
|
||||
return dataLine, false
|
||||
}
|
||||
choices, ok := m["choices"].([]any)
|
||||
if !ok || len(choices) == 0 {
|
||||
return dataLine, false
|
||||
}
|
||||
first, ok := choices[0].(map[string]any)
|
||||
if !ok {
|
||||
return dataLine, false
|
||||
}
|
||||
delta, ok := first["delta"].(map[string]any)
|
||||
if !ok {
|
||||
return dataLine, false
|
||||
}
|
||||
content, ok := delta["content"].(string)
|
||||
if !ok || content == "" {
|
||||
return dataLine, false
|
||||
}
|
||||
rewritten := filter.Push(content)
|
||||
if rewritten == "" {
|
||||
return "", true
|
||||
}
|
||||
if rewritten == content {
|
||||
return dataLine, false
|
||||
}
|
||||
delta["content"] = rewritten
|
||||
out, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return dataLine, false
|
||||
}
|
||||
return string(out), false
|
||||
}
|
||||
|
||||
func rewriteAnthropic(dataLine string, filter *pii.StreamFilter) (string, bool) {
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal([]byte(dataLine), &m); err != nil {
|
||||
return dataLine, false
|
||||
}
|
||||
if t, _ := m["type"].(string); t != "content_block_delta" {
|
||||
return dataLine, false
|
||||
}
|
||||
delta, ok := m["delta"].(map[string]any)
|
||||
if !ok {
|
||||
return dataLine, false
|
||||
}
|
||||
if dt, _ := delta["type"].(string); dt != "text_delta" {
|
||||
return dataLine, false
|
||||
}
|
||||
text, ok := delta["text"].(string)
|
||||
if !ok || text == "" {
|
||||
return dataLine, false
|
||||
}
|
||||
rewritten := filter.Push(text)
|
||||
if rewritten == "" {
|
||||
return "", true
|
||||
}
|
||||
if rewritten == text {
|
||||
return dataLine, false
|
||||
}
|
||||
delta["text"] = rewritten
|
||||
out, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return dataLine, false
|
||||
}
|
||||
return string(out), false
|
||||
}
|
||||
|
||||
// SynthResidualEvent builds a provider-shaped SSE event carrying
|
||||
// the streaming filter's drained tail so the response body remains
|
||||
// a valid event stream after the proxy splices in held-back text.
|
||||
func SynthResidualEvent(provider Provider, text string) string {
|
||||
switch provider {
|
||||
case Anthropic:
|
||||
payload := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]string{"type": "text_delta", "text": text},
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return "event: content_block_delta\ndata: " + string(b) + "\n\n"
|
||||
default:
|
||||
payload := map[string]any{
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": []map[string]any{
|
||||
{"index": 0, "delta": map[string]string{"content": text}},
|
||||
},
|
||||
}
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return "data: " + string(b) + "\n\n"
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package ssewire
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestSsewire(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "ssewire test suite")
|
||||
}
|
||||
@@ -1,114 +0,0 @@
|
||||
package ssewire
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Scanner contract: returns one Event per double-newline-terminated
|
||||
// SSE block, preserving the raw bytes (so unmodified events round-trip
|
||||
// exactly) and extracting the first data: payload as DataLine.
|
||||
|
||||
var _ = Describe("Scanner", func() {
|
||||
It("scans a basic event", func() {
|
||||
in := "event: foo\ndata: hello\n\n"
|
||||
s := NewScanner(strings.NewReader(in))
|
||||
Expect(s.Scan()).To(BeTrue(), "Scan returned false on a well-formed event; err=%v", s.Err())
|
||||
ev := s.Event()
|
||||
Expect(ev.Raw).To(Equal(in))
|
||||
Expect(ev.DataLine).To(Equal("hello"))
|
||||
Expect(s.Scan()).To(BeFalse(), "Scan should return false after the only event")
|
||||
})
|
||||
|
||||
It("handles CRLF", func() {
|
||||
// Some upstreams emit CRLF instead of LF. The scanner trims
|
||||
// trailing \r off the data line so DataLine carries the same
|
||||
// bytes whichever line ending the producer chose.
|
||||
in := "event: foo\r\ndata: hello\r\n\r\n"
|
||||
s := NewScanner(strings.NewReader(in))
|
||||
Expect(s.Scan()).To(BeTrue(), "Scan returned false on CRLF event; err=%v", s.Err())
|
||||
Expect(s.Event().DataLine).To(Equal("hello"))
|
||||
})
|
||||
|
||||
It("scans multiple events", func() {
|
||||
in := "data: one\n\ndata: two\n\ndata: three\n\n"
|
||||
s := NewScanner(strings.NewReader(in))
|
||||
got := []string{}
|
||||
for s.Scan() {
|
||||
got = append(got, s.Event().DataLine)
|
||||
}
|
||||
Expect(got).To(Equal([]string{"one", "two", "three"}))
|
||||
})
|
||||
|
||||
It("handles empty data payload", func() {
|
||||
// "data:" with no payload is valid SSE — DataLine should be empty
|
||||
// and Scan should still surface the event so callers can decide.
|
||||
in := "data:\n\n"
|
||||
s := NewScanner(strings.NewReader(in))
|
||||
Expect(s.Scan()).To(BeTrue(), "Scan returned false on empty data payload; err=%v", s.Err())
|
||||
Expect(s.Event().DataLine).To(Equal(""))
|
||||
})
|
||||
|
||||
It("skips leading blank lines", func() {
|
||||
// A producer that prints a blank "keep-alive" before the first
|
||||
// real event must not produce a phantom event.
|
||||
in := "\n\n\ndata: real\n\n"
|
||||
s := NewScanner(strings.NewReader(in))
|
||||
Expect(s.Scan()).To(BeTrue(), "Scan returned false; err=%v", s.Err())
|
||||
Expect(s.Event().DataLine).To(Equal("real"))
|
||||
})
|
||||
|
||||
It("handles mid-event EOF", func() {
|
||||
// EOF mid-event still surfaces the partial event with whatever
|
||||
// data was extracted — the StreamFilter+caller decides how to
|
||||
// handle a truncated upstream rather than silently dropping it.
|
||||
in := "data: half"
|
||||
s := NewScanner(strings.NewReader(in))
|
||||
Expect(s.Scan()).To(BeTrue(), "Scan returned false on partial event")
|
||||
ev := s.Event()
|
||||
Expect(ev.DataLine).To(Equal("half"))
|
||||
Expect(s.Scan()).To(BeFalse(), "Scan should not surface a second event after EOF")
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("IsTerminalMarker", func() {
|
||||
cases := []struct {
|
||||
name string
|
||||
dataLine string
|
||||
provider Provider
|
||||
want bool
|
||||
}{
|
||||
{"openai DONE", "[DONE]", OpenAI, true},
|
||||
{"openai DONE with whitespace", " [DONE] ", OpenAI, true},
|
||||
{"anthropic DONE also recognised", "[DONE]", Anthropic, true},
|
||||
{"anthropic message_stop", `{"type":"message_stop"}`, Anthropic, true},
|
||||
{"anthropic content_block_delta is not terminal", `{"type":"content_block_delta"}`, Anthropic, false},
|
||||
{"openai chat.completion.chunk is not terminal", `{"object":"chat.completion.chunk"}`, OpenAI, false},
|
||||
{"openai message_stop is not terminal (wrong provider)", `{"type":"message_stop"}`, OpenAI, false},
|
||||
{"empty data", "", OpenAI, false},
|
||||
{"non-json garbage", "garbage", Anthropic, false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
It(c.name, func() {
|
||||
Expect(IsTerminalMarker(c.dataLine, c.provider)).To(Equal(c.want))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
var _ = Describe("SynthResidualEvent", func() {
|
||||
It("anthropic", func() {
|
||||
got := SynthResidualEvent(Anthropic, "tail")
|
||||
Expect(strings.HasPrefix(got, "event: content_block_delta\ndata:")).To(BeTrue(), "Anthropic residual event missing event/data lines: %q", got)
|
||||
Expect(strings.HasSuffix(got, "\n\n")).To(BeTrue(), "Anthropic residual event missing trailing blank line: %q", got)
|
||||
Expect(got).To(ContainSubstring(`"text":"tail"`))
|
||||
})
|
||||
|
||||
It("openai", func() {
|
||||
got := SynthResidualEvent(OpenAI, "tail")
|
||||
Expect(strings.HasPrefix(got, "data: ")).To(BeTrue(), "OpenAI residual event missing data: prefix: %q", got)
|
||||
Expect(strings.HasSuffix(got, "\n\n")).To(BeTrue(), "OpenAI residual event missing trailing blank line: %q", got)
|
||||
Expect(got).To(ContainSubstring(`"content":"tail"`))
|
||||
})
|
||||
})
|
||||
@@ -6,12 +6,13 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"dario.cat/mergo"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/config/meta"
|
||||
"github.com/mudler/LocalAI/core/gallery"
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
"github.com/mudler/LocalAI/pkg/utils"
|
||||
@@ -114,9 +115,7 @@ func (s *ConfigService) PatchConfig(_ context.Context, name string, patch map[st
|
||||
if existingMap == nil {
|
||||
existingMap = map[string]any{}
|
||||
}
|
||||
if err := mergo.Merge(&existingMap, patch, mergo.WithOverride); err != nil {
|
||||
return nil, fmt.Errorf("merge configs: %w", err)
|
||||
}
|
||||
patchMerge(existingMap, patch, mapLeafFieldPaths(), "")
|
||||
yamlData, err := yaml.Marshal(existingMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal merged YAML: %w", err)
|
||||
@@ -142,6 +141,55 @@ func (s *ConfigService) PatchConfig(_ context.Context, name string, patch map[st
|
||||
return &updated, nil
|
||||
}
|
||||
|
||||
// mapLeafFieldPaths returns the set of dotted config paths whose schema type is
|
||||
// a map that the editor edits as one complete value (e.g.
|
||||
// pii_detection.entity_actions, roles, engine_args). A PATCH must REPLACE these
|
||||
// wholesale rather than union them: the deep-merge only adds and overrides
|
||||
// keys, so a map entry the admin deleted in the editor would otherwise silently
|
||||
// survive. Derived from the config schema so it stays correct as map fields are
|
||||
// added. (UIType comes from reflection, independent of any registry override.)
|
||||
func mapLeafFieldPaths() map[string]struct{} {
|
||||
md := meta.BuildConfigMetadata(reflect.TypeFor[config.ModelConfig]())
|
||||
out := make(map[string]struct{})
|
||||
for _, f := range md.Fields {
|
||||
if f.UIType == "map" {
|
||||
out[f.Path] = struct{}{}
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// patchMerge deep-merges src into dst with the same shape as the previous
|
||||
// mergo.WithOverride behaviour — scalars and slices replace; nested
|
||||
// struct-maps (e.g. pii_detection, parameters) recurse so unknown sibling keys
|
||||
// the editor doesn't model survive — EXCEPT that any path in mapLeaves is
|
||||
// replaced wholesale, and removed when the patch sets it empty, so deletions
|
||||
// inside a map field persist to disk.
|
||||
func patchMerge(dst, src map[string]any, mapLeaves map[string]struct{}, prefix string) {
|
||||
for k, sv := range src {
|
||||
path := k
|
||||
if prefix != "" {
|
||||
path = prefix + "." + k
|
||||
}
|
||||
if _, isLeaf := mapLeaves[path]; isLeaf {
|
||||
if m, ok := sv.(map[string]any); ok && len(m) == 0 {
|
||||
delete(dst, k) // emptied map field -> drop it from the YAML
|
||||
} else {
|
||||
dst[k] = sv
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Recurse into struct-like nesting so dst-only sibling keys survive.
|
||||
if sm, ok := sv.(map[string]any); ok {
|
||||
if dm, ok2 := dst[k].(map[string]any); ok2 {
|
||||
patchMerge(dm, sm, mapLeaves, path)
|
||||
continue
|
||||
}
|
||||
}
|
||||
dst[k] = sv
|
||||
}
|
||||
}
|
||||
|
||||
// EditYAML replaces the YAML for an installed model, with optional rename
|
||||
// support. ml may be nil; when set, EditYAML calls ml.ShutdownModel(oldName)
|
||||
// after a successful write so the next inference picks up the new config.
|
||||
|
||||
@@ -107,6 +107,64 @@ var _ = Describe("ConfigService", func() {
|
||||
_, err := svc.PatchConfig(ctx, "qwen", map[string]any{})
|
||||
Expect(err).To(MatchError(ErrEmptyBody))
|
||||
})
|
||||
|
||||
It("replaces a map field wholesale so deleted entries do not survive", func() {
|
||||
// A detector model with a populated entity_actions map. The editor
|
||||
// removes SSN and re-sends the remaining map; a naive deep-merge
|
||||
// would re-add SSN (it only adds/overrides keys, never deletes).
|
||||
writeModelYAML(svc, dir, "ner", map[string]any{
|
||||
"backend": "llama-cpp",
|
||||
"known_usecases": []any{"token_classify"},
|
||||
"pii_detection": map[string]any{
|
||||
"default_action": "mask",
|
||||
"entity_actions": map[string]any{"SSN": "block", "EMAIL": "mask"},
|
||||
},
|
||||
})
|
||||
|
||||
_, err := svc.PatchConfig(ctx, "ner", map[string]any{
|
||||
"pii_detection": map[string]any{
|
||||
"default_action": "mask",
|
||||
"entity_actions": map[string]any{"EMAIL": "mask"},
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
raw, err := os.ReadFile(filepath.Join(dir, "ner.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
var got map[string]any
|
||||
Expect(yaml.Unmarshal(raw, &got)).To(Succeed())
|
||||
pii := got["pii_detection"].(map[string]any)
|
||||
ea := pii["entity_actions"].(map[string]any)
|
||||
Expect(ea).To(HaveKeyWithValue("EMAIL", "mask"))
|
||||
Expect(ea).NotTo(HaveKey("SSN"), "deleted map entry must not survive the patch")
|
||||
// The scalar sibling in the same nested block is still preserved.
|
||||
Expect(pii).To(HaveKeyWithValue("default_action", "mask"))
|
||||
})
|
||||
|
||||
It("drops a map field entirely when the patch empties it", func() {
|
||||
writeModelYAML(svc, dir, "ner", map[string]any{
|
||||
"backend": "llama-cpp",
|
||||
"known_usecases": []any{"token_classify"},
|
||||
"pii_detection": map[string]any{
|
||||
"default_action": "mask",
|
||||
"entity_actions": map[string]any{"SSN": "block"},
|
||||
},
|
||||
})
|
||||
|
||||
_, err := svc.PatchConfig(ctx, "ner", map[string]any{
|
||||
"pii_detection": map[string]any{
|
||||
"entity_actions": map[string]any{},
|
||||
},
|
||||
})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
raw, err := os.ReadFile(filepath.Join(dir, "ner.yaml"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
var got map[string]any
|
||||
Expect(yaml.Unmarshal(raw, &got)).To(Succeed())
|
||||
pii := got["pii_detection"].(map[string]any)
|
||||
Expect(pii).NotTo(HaveKey("entity_actions"))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("EditYAML", func() {
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// FileConfig is the on-disk schema for pii.yaml. Each Pattern entry
|
||||
// overrides the matching default by ID; missing fields fall back to
|
||||
// the default. Unknown IDs are rejected at load time so an admin who
|
||||
// fat-fingers a pattern name gets a clear error rather than a silent
|
||||
// no-op.
|
||||
type FileConfig struct {
|
||||
Patterns []FilePattern `yaml:"patterns"`
|
||||
}
|
||||
|
||||
type FilePattern struct {
|
||||
ID string `yaml:"id"`
|
||||
Action Action `yaml:"action"`
|
||||
}
|
||||
|
||||
// LoadConfig reads pii.yaml from path and merges it on top of
|
||||
// DefaultPatterns(). path == "" returns the defaults compiled and
|
||||
// ready. The returned slice is already Compile()'d, so callers can
|
||||
// pass it straight to NewRedactor.
|
||||
func LoadConfig(path string) ([]Pattern, error) {
|
||||
defaults := DefaultPatterns()
|
||||
if path == "" {
|
||||
return Compile(defaults)
|
||||
}
|
||||
|
||||
raw, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pii: read config %q: %w", path, err)
|
||||
}
|
||||
var cfg FileConfig
|
||||
if err := yaml.Unmarshal(raw, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("pii: parse config %q: %w", path, err)
|
||||
}
|
||||
|
||||
overrides := make(map[string]Action, len(cfg.Patterns))
|
||||
known := make(map[string]bool, len(defaults))
|
||||
for _, d := range defaults {
|
||||
known[d.ID] = true
|
||||
}
|
||||
for _, p := range cfg.Patterns {
|
||||
if !known[p.ID] {
|
||||
return nil, fmt.Errorf("pii: unknown pattern id %q in %q", p.ID, path)
|
||||
}
|
||||
if p.Action == "" {
|
||||
continue
|
||||
}
|
||||
switch p.Action {
|
||||
case ActionMask, ActionBlock, ActionAllow:
|
||||
overrides[p.ID] = p.Action
|
||||
default:
|
||||
return nil, fmt.Errorf("pii: invalid action %q for pattern %q", p.Action, p.ID)
|
||||
}
|
||||
}
|
||||
|
||||
merged := make([]Pattern, len(defaults))
|
||||
for i, d := range defaults {
|
||||
if a, ok := overrides[d.ID]; ok {
|
||||
d.Action = a
|
||||
}
|
||||
merged[i] = d
|
||||
}
|
||||
return Compile(merged)
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("LoadConfig", func() {
|
||||
It("returns defaults when no path given", func() {
|
||||
patterns, err := LoadConfig("")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(patterns).To(HaveLen(len(DefaultPatterns())))
|
||||
})
|
||||
|
||||
It("overrides action", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
path := filepath.Join(dir, "pii.yaml")
|
||||
body := []byte(`patterns:
|
||||
- id: email
|
||||
action: block
|
||||
- id: ssn
|
||||
action: allow
|
||||
`)
|
||||
Expect(os.WriteFile(path, body, 0o600)).To(Succeed())
|
||||
patterns, err := LoadConfig(path)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
got := map[string]Action{}
|
||||
for _, p := range patterns {
|
||||
got[p.ID] = p.Action
|
||||
}
|
||||
Expect(got["email"]).To(Equal(ActionBlock))
|
||||
Expect(got["ssn"]).To(Equal(ActionAllow))
|
||||
// Unmentioned patterns keep their default action.
|
||||
Expect(got["credit_card"]).To(Equal(ActionMask), "credit_card default action lost")
|
||||
})
|
||||
|
||||
It("rejects unknown id", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
path := filepath.Join(dir, "pii.yaml")
|
||||
Expect(os.WriteFile(path, []byte("patterns:\n - id: nonsense\n action: mask\n"), 0o600)).To(Succeed())
|
||||
_, err := LoadConfig(path)
|
||||
Expect(err).To(HaveOccurred(), "expected error on unknown pattern id")
|
||||
})
|
||||
|
||||
It("rejects invalid action", func() {
|
||||
dir := GinkgoT().TempDir()
|
||||
path := filepath.Join(dir, "pii.yaml")
|
||||
Expect(os.WriteFile(path, []byte("patterns:\n - id: email\n action: lolwhat\n"), 0o600)).To(Succeed())
|
||||
_, err := LoadConfig(path)
|
||||
Expect(err).To(HaveOccurred(), "expected error on invalid action")
|
||||
})
|
||||
})
|
||||
@@ -19,28 +19,71 @@ import (
|
||||
// drag the http/middleware package into pii's import graph and create
|
||||
// a cycle (http/middleware will import this one).
|
||||
const (
|
||||
ctxKeyCorrelationID = "routing.correlation_id"
|
||||
ctxKeyPIIEventID = "routing.pii_event_id"
|
||||
ctxKeyCorrelationID = "routing.correlation_id"
|
||||
ctxKeyPIIEventID = "routing.pii_event_id"
|
||||
// Must match the constants in core/http/middleware/request.go.
|
||||
// Echoing them across packages would create an import cycle
|
||||
// (http/middleware imports this package). Drift is caught by
|
||||
// integration tests against the chat route.
|
||||
ctxKeyParsedRequest = "LOCALAI_REQUEST"
|
||||
ctxKeyModelConfig = "MODEL_CONFIG"
|
||||
ctxKeyParsedRequest = "LOCALAI_REQUEST"
|
||||
ctxKeyModelConfig = "MODEL_CONFIG"
|
||||
)
|
||||
|
||||
// ModelPIIConfig is the duck-typed view this middleware needs of the
|
||||
// per-model PII configuration carried on the echo context. *config.ModelConfig
|
||||
// satisfies it via PIIIsEnabled / PIIPatternOverrides; the indirection
|
||||
// keeps the pii package from importing core/config.
|
||||
// per-model PII configuration carried on the echo context.
|
||||
// *config.ModelConfig satisfies it via PIIIsEnabled / PIIDetectors; the
|
||||
// indirection keeps the pii package from importing core/config.
|
||||
//
|
||||
// Consumers of the override map: the action returned from PIIPatternOverrides
|
||||
// is the raw YAML string (e.g. "block"). Validation against the canonical
|
||||
// ActionMask/Block/Allow constants happens here, so a typo in a model
|
||||
// YAML logs and is ignored rather than panicking.
|
||||
// PIIDetectors lists the token-classification models whose detections
|
||||
// drive redaction for this (consuming) model. The detection policy lives
|
||||
// on each named detector model — resolved via NERDetectorResolver — so
|
||||
// this consuming view carries no per-entity actions of its own.
|
||||
type ModelPIIConfig interface {
|
||||
PIIIsEnabled() bool
|
||||
PIIPatternOverrides() map[string]string
|
||||
PIIDetectors() []string
|
||||
}
|
||||
|
||||
// NERDetectorResolver resolves a detector model name to a ready-to-use
|
||||
// NERConfig — the detector plus the policy (min score, entity→action
|
||||
// map, default action) read from that model's own pii_detection block.
|
||||
// ok is false when the name can't supply a detector (unknown model, not
|
||||
// a token_classify model, or load failure); the middleware fails closed
|
||||
// in that case. Supplied by the application layer, which owns the model
|
||||
// loader and the core/backend dependency, keeping the pii package free of
|
||||
// both. A nil resolver (or the option being unset) disables the NER tier.
|
||||
type NERDetectorResolver func(modelName string) (NERConfig, bool)
|
||||
|
||||
// Option configures optional RequestMiddleware behaviour. Threaded as
|
||||
// variadic options so adding the NER tier doesn't break the existing
|
||||
// four-argument call sites (routes and tests).
|
||||
type Option func(*mwOptions)
|
||||
|
||||
type mwOptions struct {
|
||||
nerResolver NERDetectorResolver
|
||||
policyResolver PolicyResolver
|
||||
}
|
||||
|
||||
// PolicyResolver returns the effective (enabled, detectors) for the model
|
||||
// carried on the request context, layering instance-wide PII defaults over the
|
||||
// per-model config. Supplied by the application layer (which owns core/config),
|
||||
// keeping this package decoupled from it — the middleware passes the raw
|
||||
// context value through as `any`. When unset, the middleware falls back to the
|
||||
// duck-typed ModelPIIConfig (explicit per-model config only, no global default).
|
||||
type PolicyResolver func(modelCfg any) (enabled bool, detectors []string)
|
||||
|
||||
// WithPolicyResolver overrides how the middleware decides enablement and the
|
||||
// detector list, so the instance-wide default detector / default-on usecases
|
||||
// apply. Without it the middleware reads ModelPIIConfig off the context.
|
||||
func WithPolicyResolver(r PolicyResolver) Option {
|
||||
return func(o *mwOptions) { o.policyResolver = r }
|
||||
}
|
||||
|
||||
// WithNERResolver enables the NER tier. When a request's model lists
|
||||
// pii.detectors, the middleware resolves each to a NERConfig and runs
|
||||
// RedactNER (the union of all detectors' hits, merged). Without this
|
||||
// option, or when a model lists no detectors, redaction is a no-op.
|
||||
func WithNERResolver(r NERDetectorResolver) Option {
|
||||
return func(o *mwOptions) { o.nerResolver = r }
|
||||
}
|
||||
|
||||
// ScannedText is one piece of user text from the request. Index is
|
||||
@@ -84,30 +127,32 @@ type Adapter struct {
|
||||
// no-auth identity. The middleware writes ctxKeyPIIEventID on the echo
|
||||
// context so the usage middleware can later cross-reference the event
|
||||
// with the UsageRecord.
|
||||
func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fallbackUser *auth.User) echo.MiddlewareFunc {
|
||||
func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fallbackUser *auth.User, opts ...Option) echo.MiddlewareFunc {
|
||||
var o mwOptions
|
||||
for _, opt := range opts {
|
||||
opt(&o)
|
||||
}
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
if redactor == nil || len(redactor.Patterns()) == 0 || adapter.Scan == nil {
|
||||
if redactor == nil || adapter.Scan == nil {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// Per-model gating: redaction is opt-in per model. If the
|
||||
// resolved config disables PII for this model (the default
|
||||
// for non-proxy backends), pass through immediately. We do
|
||||
// this before parsing the request so a disabled model
|
||||
// doesn't pay the regex scan cost.
|
||||
if cfg, ok := c.Get(ctxKeyModelConfig).(ModelPIIConfig); ok {
|
||||
if !cfg.PIIIsEnabled() {
|
||||
return next(c)
|
||||
}
|
||||
} else {
|
||||
// No ModelPIIConfig on context → fail-closed: skip
|
||||
// redaction. This protects routes that wire the
|
||||
// middleware before SetModelAndConfig runs (or non-chat
|
||||
// routes that don't carry a model). The middleware was
|
||||
// previously fail-open, applying the global redactor
|
||||
// unconditionally; the new contract is per-model
|
||||
// opt-in, and a missing model is treated as disabled.
|
||||
// Per-model gating: redaction is opt-in per model. The policy
|
||||
// resolver (when wired) layers instance-wide defaults over the
|
||||
// per-model config; otherwise we read the per-model config
|
||||
// directly. A missing config (non-chat routes, or middleware
|
||||
// wired before SetModelAndConfig) or a not-enabled result passes
|
||||
// through.
|
||||
rawCfg := c.Get(ctxKeyModelConfig)
|
||||
var enabled bool
|
||||
var detectors []string
|
||||
if o.policyResolver != nil {
|
||||
enabled, detectors = o.policyResolver(rawCfg)
|
||||
} else if cfg, ok := rawCfg.(ModelPIIConfig); ok {
|
||||
enabled, detectors = cfg.PIIIsEnabled(), cfg.PIIDetectors()
|
||||
}
|
||||
if !enabled {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
@@ -116,6 +161,12 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
return next(c)
|
||||
}
|
||||
|
||||
// A PII-enabled model with no detectors (or no resolver wired)
|
||||
// has nothing to scan with — pass through.
|
||||
if len(detectors) == 0 || o.nerResolver == nil {
|
||||
return next(c)
|
||||
}
|
||||
|
||||
user := auth.GetUser(c)
|
||||
if user == nil {
|
||||
user = fallbackUser
|
||||
@@ -126,24 +177,19 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
}
|
||||
correlationID, _ := c.Get(ctxKeyCorrelationID).(string)
|
||||
|
||||
// Resolve per-model action overrides once per request. The
|
||||
// raw map is YAML strings; convert to the typed Action set
|
||||
// and silently drop unknown values rather than failing the
|
||||
// request — model YAML typos shouldn't take chat down.
|
||||
var overrides map[string]Action
|
||||
if cfg, ok := c.Get(ctxKeyModelConfig).(ModelPIIConfig); ok {
|
||||
if raw := cfg.PIIPatternOverrides(); len(raw) > 0 {
|
||||
overrides = make(map[string]Action, len(raw))
|
||||
for id, action := range raw {
|
||||
switch Action(action) {
|
||||
case ActionMask, ActionBlock, ActionAllow:
|
||||
overrides[id] = Action(action)
|
||||
default:
|
||||
xlog.Warn("pii: ignoring unknown action in per-model override",
|
||||
"pattern", id, "action", action)
|
||||
}
|
||||
}
|
||||
// Resolve each named detector to its NERConfig (detector +
|
||||
// the policy from that model's own pii_detection block). A
|
||||
// configured detector that can't be resolved fails closed:
|
||||
// serving the request without the semantic check the operator
|
||||
// asked for is exactly the leak this tier exists to prevent.
|
||||
cfgs := make([]NERConfig, 0, len(detectors))
|
||||
for _, name := range detectors {
|
||||
nc, ok := o.nerResolver(name)
|
||||
if !ok {
|
||||
xlog.Error("pii: configured detector model could not be resolved; blocking request (fail-closed)", "detector", name)
|
||||
return blockNERUnavailable(c, store, correlationID, userID)
|
||||
}
|
||||
cfgs = append(cfgs, nc)
|
||||
}
|
||||
|
||||
texts := adapter.Scan(parsed)
|
||||
@@ -151,24 +197,38 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
var blocked bool
|
||||
var firstEventID string
|
||||
|
||||
for _, st := range texts {
|
||||
if st.Text == "" {
|
||||
continue
|
||||
}
|
||||
res := redactor.RedactWithOverrides(st.Text, overrides)
|
||||
// Scan the request as ONE document (messages joined) so the NER
|
||||
// tier keeps conversational context — whether "4421" is a PIN is
|
||||
// decided by the question in the previous message. The spans come
|
||||
// back per message with local offsets for in-place rewriting.
|
||||
segTexts := make([]string, len(texts))
|
||||
for i, st := range texts {
|
||||
segTexts[i] = st.Text
|
||||
}
|
||||
// Fail closed: a detector outage at request time must NOT
|
||||
// silently serve the request. The NER tier was explicitly
|
||||
// configured for this model, so the semantic check is part
|
||||
// of the contract.
|
||||
segResults, nerErr := RedactNERSegments(c.Request().Context(), segTexts, cfgs)
|
||||
if nerErr != nil {
|
||||
xlog.Error("pii: NER detector failed; blocking request (fail-closed)", "error", nerErr)
|
||||
return blockNERUnavailable(c, store, correlationID, userID)
|
||||
}
|
||||
|
||||
for i, res := range segResults {
|
||||
st := texts[i]
|
||||
if len(res.Spans) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Persist one event per span so admins can see exactly
|
||||
// which patterns fired in which positions. The action
|
||||
// recorded is the resolved one (after override), so the
|
||||
// events log reflects what actually happened to the
|
||||
// request, not the global default.
|
||||
// Persist one event per detected span. The action recorded
|
||||
// is the one that actually fired (carried on the span after
|
||||
// the overlap merge), so the events log reflects what
|
||||
// happened to the request.
|
||||
for _, span := range res.Spans {
|
||||
action := actionForSpan(redactor.Patterns(), span.Pattern, overrides)
|
||||
ev := PIIEvent{
|
||||
ID: newEventID(),
|
||||
Origin: OriginMiddleware,
|
||||
CorrelationID: correlationID,
|
||||
UserID: userID,
|
||||
Direction: DirectionIn,
|
||||
@@ -176,7 +236,8 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
ByteOffset: span.Start,
|
||||
Length: span.End - span.Start,
|
||||
HashPrefix: span.HashPrefix,
|
||||
Action: action,
|
||||
Action: span.Action,
|
||||
Score: span.Score,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
if firstEventID == "" {
|
||||
@@ -223,24 +284,85 @@ func RequestMiddleware(redactor *Redactor, store EventStore, adapter Adapter, fa
|
||||
}
|
||||
}
|
||||
|
||||
func actionForPattern(patterns []Pattern, id string) Action {
|
||||
for _, p := range patterns {
|
||||
if p.ID == id {
|
||||
return p.Action
|
||||
// nerUnavailablePattern is the sentinel PatternID recorded on the
|
||||
// fail-closed audit event when a model's configured NER tier cannot
|
||||
// run. It is not a real regex pattern — it marks a request blocked
|
||||
// because the encoder/NER check was unavailable (model unresolved or
|
||||
// backend error), so the events log distinguishes it from a content
|
||||
// block (which carries a real pattern ID).
|
||||
const nerUnavailablePattern = "__ner_unavailable__"
|
||||
|
||||
// blockNERUnavailable records a fail-closed audit event and returns the
|
||||
// response used when a model has an NER tier configured but it could
|
||||
// not run. Failing closed is deliberate for a PII filter: if the
|
||||
// semantic check the operator asked for cannot execute, refusing the
|
||||
// request is safer than serving it with only the cheap regex tier. The
|
||||
// 503 (vs the 400 used for a content block) tells clients and operators
|
||||
// this was a dependency outage, not sensitive data in the request.
|
||||
func blockNERUnavailable(c echo.Context, store EventStore, correlationID, userID string) error {
|
||||
ev := PIIEvent{
|
||||
ID: newEventID(),
|
||||
Kind: KindPII,
|
||||
Origin: OriginMiddleware,
|
||||
CorrelationID: correlationID,
|
||||
UserID: userID,
|
||||
Direction: DirectionIn,
|
||||
PatternID: nerUnavailablePattern,
|
||||
Action: ActionBlock,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
if store != nil {
|
||||
if err := store.Record(context.Background(), ev); err != nil {
|
||||
xlog.Error("pii: failed to record NER-unavailable event", "error", err)
|
||||
}
|
||||
}
|
||||
return ActionMask
|
||||
c.Set(ctxKeyPIIEventID, ev.ID)
|
||||
return c.JSON(http.StatusServiceUnavailable, map[string]any{
|
||||
"error": map[string]string{
|
||||
"message": "request blocked: PII NER check is configured but unavailable",
|
||||
"type": "pii_ner_unavailable",
|
||||
},
|
||||
"correlation_id": correlationID,
|
||||
"pii_event_id": ev.ID,
|
||||
})
|
||||
}
|
||||
|
||||
// actionForSpan returns the resolved action for a span, preferring a
|
||||
// per-request override over the pattern's stored action. Used so the
|
||||
// PIIEvent log reflects the action that actually fired (e.g., a model
|
||||
// upgraded email from mask to block — the event row says "block").
|
||||
func actionForSpan(patterns []Pattern, id string, overrides map[string]Action) Action {
|
||||
if action, ok := overrides[id]; ok {
|
||||
return action
|
||||
// validAction converts a raw YAML action string to the typed Action,
|
||||
// returning "" for anything that isn't a known action.
|
||||
func validAction(raw string) Action {
|
||||
switch Action(raw) {
|
||||
case ActionMask, ActionBlock, ActionAllow:
|
||||
return Action(raw)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
return actionForPattern(patterns, id)
|
||||
}
|
||||
|
||||
// validActionOr is validAction with a fallback for empty/invalid input.
|
||||
func validActionOr(raw string, fallback Action) Action {
|
||||
if a := validAction(raw); a != "" {
|
||||
return a
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
// validActions converts a raw entity-group->action map to typed
|
||||
// Actions, dropping (and logging) unknown actions so a model YAML typo
|
||||
// is ignored rather than taking the request down — mirroring how the
|
||||
// per-pattern overrides are validated above.
|
||||
func validActions(raw map[string]string) map[string]Action {
|
||||
if len(raw) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]Action, len(raw))
|
||||
for group, action := range raw {
|
||||
if a := validAction(action); a != "" {
|
||||
out[group] = a
|
||||
} else {
|
||||
xlog.Warn("pii: ignoring unknown NER entity action", "group", group, "action", action)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func newEventID() string {
|
||||
@@ -248,3 +370,8 @@ func newEventID() string {
|
||||
_, _ = rand.Read(b[:])
|
||||
return "pii_" + hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
// NewEventID mints a fresh random event id in the package's standard shape.
|
||||
// Exported so callers outside this package (the analyze/redact API handlers)
|
||||
// record events with ids indistinguishable from the in-band middleware's.
|
||||
func NewEventID() string { return newEventID() }
|
||||
|
||||
@@ -3,12 +3,12 @@ package pii
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/mudler/LocalAI/core/http/auth"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -56,21 +56,15 @@ func setRequestOnContext(req *fakeRequest) echo.MiddlewareFunc {
|
||||
}
|
||||
|
||||
// fakeModelPIIConfig satisfies the duck-typed ModelPIIConfig interface
|
||||
// the middleware expects on the echo context. The real implementation
|
||||
// lives on *config.ModelConfig; using a fake here keeps these tests
|
||||
// out of the core/config import graph.
|
||||
// the middleware expects on the echo context (PIIIsEnabled + PIIDetectors).
|
||||
type fakeModelPIIConfig struct {
|
||||
enabled bool
|
||||
overrides map[string]string
|
||||
detectors []string
|
||||
}
|
||||
|
||||
func (f fakeModelPIIConfig) PIIIsEnabled() bool { return f.enabled }
|
||||
func (f fakeModelPIIConfig) PIIPatternOverrides() map[string]string { return f.overrides }
|
||||
func (f fakeModelPIIConfig) PIIIsEnabled() bool { return f.enabled }
|
||||
func (f fakeModelPIIConfig) PIIDetectors() []string { return f.detectors }
|
||||
|
||||
// withModelConfig wires a ModelPIIConfig onto the context so the
|
||||
// middleware's per-model gate doesn't fail-closed during tests. Pass
|
||||
// enabled=true for the default test path; explicit-false tests should
|
||||
// use the gating spec further down instead.
|
||||
func withModelConfig(cfg fakeModelPIIConfig) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) error {
|
||||
@@ -80,230 +74,257 @@ func withModelConfig(cfg fakeModelPIIConfig) echo.MiddlewareFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRedactor(ids ...string) *Redactor {
|
||||
patterns, err := Compile(pick(DefaultPatterns(), ids))
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
|
||||
return NewRedactor(patterns)
|
||||
// resolverFor returns a NERDetectorResolver that maps each named model to
|
||||
// the supplied NERConfig. Names absent from the map resolve to (zero,
|
||||
// false) so the middleware fails closed — mirroring an unresolvable model.
|
||||
func resolverFor(byName map[string]NERConfig) NERDetectorResolver {
|
||||
return func(name string) (NERConfig, bool) {
|
||||
cfg, ok := byName[name]
|
||||
return cfg, ok
|
||||
}
|
||||
}
|
||||
|
||||
var _ = Describe("RequestMiddleware", func() {
|
||||
It("masks email", func() {
|
||||
red := newTestRedactor("email")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
user := &auth.User{ID: "user-1", Name: "alice"}
|
||||
func serve(body *fakeRequest, cfg fakeModelPIIConfig, mw echo.MiddlewareFunc, withConfig bool) (*httptest.ResponseRecorder, *bool) {
|
||||
called := new(bool)
|
||||
e := echo.New()
|
||||
chain := []echo.MiddlewareFunc{setRequestOnContext(body)}
|
||||
if withConfig {
|
||||
chain = append(chain, withModelConfig(cfg))
|
||||
}
|
||||
chain = append(chain, mw)
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
*called = true
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, chain...)
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
return w, called
|
||||
}
|
||||
|
||||
body := &fakeRequest{Messages: []string{"contact me at alice@example.com"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
func nerCfg(action Action, entities ...NEREntity) NERConfig {
|
||||
return NERConfig{
|
||||
Detector: &stubNERDetector{entities: entities},
|
||||
DefaultAction: action,
|
||||
}
|
||||
}
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw, func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
// Inject the user as if upstream auth ran.
|
||||
return func(c echo.Context) error {
|
||||
c.Set("auth_user", user)
|
||||
return next(c)
|
||||
}
|
||||
})
|
||||
var _ = Describe("RequestMiddleware (NER)", func() {
|
||||
store := func() EventStore { return NewMemoryEventStore(0) }
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
It("masks a detected entity end-to-end", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Hi I'm Alice today"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{
|
||||
"privacy-filter": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
|
||||
})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"privacy-filter"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK), "body=%s", w.Body.String())
|
||||
Expect(body.Messages[0]).NotTo(ContainSubstring("alice@example.com"), "request body should be redacted in place")
|
||||
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:email]"))
|
||||
|
||||
events, err := store.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(err).NotTo(HaveOccurred(), "list events")
|
||||
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:ner:PER]"))
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].PatternID).To(Equal("email"))
|
||||
Expect(events[0].PatternID).To(Equal("ner:PER"))
|
||||
Expect(events[0].Direction).To(Equal(DirectionIn))
|
||||
})
|
||||
|
||||
It("blocks api key", func() {
|
||||
red := newTestRedactor("api_key_prefix")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
|
||||
body := &fakeRequest{Messages: []string{"my key is sk-abcdefghijklmnopqrstuvwxyz0123456789"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
handlerCalled := false
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
handlerCalled = true
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusBadRequest), "expected 400 on block; body=%s", w.Body.String())
|
||||
Expect(handlerCalled).To(BeFalse(), "handler must not run when request is blocked")
|
||||
// Ensure the matched value never appears in the response body.
|
||||
Expect(w.Body.String()).NotTo(ContainSubstring("abcdefghijklmnopqrstuvwxyz0123456789"), "blocked response leaks the matched value")
|
||||
It("blocks (400) when a detected entity's action is block", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"my password is hunter2 ok"}}
|
||||
cfg := NERConfig{
|
||||
Detector: &stubNERDetector{entities: []NEREntity{{Group: "PASSWORD", Start: 15, End: 22, Score: 0.99}}},
|
||||
EntityActions: map[string]Action{"PASSWORD": ActionBlock},
|
||||
}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{"pf": cfg})))
|
||||
w, called := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusBadRequest), "body=%s", w.Body.String())
|
||||
Expect(*called).To(BeFalse(), "handler must not run when blocked")
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(w.Body.Bytes(), &resp)).To(Succeed())
|
||||
errBlock, ok := resp["error"].(map[string]any)
|
||||
Expect(ok).To(BeTrue())
|
||||
errBlock, _ := resp["error"].(map[string]any)
|
||||
Expect(errBlock["type"]).To(Equal("pii_blocked"))
|
||||
})
|
||||
|
||||
It("allow leaves text intact but still records an event", func() {
|
||||
patterns, _ := Compile([]Pattern{{
|
||||
ID: "email", Description: "Email", Action: ActionAllow, MaxMatchLength: 254,
|
||||
}})
|
||||
red := NewRedactor(patterns)
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
|
||||
It("allow leaves text intact but records an event", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"hi at alice@example.com"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
cfg := NERConfig{
|
||||
Detector: &stubNERDetector{entities: []NEREntity{{Group: "EMAIL", Start: 6, End: 23, Score: 0.9}}},
|
||||
EntityActions: map[string]Action{"EMAIL": ActionAllow},
|
||||
}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{"pf": cfg})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
// allow does NOT mutate the body — the model still sees the email.
|
||||
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "allow should leave text intact")
|
||||
// ...but the detection is still recorded for audit.
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1), "allow should still record a PIIEvent")
|
||||
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"))
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].Action).To(Equal(ActionAllow))
|
||||
})
|
||||
|
||||
It("no match passes through", func() {
|
||||
red := newTestRedactor()
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
|
||||
It("passes through on no match", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"perfectly innocent text"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{"pf": nerCfg(ActionMask)})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(Equal("perfectly innocent text"), "body should be untouched")
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(BeEmpty(), "expected 0 events on no-match input")
|
||||
Expect(body.Messages[0]).To(Equal("perfectly innocent text"))
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("skips when model config disabled", func() {
|
||||
// Per-model gating is the new contract: a model with PIIIsEnabled
|
||||
// returning false must bypass redaction entirely, even if the
|
||||
// global redactor has matching patterns.
|
||||
red := newTestRedactor("email")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
|
||||
body := &fakeRequest{Messages: []string{"contact alice@example.com"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: false}), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
It("skips when the model has PII disabled", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Hi I'm Alice"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{
|
||||
"pf": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
|
||||
})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: false, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "disabled model must not redact")
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(BeEmpty(), "disabled model must produce no events")
|
||||
Expect(body.Messages[0]).To(Equal("Hi I'm Alice"), "disabled model must not redact")
|
||||
})
|
||||
|
||||
It("fails closed without model config", func() {
|
||||
// Routes that wire the middleware before SetModelAndConfig, or
|
||||
// non-chat routes lacking a model, hit this path. The contract
|
||||
// is fail-closed: pass through without redaction so a missing
|
||||
// model can't accidentally leak through global defaults.
|
||||
red := newTestRedactor("email")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
|
||||
body := &fakeRequest{Messages: []string{"contact alice@example.com"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
// Note: no withModelConfig in the chain.
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
It("passes through when the model lists no detectors", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Hi I'm Alice"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "missing ModelPIIConfig should fail-closed (no redaction)")
|
||||
Expect(body.Messages[0]).To(Equal("Hi I'm Alice"))
|
||||
})
|
||||
|
||||
It("applies per-model override", func() {
|
||||
// email defaults to mask. A per-model override upgrades it to
|
||||
// block. The middleware short-circuits with 400, the request
|
||||
// body is never touched, and the events log records action=block.
|
||||
red := newTestRedactor("email")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
It("fails closed without a model config", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Hi I'm Alice"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{
|
||||
"pf": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
|
||||
})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{}, mw, false) // no model config on context
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(Equal("Hi I'm Alice"), "missing ModelPIIConfig should pass through")
|
||||
})
|
||||
|
||||
It("unions multiple detectors", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Alice at acme"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{
|
||||
"names": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 0, End: 5, Score: 0.9}),
|
||||
"orgs": nerCfg(ActionMask, NEREntity{Group: "ORG", Start: 9, End: 13, Score: 0.9}),
|
||||
})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"names", "orgs"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:ner:PER]"))
|
||||
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:ner:ORG]"))
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(2))
|
||||
})
|
||||
|
||||
It("fails closed (503) when a detector errors", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"contact alice@example.com"}}
|
||||
mw := RequestMiddleware(red, store, fakeAdapter(), nil)
|
||||
cfg := NERConfig{Detector: &stubNERDetector{err: errors.New("backend offline")}, DefaultAction: ActionMask}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{"pf": cfg})))
|
||||
w, called := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
e := echo.New()
|
||||
handlerCalled := false
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
handlerCalled = true
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body),
|
||||
withModelConfig(fakeModelPIIConfig{
|
||||
enabled: true,
|
||||
overrides: map[string]string{"email": "block"},
|
||||
}), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusBadRequest), "expected 400 from override-block; body=%s", w.Body.String())
|
||||
Expect(handlerCalled).To(BeFalse(), "handler must not run when override blocks")
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(w.Code).To(Equal(http.StatusServiceUnavailable), "body=%s", w.Body.String())
|
||||
Expect(*called).To(BeFalse())
|
||||
Expect(body.Messages[0]).To(ContainSubstring("alice@example.com"), "request body must be untouched on a fail-closed block")
|
||||
var resp map[string]any
|
||||
Expect(json.Unmarshal(w.Body.Bytes(), &resp)).To(Succeed())
|
||||
errBlock, _ := resp["error"].(map[string]any)
|
||||
Expect(errBlock["type"]).To(Equal("pii_ner_unavailable"))
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].Action).To(Equal(ActionBlock), "event must record the resolved (override) action")
|
||||
Expect(events[0].PatternID).To(Equal(nerUnavailablePattern))
|
||||
})
|
||||
|
||||
It("fails closed (503) when a configured detector can't be resolved", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"contact alice@example.com"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{}))) // "missing" not present
|
||||
w, called := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"missing"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusServiceUnavailable))
|
||||
Expect(*called).To(BeFalse())
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].PatternID).To(Equal(nerUnavailablePattern))
|
||||
})
|
||||
|
||||
It("nil redactor is passthrough", func() {
|
||||
body := &fakeRequest{Messages: []string{"alice@example.com"}}
|
||||
mw := RequestMiddleware(nil, nil, fakeAdapter(), nil)
|
||||
|
||||
e := echo.New()
|
||||
e.POST("/chat", func(c echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]string{"ok": "yes"})
|
||||
}, setRequestOnContext(body), withModelConfig(fakeModelPIIConfig{enabled: true}), mw)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat", strings.NewReader(`{}`))
|
||||
w := httptest.NewRecorder()
|
||||
e.ServeHTTP(w, req)
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(Equal("alice@example.com"), "nil redactor must be a no-op")
|
||||
})
|
||||
|
||||
It("WithPolicyResolver enables a model the per-model config left off (global default)", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Hi I'm Alice today"}}
|
||||
// The per-model config is disabled with no detectors; the policy
|
||||
// resolver (instance-wide default) turns it on and supplies one.
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{
|
||||
"global-pf": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
|
||||
})),
|
||||
WithPolicyResolver(func(_ any) (bool, []string) { return true, []string{"global-pf"} }))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: false}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK), "body=%s", w.Body.String())
|
||||
Expect(body.Messages[0]).To(ContainSubstring("[REDACTED:ner:PER]"))
|
||||
})
|
||||
|
||||
It("WithPolicyResolver returning disabled short-circuits an otherwise-enabled model", func() {
|
||||
st := store()
|
||||
body := &fakeRequest{Messages: []string{"Hi I'm Alice today"}}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{
|
||||
"pf": nerCfg(ActionMask, NEREntity{Group: "PER", Start: 6, End: 11, Score: 0.95}),
|
||||
})),
|
||||
WithPolicyResolver(func(_ any) (bool, []string) { return false, nil }))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK))
|
||||
Expect(body.Messages[0]).To(Equal("Hi I'm Alice today"), "resolver disabled => no redaction")
|
||||
})
|
||||
|
||||
It("scans all messages as one document so earlier-message context applies", func() {
|
||||
st := store()
|
||||
// The detector (pinAfterCard) only recognises "4421" when "card"
|
||||
// appears earlier in the SAME text it is handed — so this only
|
||||
// masks if the middleware joins the messages before scanning.
|
||||
body := &fakeRequest{Messages: []string{
|
||||
"What are the last four digits of your card?",
|
||||
"it is 4421 ok",
|
||||
}}
|
||||
cfg := NERConfig{Detector: &funcNERDetector{fn: pinAfterCard}, DefaultAction: ActionMask}
|
||||
mw := RequestMiddleware(&Redactor{}, st, fakeAdapter(), nil,
|
||||
WithNERResolver(resolverFor(map[string]NERConfig{"pf": cfg})))
|
||||
w, _ := serve(body, fakeModelPIIConfig{enabled: true, detectors: []string{"pf"}}, mw, true)
|
||||
|
||||
Expect(w.Code).To(Equal(http.StatusOK), "body=%s", w.Body.String())
|
||||
Expect(body.Messages[0]).To(Equal("What are the last four digits of your card?"), "question untouched")
|
||||
Expect(body.Messages[1]).To(Equal("it is [REDACTED:ner:PIN] ok"))
|
||||
events, _ := st.List(context.Background(), ListQuery{Limit: 100})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].ByteOffset).To(Equal(6), "event offsets are message-local")
|
||||
})
|
||||
})
|
||||
|
||||
@@ -28,6 +28,10 @@ type NEREntity struct {
|
||||
Start int
|
||||
End int
|
||||
Score float32
|
||||
// Text is the matched substring as the detector saw it. Carried for
|
||||
// debug logging only (the persisted PIIEvent never stores the raw
|
||||
// value); the redactor re-slices the original text for masking.
|
||||
Text string
|
||||
}
|
||||
|
||||
// NERConfig configures the encoder tier for one redactor invocation.
|
||||
@@ -56,8 +60,22 @@ type NERConfig struct {
|
||||
// entities silently" — useful when the model returns a broad
|
||||
// taxonomy but the admin only cares about a subset.
|
||||
DefaultAction Action
|
||||
|
||||
// Source labels where this detector's hits come from. It becomes the
|
||||
// PatternID prefix on events and the [REDACTED:<id>] mask, so neural NER
|
||||
// detections (Source "ner") and deterministic pattern-matcher detections
|
||||
// (Source "pattern") are told apart in the events log and to the model.
|
||||
// Empty defaults to "ner" for backward compatibility.
|
||||
Source string
|
||||
}
|
||||
|
||||
// Detector source labels (the PatternID prefix). Kept short and stable —
|
||||
// they appear in the events log and the [REDACTED:...] mask.
|
||||
const (
|
||||
SourceNER = "ner"
|
||||
SourcePattern = "pattern"
|
||||
)
|
||||
|
||||
// ResolveAction returns the action configured for a detected entity
|
||||
// group, falling back to DefaultAction. Returns ("", false) when the
|
||||
// entity should be ignored entirely (no override + no default).
|
||||
@@ -71,13 +89,39 @@ func (c NERConfig) ResolveAction(group string) (Action, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// nerPatternID returns the synthetic pattern ID that audit rows carry
|
||||
// for NER hits. Prefixing with "ner:" keeps these distinguishable from
|
||||
// regex pattern IDs in the events tab and in filter queries; admins
|
||||
// can switch off a single entity type with the same Disabled-pattern
|
||||
// machinery used for regex.
|
||||
func nerPatternID(group string) string {
|
||||
return "ner:" + group
|
||||
// NERConfigFromRaw builds a typed NERConfig from a detector plus the raw
|
||||
// policy strings carried on a detector model's pii_detection config. An
|
||||
// empty or invalid default_action becomes ActionMask — the safe-by-default
|
||||
// policy for a PII filter (a detected entity is masked unless an admin
|
||||
// downgrades it). Unknown per-entity actions are dropped (and logged by
|
||||
// validActions). This is the single conversion point the application-layer
|
||||
// resolver uses, so the detector model's policy reaches the redactor in
|
||||
// exactly one shape. source labels the detector kind (SourceNER /
|
||||
// SourcePattern) and becomes the PatternID prefix; empty defaults to
|
||||
// SourceNER.
|
||||
func NERConfigFromRaw(detector NERDetector, minScore float32, defaultAction string, entityActions map[string]string, source string) NERConfig {
|
||||
if source == "" {
|
||||
source = SourceNER
|
||||
}
|
||||
return NERConfig{
|
||||
Detector: detector,
|
||||
MinScore: minScore,
|
||||
DefaultAction: validActionOr(defaultAction, ActionMask),
|
||||
EntityActions: validActions(entityActions),
|
||||
Source: source,
|
||||
}
|
||||
}
|
||||
|
||||
// patternID returns the synthetic pattern ID that audit rows and masks carry
|
||||
// for this detector's hits, e.g. "ner:EMAIL" or "pattern:ANTHROPIC_KEY". The
|
||||
// source prefix keeps neural and deterministic detections distinguishable in
|
||||
// the events tab and in pattern_id filter queries.
|
||||
func (c NERConfig) patternID(group string) string {
|
||||
source := c.Source
|
||||
if source == "" {
|
||||
source = SourceNER
|
||||
}
|
||||
return source + ":" + group
|
||||
}
|
||||
|
||||
// errNERDetector is a NERDetector that always returns the wrapped
|
||||
|
||||
@@ -9,8 +9,7 @@ import (
|
||||
)
|
||||
|
||||
// stubNERDetector returns a fixed slice of entities and tracks call
|
||||
// count so tests can assert the detector isn't called when text is
|
||||
// empty / no patterns / detector disabled.
|
||||
// count so tests can assert the detector isn't called when text is empty.
|
||||
type stubNERDetector struct {
|
||||
entities []NEREntity
|
||||
err error
|
||||
@@ -22,43 +21,39 @@ func (s *stubNERDetector) Detect(_ context.Context, _ string) ([]NEREntity, erro
|
||||
return s.entities, s.err
|
||||
}
|
||||
|
||||
var _ = Describe("RedactWithNER", func() {
|
||||
It("nil detector is regex-only", func() {
|
||||
// When the NER tier is disabled (Detector == nil) the redactor
|
||||
// must behave exactly like the existing regex-only path — no
|
||||
// detector call, same Result shape, no error.
|
||||
r := NewRedactor([]Pattern{pickEmail()})
|
||||
res, err := r.RedactWithNER(context.Background(), "ping me at alice@example.com", nil, NERConfig{})
|
||||
var _ = Describe("RedactNER", func() {
|
||||
It("no detectors is a no-op", func() {
|
||||
res, err := RedactNER(context.Background(), "ping me at alice@example.com", nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:email]"), "regex tier should still run when Detector is nil")
|
||||
Expect(res.Redacted).To(Equal("ping me at alice@example.com"))
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("applies entity actions", func() {
|
||||
det := &stubNERDetector{entities: []NEREntity{
|
||||
{Group: "PER", Start: 6, End: 11, Score: 0.95}, // "Alice" in "Hi I'm Alice today"
|
||||
}}
|
||||
r := NewRedactor(nil)
|
||||
res, err := r.RedactWithNER(context.Background(), "Hi I'm Alice today", nil, NERConfig{
|
||||
res, err := RedactNER(context.Background(), "Hi I'm Alice today", []NERConfig{{
|
||||
Detector: det,
|
||||
EntityActions: map[string]Action{"PER": ActionMask},
|
||||
})
|
||||
}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(det.calls).To(Equal(1))
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:PER]"))
|
||||
Expect(res.Spans).To(HaveLen(1))
|
||||
Expect(res.Spans[0].Pattern).To(Equal("ner:PER"))
|
||||
Expect(res.Spans[0].Action).To(Equal(ActionMask))
|
||||
})
|
||||
|
||||
It("filters below MinScore", func() {
|
||||
det := &stubNERDetector{entities: []NEREntity{
|
||||
{Group: "PER", Start: 0, End: 5, Score: 0.20},
|
||||
}}
|
||||
r := NewRedactor(nil)
|
||||
res, err := r.RedactWithNER(context.Background(), "Alice", nil, NERConfig{
|
||||
res, err := RedactNER(context.Background(), "Alice", []NERConfig{{
|
||||
Detector: det,
|
||||
MinScore: 0.50,
|
||||
EntityActions: map[string]Action{"PER": ActionMask},
|
||||
})
|
||||
}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(Equal("Alice"), "low-confidence entity should be dropped")
|
||||
})
|
||||
@@ -67,108 +62,120 @@ var _ = Describe("RedactWithNER", func() {
|
||||
det := &stubNERDetector{entities: []NEREntity{
|
||||
{Group: "ORG", Start: 7, End: 11, Score: 0.9}, // "Acme" in "Hello, Acme!"
|
||||
}}
|
||||
r := NewRedactor(nil)
|
||||
res, err := r.RedactWithNER(context.Background(), "Hello, Acme!", nil, NERConfig{
|
||||
res, err := RedactNER(context.Background(), "Hello, Acme!", []NERConfig{{
|
||||
Detector: det,
|
||||
DefaultAction: ActionMask,
|
||||
})
|
||||
}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:ORG]"), "DefaultAction should apply to ORG")
|
||||
})
|
||||
|
||||
It("drops unconfigured groups with no default", func() {
|
||||
// EntityActions has no entry for ORG and DefaultAction is empty —
|
||||
// the detected entity must be ignored entirely (no audit row, no
|
||||
// redaction).
|
||||
det := &stubNERDetector{entities: []NEREntity{
|
||||
{Group: "ORG", Start: 0, End: 4, Score: 0.9},
|
||||
}}
|
||||
r := NewRedactor(nil)
|
||||
res, err := r.RedactWithNER(context.Background(), "Acme", nil, NERConfig{
|
||||
res, err := RedactNER(context.Background(), "Acme", []NERConfig{{
|
||||
Detector: det,
|
||||
EntityActions: map[string]Action{"PER": ActionMask}, // ORG is unconfigured
|
||||
})
|
||||
}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(Equal("Acme"))
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("overlapping hits keep stronger action", func() {
|
||||
// Regex marks 0..10 as mask; NER marks 5..15 as block. After
|
||||
// merge, the union 0..15 keeps the strongest action (block).
|
||||
pat := Pattern{ID: "test", Action: ActionMask, regex: rangeRegex(0, 10)}
|
||||
r := NewRedactor([]Pattern{pat})
|
||||
det := &stubNERDetector{entities: []NEREntity{
|
||||
{Group: "PER", Start: 5, End: 15, Score: 0.9},
|
||||
}}
|
||||
It("unions multiple detectors and keeps the stronger action on overlap", func() {
|
||||
// Detector A marks 0..10 as mask; detector B marks 5..15 as block.
|
||||
// After merge, the union 0..15 keeps the strongest action (block).
|
||||
detA := &stubNERDetector{entities: []NEREntity{{Group: "A", Start: 0, End: 10, Score: 0.9}}}
|
||||
detB := &stubNERDetector{entities: []NEREntity{{Group: "B", Start: 5, End: 15, Score: 0.9}}}
|
||||
text := "0123456789ABCDEF"
|
||||
res, err := r.RedactWithNER(context.Background(), text, nil, NERConfig{
|
||||
Detector: det,
|
||||
EntityActions: map[string]Action{"PER": ActionBlock},
|
||||
res, err := RedactNER(context.Background(), text, []NERConfig{
|
||||
{Detector: detA, EntityActions: map[string]Action{"A": ActionMask}},
|
||||
{Detector: detB, EntityActions: map[string]Action{"B": ActionBlock}},
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(detA.calls).To(Equal(1))
|
||||
Expect(detB.calls).To(Equal(1))
|
||||
Expect(res.Blocked).To(BeTrue(), "overlapping mask+block should set Blocked=true")
|
||||
})
|
||||
|
||||
It("detector error returns regex result and error", func() {
|
||||
// Fail-open: when the NER detector errors, the redactor still
|
||||
// returns regex-tier hits so an offline NER backend doesn't strip
|
||||
// the cheap protection. Caller can read the error and decide
|
||||
// whether to surface it.
|
||||
det := &stubNERDetector{err: errors.New("backend offline")}
|
||||
r := NewRedactor([]Pattern{pickEmail()})
|
||||
res, err := r.RedactWithNER(context.Background(), "ping alice@example.com", nil, NERConfig{
|
||||
Detector: det,
|
||||
DefaultAction: ActionMask,
|
||||
It("returns a best-effort result and the error when a detector fails (fail-closed contract)", func() {
|
||||
// One healthy detector, one failing. RedactNER returns the healthy
|
||||
// detector's hits AND the error, so the caller can fail closed.
|
||||
good := &stubNERDetector{entities: []NEREntity{{Group: "PER", Start: 0, End: 5, Score: 0.9}}}
|
||||
bad := &stubNERDetector{err: errors.New("backend offline")}
|
||||
res, err := RedactNER(context.Background(), "Alice", []NERConfig{
|
||||
{Detector: good, DefaultAction: ActionMask},
|
||||
{Detector: bad, DefaultAction: ActionMask},
|
||||
})
|
||||
Expect(err).To(HaveOccurred(), "expected detector error to surface")
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:email]"), "regex tier should still apply on NER failure")
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:PER]"), "healthy detector's hits should still apply")
|
||||
})
|
||||
|
||||
It("out-of-bounds offsets are skipped", func() {
|
||||
// A misconfigured / buggy backend could return offsets past the
|
||||
// end of text. The redactor must not panic on slice OOB.
|
||||
It("skips out-of-bounds offsets without panicking", func() {
|
||||
det := &stubNERDetector{entities: []NEREntity{
|
||||
{Group: "PER", Start: 0, End: 999, Score: 0.9},
|
||||
{Group: "PER", Start: -1, End: 3, Score: 0.9},
|
||||
{Group: "PER", Start: 5, End: 5, Score: 0.9}, // zero-length
|
||||
}}
|
||||
r := NewRedactor(nil)
|
||||
res, err := r.RedactWithNER(context.Background(), "Alice", nil, NERConfig{
|
||||
res, err := RedactNER(context.Background(), "Alice", []NERConfig{{
|
||||
Detector: det,
|
||||
DefaultAction: ActionMask,
|
||||
})
|
||||
}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(Equal("Alice"))
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
// --- test helpers ---
|
||||
var _ = Describe("NERConfigFromRaw", func() {
|
||||
det := &stubNERDetector{}
|
||||
|
||||
// rangeMatcher is a deterministic regexpMatcher stub: it claims one
|
||||
// fixed range regardless of input. Lets the overlap-merge test
|
||||
// produce a known regex/NER intersection without depending on a real
|
||||
// compiled regex.
|
||||
type rangeMatcher struct{ start, end int }
|
||||
It("defaults an empty default_action to mask and an empty source to ner", func() {
|
||||
cfg := NERConfigFromRaw(det, 0.4, "", nil, "")
|
||||
Expect(cfg.DefaultAction).To(Equal(ActionMask))
|
||||
Expect(cfg.MinScore).To(BeNumerically("~", 0.4, 1e-6))
|
||||
Expect(cfg.Source).To(Equal(SourceNER))
|
||||
Expect(cfg.patternID("EMAIL")).To(Equal("ner:EMAIL"))
|
||||
})
|
||||
|
||||
func (m rangeMatcher) FindAllStringIndex(_ string, _ int) [][]int {
|
||||
return [][]int{{m.start, m.end}}
|
||||
}
|
||||
It("passes through valid actions and drops invalid ones", func() {
|
||||
cfg := NERConfigFromRaw(det, 0, "block", map[string]string{
|
||||
"PASSWORD": "block",
|
||||
"EMAIL": "mask",
|
||||
"BOGUS": "nonsense", // dropped
|
||||
}, SourceNER)
|
||||
Expect(cfg.DefaultAction).To(Equal(ActionBlock))
|
||||
Expect(cfg.EntityActions).To(HaveKeyWithValue("PASSWORD", ActionBlock))
|
||||
Expect(cfg.EntityActions).To(HaveKeyWithValue("EMAIL", ActionMask))
|
||||
Expect(cfg.EntityActions).NotTo(HaveKey("BOGUS"))
|
||||
})
|
||||
|
||||
func rangeRegex(start, end int) regexpMatcher { return rangeMatcher{start: start, end: end} }
|
||||
It("prefixes pattern-detector hits with the pattern source", func() {
|
||||
cfg := NERConfigFromRaw(det, 0, "mask", nil, SourcePattern)
|
||||
Expect(cfg.Source).To(Equal(SourcePattern))
|
||||
Expect(cfg.patternID("ANTHROPIC_KEY")).To(Equal("pattern:ANTHROPIC_KEY"))
|
||||
})
|
||||
})
|
||||
|
||||
// pickEmail returns the compiled "email" pattern from DefaultPatterns
|
||||
// — the NER tests use it as the regex tier's contribution.
|
||||
func pickEmail() Pattern {
|
||||
for _, p := range DefaultPatterns() {
|
||||
if p.ID == "email" {
|
||||
compiled, err := Compile([]Pattern{p})
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
|
||||
return compiled[0]
|
||||
}
|
||||
}
|
||||
Fail("email pattern missing from DefaultPatterns")
|
||||
return Pattern{}
|
||||
}
|
||||
var _ = Describe("NERConfig.ResolveAction", func() {
|
||||
It("prefers an explicit entity action over the default", func() {
|
||||
cfg := NERConfig{EntityActions: map[string]Action{"EMAIL": ActionBlock}, DefaultAction: ActionMask}
|
||||
a, ok := cfg.ResolveAction("EMAIL")
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(a).To(Equal(ActionBlock))
|
||||
})
|
||||
|
||||
It("falls back to the default action", func() {
|
||||
cfg := NERConfig{DefaultAction: ActionMask}
|
||||
a, ok := cfg.ResolveAction("ANYTHING")
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(a).To(Equal(ActionMask))
|
||||
})
|
||||
|
||||
It("ignores a group with no override and no default", func() {
|
||||
cfg := NERConfig{}
|
||||
_, ok := cfg.ResolveAction("ANYTHING")
|
||||
Expect(ok).To(BeFalse())
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,188 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// regexpMatcher is a thin wrapper so tests can swap in a deterministic
|
||||
// matcher without touching the regexp package. Real usage uses
|
||||
// regexpMatcherFromPattern; tests can construct fakes.
|
||||
type regexpMatcher interface {
|
||||
FindAllStringIndex(s string, n int) [][]int
|
||||
}
|
||||
|
||||
type goRegexp struct{ r *regexp.Regexp }
|
||||
|
||||
func (g goRegexp) FindAllStringIndex(s string, n int) [][]int {
|
||||
return g.r.FindAllStringIndex(s, n)
|
||||
}
|
||||
|
||||
// DefaultPatterns returns the built-in regex set. Each entry includes
|
||||
// a conservative MaxMatchLength so the streaming filter can size its
|
||||
// tail buffer without re-parsing the regex at runtime.
|
||||
//
|
||||
// Caveats by design:
|
||||
// - The phone pattern matches international and US formats but does
|
||||
// not validate area codes. False positives on numbers that look
|
||||
// phone-like (e.g., timestamps in some formats) are accepted in
|
||||
// return for reliable coverage.
|
||||
// - The credit card pattern requires the Luhn check (verifyLuhn) to
|
||||
// reduce false positives — random 16-digit strings won't match.
|
||||
// - The API-key pattern targets common provider prefixes (sk-, pk-,
|
||||
// xoxb-, ghp_, github_pat_) rather than guessing entropy. Adding
|
||||
// new providers should append a new Pattern, not extend an
|
||||
// existing alternation, so the admin UI can show one row per
|
||||
// provider with its own toggle.
|
||||
func DefaultPatterns() []Pattern {
|
||||
return []Pattern{
|
||||
{
|
||||
ID: "email",
|
||||
Description: "Email address",
|
||||
Action: ActionMask,
|
||||
MaxMatchLength: 254, // RFC 5321 max
|
||||
},
|
||||
{
|
||||
ID: "phone",
|
||||
Description: "Phone number (international or US format)",
|
||||
Action: ActionMask,
|
||||
MaxMatchLength: 24,
|
||||
},
|
||||
{
|
||||
ID: "ssn",
|
||||
Description: "US Social Security Number (NNN-NN-NNNN)",
|
||||
Action: ActionMask,
|
||||
MaxMatchLength: 11,
|
||||
},
|
||||
{
|
||||
ID: "credit_card",
|
||||
Description: "Credit card number (Luhn-verified)",
|
||||
Action: ActionMask,
|
||||
MaxMatchLength: 19,
|
||||
},
|
||||
{
|
||||
ID: "ipv4",
|
||||
Description: "IPv4 address",
|
||||
Action: ActionMask,
|
||||
MaxMatchLength: 15,
|
||||
},
|
||||
{
|
||||
ID: "api_key_prefix",
|
||||
Description: "Common API key prefixes (sk-, pk-, xoxb-, ghp_, github_pat_)",
|
||||
Action: ActionBlock, // tighter default — leaked credentials are higher harm
|
||||
MaxMatchLength: 200,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// patternRegexps maps Pattern.ID to its compiled regex. Kept separate
|
||||
// from the Pattern struct so DefaultPatterns can be data-only and
|
||||
// tests can swap matchers via Compile().
|
||||
var patternRegexps = map[string]*regexp.Regexp{
|
||||
// Pragmatic email — does not implement RFC 5322 in full (no one
|
||||
// sane does in a regex). Catches the common shape; the encoder
|
||||
// NER tier (future) catches edge cases.
|
||||
"email": regexp.MustCompile(`(?i)[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}`),
|
||||
// US: (123) 456-7890, 123-456-7890, 123.456.7890, 1234567890.
|
||||
// International: +<country>-<area>-<rest> with separators.
|
||||
"phone": regexp.MustCompile(`(?:\+?\d{1,3}[\s\-.]?)?(?:\(\d{3}\)|\d{3})[\s\-.]?\d{3}[\s\-.]?\d{4}`),
|
||||
"ssn": regexp.MustCompile(`\b\d{3}-\d{2}-\d{4}\b`),
|
||||
// 13-19 digit Luhn-eligible runs. The verifier in match() rejects
|
||||
// non-Luhn matches.
|
||||
"credit_card": regexp.MustCompile(`\b(?:\d[ \-]?){13,19}\b`),
|
||||
"ipv4": regexp.MustCompile(`\b(?:\d{1,3}\.){3}\d{1,3}\b`),
|
||||
// Common provider prefixes; each alternative is a separate
|
||||
// well-known marker rather than a permissive entropy match.
|
||||
"api_key_prefix": regexp.MustCompile(`(?:sk-[A-Za-z0-9]{20,}|pk-[A-Za-z0-9]{20,}|xoxb-[A-Za-z0-9\-]{20,}|ghp_[A-Za-z0-9]{20,}|github_pat_[A-Za-z0-9_]{20,})`),
|
||||
}
|
||||
|
||||
// Compile attaches matchers to each pattern. Patterns whose ID is not
|
||||
// in patternRegexps are returned as a typed error so an admin who
|
||||
// adds a custom pattern via config gets a clear "no regex registered"
|
||||
// message instead of silent skip.
|
||||
func Compile(patterns []Pattern) ([]Pattern, error) {
|
||||
out := make([]Pattern, len(patterns))
|
||||
for i, p := range patterns {
|
||||
r, ok := patternRegexps[p.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("pii: no regex registered for pattern id %q", p.ID)
|
||||
}
|
||||
p.regex = goRegexp{r: r}
|
||||
out[i] = p
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// VerifyMatch applies pattern-specific post-checks (e.g. Luhn for
|
||||
// credit_card). Returns the original match or "" to discard it.
|
||||
func VerifyMatch(patternID, candidate string) string {
|
||||
switch patternID {
|
||||
case "credit_card":
|
||||
digits := stripNonDigits(candidate)
|
||||
if len(digits) < 13 || len(digits) > 19 {
|
||||
return ""
|
||||
}
|
||||
if !verifyLuhn(digits) {
|
||||
return ""
|
||||
}
|
||||
case "ipv4":
|
||||
// Each octet must be 0..255. The regex allows 0..999 since
|
||||
// regex isn't great at numeric ranges; we tighten here.
|
||||
for oct := range strings.SplitSeq(candidate, ".") {
|
||||
n := 0
|
||||
for _, c := range oct {
|
||||
if c < '0' || c > '9' {
|
||||
return ""
|
||||
}
|
||||
n = n*10 + int(c-'0')
|
||||
}
|
||||
if n > 255 {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
}
|
||||
return candidate
|
||||
}
|
||||
|
||||
func stripNonDigits(s string) string {
|
||||
var b strings.Builder
|
||||
b.Grow(len(s))
|
||||
for _, c := range s {
|
||||
if c >= '0' && c <= '9' {
|
||||
b.WriteRune(c)
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// verifyLuhn implements the Luhn checksum used by credit-card numbers.
|
||||
// Returns true iff the digits pass.
|
||||
func verifyLuhn(digits string) bool {
|
||||
sum := 0
|
||||
double := false
|
||||
for i := len(digits) - 1; i >= 0; i-- {
|
||||
d := int(digits[i] - '0')
|
||||
if double {
|
||||
d *= 2
|
||||
if d > 9 {
|
||||
d -= 9
|
||||
}
|
||||
}
|
||||
sum += d
|
||||
double = !double
|
||||
}
|
||||
return sum%10 == 0
|
||||
}
|
||||
|
||||
// MaxPatternLength returns the longest MaxMatchLength across the input
|
||||
// patterns. Used by the streaming filter to size its tail buffer.
|
||||
func MaxPatternLength(patterns []Pattern) int {
|
||||
max := 0
|
||||
for _, p := range patterns {
|
||||
if p.MaxMatchLength > max {
|
||||
max = p.MaxMatchLength
|
||||
}
|
||||
}
|
||||
return max
|
||||
}
|
||||
@@ -4,212 +4,152 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
)
|
||||
|
||||
// rawHit is one detection — regex-side or NER-side — before
|
||||
// overlap-merging. Lifted to file scope so the regex and NER
|
||||
// collectors can both produce them and feed the same merge/emit step.
|
||||
// rawHit is one detection before overlap-merging. Lifted to file scope so
|
||||
// the NER collector and the merge/emit step can share it.
|
||||
type rawHit struct {
|
||||
patternID string
|
||||
action Action
|
||||
start int
|
||||
end int
|
||||
score float32
|
||||
}
|
||||
|
||||
// Redactor scans text against a configured pattern set and applies the
|
||||
// per-pattern action. The pattern set itself is mutable at runtime via
|
||||
// SetAction (the /api/pii/patterns/:id admin endpoint mutates it
|
||||
// in-place); reads are guarded by a mutex so concurrent requests stay
|
||||
// race-free.
|
||||
type Redactor struct {
|
||||
mu sync.RWMutex
|
||||
patterns []Pattern
|
||||
maxLen int
|
||||
}
|
||||
// Redactor is a stateless handle for the PII subsystem. The regex tier
|
||||
// was removed: detection is driven entirely by per-model NER detectors
|
||||
// (see RedactNER), whose policy lives on each detector model's
|
||||
// pii_detection config. The type is retained (zero-field) as the
|
||||
// on/off sentinel the application wiring and middleware gate on, so a
|
||||
// nil *Redactor still means "PII subsystem unavailable".
|
||||
type Redactor struct{}
|
||||
|
||||
// NewRedactor constructs a redactor from a list of compiled patterns
|
||||
// (use Compile() to compile config-loaded patterns first). nil
|
||||
// patterns is valid and produces a no-op redactor — convenient for the
|
||||
// "PII disabled" deployment.
|
||||
func NewRedactor(patterns []Pattern) *Redactor {
|
||||
return &Redactor{
|
||||
patterns: patterns,
|
||||
maxLen: MaxPatternLength(patterns),
|
||||
}
|
||||
}
|
||||
|
||||
// MaxPatternLength is exposed so the streaming wrapper can size its
|
||||
// tail buffer to match.
|
||||
func (r *Redactor) MaxPatternLength() int { return r.maxLen }
|
||||
|
||||
// Patterns returns a copy of the configured pattern set so callers can
|
||||
// iterate without holding the redactor lock. The compiled regexes are
|
||||
// shared — they are immutable once built.
|
||||
func (r *Redactor) Patterns() []Pattern {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return slices.Clone(r.patterns)
|
||||
}
|
||||
|
||||
// SetAction overrides the action for a single pattern. Used by the
|
||||
// /api/pii/patterns/:id admin endpoint and the set_pii_pattern_action
|
||||
// MCP tool — transient until process restart unless persisted via
|
||||
// --pii-config.
|
||||
// RedactNER runs every configured NER detector over text, unions their
|
||||
// detections, and emits one redacted output. Each NERConfig carries its
|
||||
// own detector and policy (min score, entity→action map, default
|
||||
// action), so a consuming model that references several detector models
|
||||
// gets each model's policy applied to its own hits before the overlap
|
||||
// merge (block > mask > allow) resolves any span two detectors both
|
||||
// claim.
|
||||
//
|
||||
// Publishes a new slice so concurrent Redact callers iterating an
|
||||
// older snapshot don't race on the per-element Action string (Go
|
||||
// strings are not atomic two-word values).
|
||||
func (r *Redactor) SetAction(id string, action Action) error {
|
||||
if action != ActionMask && action != ActionBlock && action != ActionAllow {
|
||||
return fmt.Errorf("unknown action %q (must be mask, block, or allow)", action)
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for i := range r.patterns {
|
||||
if r.patterns[i].ID == id {
|
||||
next := slices.Clone(r.patterns)
|
||||
next[i].Action = action
|
||||
r.patterns = next
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("unknown pattern id %q", id)
|
||||
}
|
||||
|
||||
// SetDisabled toggles a pattern's enabled state in the live redactor.
|
||||
// Same COW publish as SetAction.
|
||||
func (r *Redactor) SetDisabled(id string, disabled bool) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for i := range r.patterns {
|
||||
if r.patterns[i].ID == id {
|
||||
next := slices.Clone(r.patterns)
|
||||
next[i].Disabled = disabled
|
||||
r.patterns = next
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("unknown pattern id %q", id)
|
||||
}
|
||||
|
||||
// Redact is a thin wrapper for callers that don't need per-request
|
||||
// action overrides. It applies each pattern's compiled-in default
|
||||
// action.
|
||||
func (r *Redactor) Redact(text string) Result {
|
||||
return r.RedactWithOverrides(text, nil)
|
||||
}
|
||||
|
||||
// RedactWithOverrides scans text and returns the result. The override
|
||||
// map is keyed by pattern id; when present, the value replaces the
|
||||
// pattern's compiled-in action for this call only — the redactor's
|
||||
// stored action is unchanged. Pattern ids missing from the map use
|
||||
// their stored action.
|
||||
// Any detector error is returned alongside a best-effort Result built
|
||||
// from the detectors that did succeed, so the caller can fail closed
|
||||
// (refuse the request) while still seeing what the healthy detectors
|
||||
// found. Configs with a nil Detector are skipped.
|
||||
//
|
||||
// For every match it records a Span (with HashPrefix, never the value)
|
||||
// and applies the resolved Action:
|
||||
// - block: sets Result.Blocked, leaves text intact (caller decides
|
||||
// whether to surface the redacted form).
|
||||
// - mask: replaces the span with maskFor(pattern.ID), sets Result.Masked.
|
||||
// - allow: leaves text intact and sets no flag (the span is still
|
||||
// recorded so the match is auditable).
|
||||
//
|
||||
// Spans are returned in the original input's coordinate system so the
|
||||
// PIIEvent record can be written without re-running the scan.
|
||||
func (r *Redactor) RedactWithOverrides(text string, overrides map[string]Action) Result {
|
||||
return r.redact(context.Background(), text, overrides, NERConfig{})
|
||||
}
|
||||
|
||||
// RedactWithNER is the encoder-tier variant: runs both the regex tier
|
||||
// (with per-pattern overrides) and the NER tier, merges hits, and
|
||||
// emits one redacted output. A nil NERConfig.Detector skips the NER
|
||||
// pass — callers can hand the same path the same NERConfig{} whether
|
||||
// or not the model has NER configured.
|
||||
//
|
||||
// Errors from the NER detector are returned alongside a best-effort
|
||||
// regex-only Result so the caller can decide whether to fail open
|
||||
// (return the regex Result, log the error) or fail closed (refuse the
|
||||
// request). The regex tier never errors.
|
||||
func (r *Redactor) RedactWithNER(ctx context.Context, text string, overrides map[string]Action, nerCfg NERConfig) (Result, error) {
|
||||
if nerCfg.Detector == nil {
|
||||
return r.redact(ctx, text, overrides, nerCfg), nil
|
||||
}
|
||||
hits, err := r.collectRegexHits(text, overrides)
|
||||
if err != nil {
|
||||
return Result{Redacted: text}, err
|
||||
}
|
||||
nerHits, nerErr := collectNERHits(ctx, text, nerCfg)
|
||||
if nerErr != nil {
|
||||
// Return the regex-only result so a NER-backend outage doesn't
|
||||
// strip the cheap protection. Caller decides fail-open vs
|
||||
// fail-closed via the returned error.
|
||||
return mergeAndEmit(text, hits), nerErr
|
||||
}
|
||||
return mergeAndEmit(text, append(hits, nerHits...)), nil
|
||||
}
|
||||
|
||||
// redact is the internal regex-only entry point. RedactWithOverrides
|
||||
// is the public wrapper; RedactWithNER routes through here only when
|
||||
// the NER detector is nil (so the call site doesn't need a separate
|
||||
// "regex-only" code path).
|
||||
func (r *Redactor) redact(_ context.Context, text string, overrides map[string]Action, _ NERConfig) Result {
|
||||
hits, _ := r.collectRegexHits(text, overrides)
|
||||
return mergeAndEmit(text, hits)
|
||||
}
|
||||
|
||||
// collectRegexHits walks the configured pattern set against text and
|
||||
// returns each verified match as a rawHit. The redactor lock is held
|
||||
// only long enough to snapshot the pattern slice — regex evaluation
|
||||
// runs lock-free against the snapshot, so SetAction/SetDisabled don't
|
||||
// stall a long-running Redact.
|
||||
func (r *Redactor) collectRegexHits(text string, overrides map[string]Action) ([]rawHit, error) {
|
||||
r.mu.RLock()
|
||||
patterns := r.patterns
|
||||
r.mu.RUnlock()
|
||||
|
||||
if len(patterns) == 0 || text == "" {
|
||||
return nil, nil
|
||||
// Package-level (no Redactor state): both the in-band request middleware
|
||||
// and the MITM request path call it with their own resolved []NERConfig.
|
||||
func RedactNER(ctx context.Context, text string, cfgs []NERConfig) (Result, error) {
|
||||
if text == "" || len(cfgs) == 0 {
|
||||
return Result{Redacted: text}, nil
|
||||
}
|
||||
var hits []rawHit
|
||||
for _, p := range patterns {
|
||||
if p.regex == nil {
|
||||
// Pattern declared but Compile() not called. Skip rather
|
||||
// than panic; the caller already saw an error from Compile.
|
||||
var firstErr error
|
||||
for _, cfg := range cfgs {
|
||||
if cfg.Detector == nil {
|
||||
continue
|
||||
}
|
||||
if p.Disabled {
|
||||
h, err := collectNERHits(ctx, text, cfg)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
action := p.Action
|
||||
if override, ok := overrides[p.ID]; ok {
|
||||
action = override
|
||||
hits = append(hits, h...)
|
||||
}
|
||||
return mergeAndEmit(text, hits), firstErr
|
||||
}
|
||||
|
||||
// segmentSeparator joins per-message texts into the single document
|
||||
// RedactNERSegments scans. Two newlines read as a paragraph break to the
|
||||
// NER encoder — neutral, in-distribution context — and never carry PII
|
||||
// themselves, so a detected span landing on the separator can only be the
|
||||
// fringe of an entity that started in a real segment.
|
||||
const segmentSeparator = "\n\n"
|
||||
|
||||
// RedactNERSegments scans texts as ONE concatenated document and maps the
|
||||
// detections back to one Result per input text. Scanning the segments
|
||||
// together is what gives the NER tier conversational context: whether
|
||||
// "jdoe_42" is a USERNAME or "4421" is a PIN is decided by the question
|
||||
// asked in the *previous* message, and a bidirectional encoder only sees
|
||||
// that context if both messages are in the same forward pass. (Measured on
|
||||
// privacy-filter-multilingual: "4421" alone detects nothing; preceded by
|
||||
// "What are the last four digits of your card?" it detects PIN at 0.726.)
|
||||
//
|
||||
// Span offsets in each Result are local to its text, so callers rewrite
|
||||
// fields in place exactly as with per-text RedactNER. A hit that crosses a
|
||||
// segment boundary is split and each fragment keeps the hit's action —
|
||||
// conservative, and only possible for an entity the model stretched across
|
||||
// the separator. Error semantics mirror RedactNER: best-effort results
|
||||
// plus the first detector error, so callers can fail closed.
|
||||
func RedactNERSegments(ctx context.Context, texts []string, cfgs []NERConfig) ([]Result, error) {
|
||||
results := make([]Result, len(texts))
|
||||
if len(texts) == 0 || len(cfgs) == 0 {
|
||||
for i := range results {
|
||||
results[i] = Result{Redacted: texts[i]}
|
||||
}
|
||||
idxs := p.regex.FindAllStringIndex(text, -1)
|
||||
for _, idx := range idxs {
|
||||
candidate := text[idx[0]:idx[1]]
|
||||
if VerifyMatch(p.ID, candidate) == "" {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
var joined strings.Builder
|
||||
starts := make([]int, len(texts))
|
||||
ends := make([]int, len(texts))
|
||||
for i, t := range texts {
|
||||
if i > 0 {
|
||||
joined.WriteString(segmentSeparator)
|
||||
}
|
||||
starts[i] = joined.Len()
|
||||
joined.WriteString(t)
|
||||
ends[i] = joined.Len()
|
||||
}
|
||||
doc := joined.String()
|
||||
|
||||
var hits []rawHit
|
||||
var firstErr error
|
||||
for _, cfg := range cfgs {
|
||||
if cfg.Detector == nil {
|
||||
continue
|
||||
}
|
||||
h, err := collectNERHits(ctx, doc, cfg)
|
||||
if err != nil {
|
||||
if firstErr == nil {
|
||||
firstErr = err
|
||||
}
|
||||
continue
|
||||
}
|
||||
hits = append(hits, h...)
|
||||
}
|
||||
|
||||
perSegment := make([][]rawHit, len(texts))
|
||||
for _, h := range hits {
|
||||
for i := range texts {
|
||||
s := max(h.start, starts[i])
|
||||
e := min(h.end, ends[i])
|
||||
if s >= e {
|
||||
continue
|
||||
}
|
||||
hits = append(hits, rawHit{
|
||||
patternID: p.ID,
|
||||
action: action,
|
||||
start: idx[0],
|
||||
end: idx[1],
|
||||
})
|
||||
local := h
|
||||
local.start = s - starts[i]
|
||||
local.end = e - starts[i]
|
||||
perSegment[i] = append(perSegment[i], local)
|
||||
}
|
||||
}
|
||||
return hits, nil
|
||||
for i := range texts {
|
||||
results[i] = mergeAndEmit(texts[i], perSegment[i])
|
||||
}
|
||||
return results, firstErr
|
||||
}
|
||||
|
||||
// collectNERHits invokes the configured NERDetector and converts each
|
||||
// returned entity into a rawHit using the NERConfig's action map.
|
||||
// Entities below MinScore or with no resolved action are dropped — the
|
||||
// detector doesn't know which entity groups the admin cares about, so
|
||||
// the redactor filters here.
|
||||
// the policy filters here.
|
||||
func collectNERHits(ctx context.Context, text string, cfg NERConfig) ([]rawHit, error) {
|
||||
if cfg.Detector == nil || text == "" {
|
||||
return nil, nil
|
||||
@@ -220,42 +160,58 @@ func collectNERHits(ctx context.Context, text string, cfg NERConfig) ([]rawHit,
|
||||
}
|
||||
var hits []rawHit
|
||||
for _, e := range entities {
|
||||
// One DEBUG line per raw detection with the model's confidence, the
|
||||
// byte range, the matched substring, and the policy decision. This is
|
||||
// the lowest-level view of why a request was masked/blocked — e.g. a
|
||||
// phone number scored as SSN — and answers "what was in that range and
|
||||
// how sure was the model" without re-running the detector. DEBUG-gated
|
||||
// because the matched value is sensitive.
|
||||
if e.Score < cfg.MinScore {
|
||||
xlog.Debug("pii/ner: detection dropped (below min score)",
|
||||
"group", e.Group, "score", e.Score, "min_score", cfg.MinScore,
|
||||
"start", e.Start, "end", e.End, "text", e.Text)
|
||||
continue
|
||||
}
|
||||
action, ok := cfg.ResolveAction(e.Group)
|
||||
if !ok {
|
||||
xlog.Debug("pii/ner: detection ignored (no action for group)",
|
||||
"group", e.Group, "score", e.Score,
|
||||
"start", e.Start, "end", e.End, "text", e.Text)
|
||||
continue
|
||||
}
|
||||
if e.Start < 0 || e.End <= e.Start || e.End > len(text) {
|
||||
// Defensive: the backend should return byte offsets into
|
||||
// the original text, but a misconfigured model could
|
||||
// produce garbage. Skip rather than panic on slice OOB.
|
||||
// Defensive: the backend should return byte offsets into the
|
||||
// original text, but a misconfigured model could produce
|
||||
// garbage. Skip rather than panic on slice OOB.
|
||||
xlog.Warn("pii/ner: detection has out-of-range offsets; skipping",
|
||||
"group", e.Group, "start", e.Start, "end", e.End, "text_len", len(text))
|
||||
continue
|
||||
}
|
||||
xlog.Debug("pii/ner: detection accepted",
|
||||
"group", e.Group, "score", e.Score, "action", action,
|
||||
"start", e.Start, "end", e.End, "text", e.Text)
|
||||
hits = append(hits, rawHit{
|
||||
patternID: nerPatternID(e.Group),
|
||||
patternID: cfg.patternID(e.Group),
|
||||
action: action,
|
||||
start: e.Start,
|
||||
end: e.End,
|
||||
score: e.Score,
|
||||
})
|
||||
}
|
||||
return hits, nil
|
||||
}
|
||||
|
||||
// mergeAndEmit handles the overlap-merge + masked-output step that
|
||||
// regex-only and combined regex+NER redactions both perform. Sorts by
|
||||
// mergeAndEmit handles the overlap-merge + masked-output step. Sorts by
|
||||
// start (stable on equal starts by descending action strength), drops
|
||||
// overlapping hits in favour of the stronger action, and walks the
|
||||
// text once to emit replacement spans.
|
||||
// overlapping hits in favour of the stronger action, and walks the text
|
||||
// once to emit replacement spans.
|
||||
func mergeAndEmit(text string, hits []rawHit) Result {
|
||||
if len(hits) == 0 {
|
||||
return Result{Redacted: text}
|
||||
}
|
||||
// Sort and deduplicate overlapping hits — when two patterns claim
|
||||
// the same span (e.g., a credit-card-shaped value also scans as
|
||||
// digits, or NER tags a span the regex also caught), keep the one
|
||||
// with the strongest action. Order: block > mask > allow.
|
||||
// Sort and deduplicate overlapping hits — when two detectors claim
|
||||
// the same span, keep the one with the strongest action. Order:
|
||||
// block > mask > allow.
|
||||
sort.Slice(hits, func(i, j int) bool {
|
||||
if hits[i].start != hits[j].start {
|
||||
return hits[i].start < hits[j].start
|
||||
@@ -270,6 +226,7 @@ func mergeAndEmit(text string, hits []rawHit) Result {
|
||||
if actionRank(h.action) > actionRank(last.action) {
|
||||
last.action = h.action
|
||||
last.patternID = h.patternID
|
||||
last.score = h.score
|
||||
}
|
||||
if h.end > last.end {
|
||||
last.end = h.end
|
||||
@@ -291,6 +248,8 @@ func mergeAndEmit(text string, hits []rawHit) Result {
|
||||
End: h.end,
|
||||
Pattern: h.patternID,
|
||||
HashPrefix: hashPrefix(matched),
|
||||
Action: h.action,
|
||||
Score: h.score,
|
||||
}
|
||||
res.Spans = append(res.Spans, span)
|
||||
|
||||
@@ -315,17 +274,15 @@ func mergeAndEmit(text string, hits []rawHit) Result {
|
||||
|
||||
// maskFor returns the placeholder that replaces a matched span. The
|
||||
// shape "[REDACTED:<id>]" is intentionally stable — it surfaces the
|
||||
// pattern id back to the model, which is sometimes useful (e.g., the
|
||||
// model can say "I see you redacted an email"). Admins who want a
|
||||
// less informative replacement can build one in front of this.
|
||||
// detector group back to the model (e.g. "I see you redacted an email").
|
||||
func maskFor(patternID string) string {
|
||||
return "[REDACTED:" + patternID + "]"
|
||||
}
|
||||
|
||||
// hashPrefix returns the first 8 chars of sha256(value). Two calls
|
||||
// with the same input produce the same prefix so an admin auditing
|
||||
// the PIIEvent log can spot a recurring leak ("the same SSN appears
|
||||
// 200 times this hour") without ever recovering the value.
|
||||
// hashPrefix returns the first 8 chars of sha256(value). Two calls with
|
||||
// the same input produce the same prefix so an admin auditing the
|
||||
// PIIEvent log can spot a recurring leak without ever recovering the
|
||||
// value.
|
||||
func hashPrefix(value string) string {
|
||||
sum := sha256.Sum256([]byte(value))
|
||||
return hex.EncodeToString(sum[:])[:8]
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// Redactor_SetActionConcurrentRedact pins the SetAction copy-on-
|
||||
// write contract: concurrent SetAction must not race with readers
|
||||
// iterating an older patterns snapshot. Run with -race to surface the
|
||||
// regression that motivated the COW (in-place mutation of the
|
||||
// per-element Action string is not atomic).
|
||||
var _ = Describe("Redactor", func() {
|
||||
It("SetAction concurrent with Redact", func() {
|
||||
patterns, err := Compile(DefaultPatterns())
|
||||
Expect(err).NotTo(HaveOccurred(), "compile")
|
||||
r := NewRedactor(patterns)
|
||||
|
||||
const writers = 4
|
||||
const readers = 8
|
||||
const iter = 100
|
||||
|
||||
var wg sync.WaitGroup
|
||||
stop := make(chan struct{})
|
||||
|
||||
for w := 0; w < writers; w++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < iter; i++ {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
}
|
||||
action := ActionMask
|
||||
if i%2 == 0 {
|
||||
action = ActionBlock
|
||||
}
|
||||
_ = r.SetAction("email", action)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for rd := 0; rd < readers; rd++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
text := "contact alice@example.com please"
|
||||
for i := 0; i < iter*2; i++ {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
default:
|
||||
}
|
||||
_ = r.Redact(text)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(stop)
|
||||
})
|
||||
})
|
||||
@@ -1,186 +1,182 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func mustCompile(ids ...string) []Pattern {
|
||||
all := DefaultPatterns()
|
||||
if len(ids) == 0 {
|
||||
out, err := Compile(all)
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
|
||||
return out
|
||||
}
|
||||
pickP := pick(all, ids)
|
||||
out, err := Compile(pickP)
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
|
||||
return out
|
||||
// detect builds a single-detector []NERConfig that reports one entity
|
||||
// over the whole input under the given group/action.
|
||||
func oneShot(group string, action Action, start, end int) []NERConfig {
|
||||
return []NERConfig{{
|
||||
Detector: &stubNERDetector{entities: []NEREntity{{Group: group, Start: start, End: end, Score: 1}}},
|
||||
EntityActions: map[string]Action{group: action},
|
||||
}}
|
||||
}
|
||||
|
||||
func pick(all []Pattern, ids []string) []Pattern {
|
||||
keep := map[string]bool{}
|
||||
for _, id := range ids {
|
||||
keep[id] = true
|
||||
}
|
||||
var out []Pattern
|
||||
for _, p := range all {
|
||||
if keep[p.ID] {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
var _ = Describe("RedactNER emission", func() {
|
||||
ctx := context.Background()
|
||||
|
||||
var _ = Describe("Redactor", func() {
|
||||
It("masks email", func() {
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
res := r.Redact("Contact me at alice@example.com any time.")
|
||||
Expect(res.Blocked).To(BeFalse(), "email is mask-action by default, should not block")
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:email]"))
|
||||
It("masks with a [REDACTED:ner:GROUP] placeholder and records a hash prefix", func() {
|
||||
res, err := RedactNER(ctx, "Contact me at alice@example.com any time.", oneShot("EMAIL", ActionMask, 14, 31))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Masked).To(BeTrue())
|
||||
Expect(res.Blocked).To(BeFalse())
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ner:EMAIL]"))
|
||||
Expect(res.Redacted).NotTo(ContainSubstring("alice@example.com"))
|
||||
Expect(res.Spans).To(HaveLen(1))
|
||||
Expect(res.Spans[0].HashPrefix).NotTo(BeEmpty(), "hash prefix must be set so audits can dedupe leaks")
|
||||
})
|
||||
|
||||
It("masks SSN", func() {
|
||||
r := NewRedactor(mustCompile("ssn"))
|
||||
res := r.Redact("call me about SSN 123-45-6789 please")
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:ssn]"))
|
||||
It("labels pattern-detector hits with the pattern source, not ner", func() {
|
||||
cfgs := []NERConfig{{
|
||||
Detector: &stubNERDetector{entities: []NEREntity{{Group: "ANTHROPIC_KEY", Start: 4, End: 24, Score: 1}}},
|
||||
EntityActions: map[string]Action{"ANTHROPIC_KEY": ActionMask},
|
||||
Source: SourcePattern,
|
||||
}}
|
||||
res, err := RedactNER(ctx, "use sk-ant-aaaaaaaaaaaaaaaa now", cfgs)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(ContainSubstring("[REDACTED:pattern:ANTHROPIC_KEY]"))
|
||||
Expect(res.Redacted).NotTo(ContainSubstring("[REDACTED:ner:"))
|
||||
Expect(res.Spans).To(HaveLen(1))
|
||||
Expect(res.Spans[0].Pattern).To(Equal("pattern:ANTHROPIC_KEY"))
|
||||
})
|
||||
|
||||
It("uses Luhn for credit card", func() {
|
||||
r := NewRedactor(mustCompile("credit_card"))
|
||||
|
||||
// 4111 1111 1111 1111 — canonical Luhn-valid Visa test number.
|
||||
good := r.Redact("card: 4111 1111 1111 1111")
|
||||
Expect(good.Spans).To(HaveLen(1))
|
||||
Expect(good.Redacted).To(ContainSubstring("[REDACTED:credit_card]"))
|
||||
|
||||
// 4111 1111 1111 1112 — same shape, fails Luhn. Must NOT match.
|
||||
bad := r.Redact("card: 4111 1111 1111 1112")
|
||||
Expect(bad.Spans).To(BeEmpty(), "Luhn-invalid 16-digit run must not be redacted")
|
||||
Expect(bad.Redacted).To(ContainSubstring("1112"), "Luhn-invalid input should pass through untouched")
|
||||
It("block leaves the matched span intact and sets Blocked", func() {
|
||||
res, err := RedactNER(ctx, "token sk-abcdef here", oneShot("PASSWORD", ActionBlock, 6, 15))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Blocked).To(BeTrue())
|
||||
Expect(res.Redacted).To(ContainSubstring("sk-abcdef"), "block leaves the value intact for the caller to discard")
|
||||
Expect(res.Spans[0].Action).To(Equal(ActionBlock))
|
||||
})
|
||||
|
||||
It("validates IPv4 octets", func() {
|
||||
r := NewRedactor(mustCompile("ipv4"))
|
||||
|
||||
good := r.Redact("server at 192.168.1.10 is up")
|
||||
Expect(good.Spans).To(HaveLen(1))
|
||||
|
||||
// 999.999.999.999 — regex matches but octet > 255 must reject.
|
||||
bad := r.Redact("not an ip: 999.999.999.999")
|
||||
Expect(bad.Spans).To(BeEmpty(), "ipv4 with octet>255 must not match")
|
||||
})
|
||||
|
||||
It("api_key defaults to block", func() {
|
||||
r := NewRedactor(mustCompile("api_key_prefix"))
|
||||
res := r.Redact("here's a token sk-abcdefghijklmnopqrstuvwxyz0123456789 to use")
|
||||
Expect(res.Blocked).To(BeTrue(), "api_key default action is block; Result.Blocked must be true")
|
||||
// The redacted output keeps the matched value when blocking — the
|
||||
// caller is expected to refuse the request, not to forward a partial.
|
||||
Expect(res.Redacted).To(ContainSubstring("sk-abcdefghijklmn"), "blocked actions leave the matched span intact for caller inspection")
|
||||
})
|
||||
|
||||
It("preserves non-matching text", func() {
|
||||
r := NewRedactor(mustCompile()) // all default patterns
|
||||
in := "no PII here at all, just words and numbers like 42 and 1.5"
|
||||
res := r.Redact(in)
|
||||
Expect(res.Redacted).To(Equal(in), "non-PII input should pass through unchanged")
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("handles empty input", func() {
|
||||
r := NewRedactor(mustCompile())
|
||||
res := r.Redact("")
|
||||
Expect(res.Redacted).To(BeEmpty())
|
||||
Expect(res.Blocked).To(BeFalse())
|
||||
It("allow leaves text intact but still records the span", func() {
|
||||
res, err := RedactNER(ctx, "Hello Acme!", oneShot("ORG", ActionAllow, 6, 10))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Masked).To(BeFalse())
|
||||
Expect(res.Blocked).To(BeFalse())
|
||||
Expect(res.Redacted).To(Equal("Hello Acme!"))
|
||||
Expect(res.Spans).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("passes non-matching text through unchanged", func() {
|
||||
det := &stubNERDetector{} // no entities
|
||||
res, err := RedactNER(ctx, "no PII here, just words", []NERConfig{{Detector: det, DefaultAction: ActionMask}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(Equal("no PII here, just words"))
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
})
|
||||
|
||||
It("nil patterns is a no-op", func() {
|
||||
// Disabled-PII deployment: pii.NewRedactor(nil) is a no-op.
|
||||
r := NewRedactor(nil)
|
||||
res := r.Redact("alice@example.com sent it")
|
||||
Expect(res.Redacted).To(Equal("alice@example.com sent it"))
|
||||
It("handles empty input without calling the detector", func() {
|
||||
det := &stubNERDetector{entities: []NEREntity{{Group: "X", Start: 0, End: 1, Score: 1}}}
|
||||
res, err := RedactNER(ctx, "", []NERConfig{{Detector: det, DefaultAction: ActionMask}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res.Redacted).To(BeEmpty())
|
||||
Expect(res.Spans).To(BeEmpty())
|
||||
Expect(det.calls).To(Equal(0))
|
||||
})
|
||||
|
||||
It("hash prefix is stable", func() {
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
a := r.Redact("a@b.com")
|
||||
b := r.Redact("hi a@b.com again")
|
||||
It("produces a stable hash prefix for the same matched value", func() {
|
||||
a, _ := RedactNER(ctx, "a@b.com", oneShot("EMAIL", ActionMask, 0, 7))
|
||||
b, _ := RedactNER(ctx, "hi a@b.com", oneShot("EMAIL", ActionMask, 3, 10))
|
||||
Expect(a.Spans).To(HaveLen(1))
|
||||
Expect(b.Spans).To(HaveLen(1))
|
||||
Expect(a.Spans[0].HashPrefix).To(Equal(b.Spans[0].HashPrefix), "same matched value must produce same hash prefix")
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("Compile", func() {
|
||||
It("rejects unknown pattern id", func() {
|
||||
_, err := Compile([]Pattern{{ID: "nonexistent", Action: ActionMask}})
|
||||
Expect(err).To(HaveOccurred(), "Compile must error on unknown pattern id")
|
||||
// funcNERDetector computes entities from the text it is handed — used to
|
||||
// prove the segment scan gives the detector the JOINED document, the way a
|
||||
// context-sensitive encoder behaves.
|
||||
type funcNERDetector struct {
|
||||
fn func(text string) ([]NEREntity, error)
|
||||
}
|
||||
|
||||
func (f *funcNERDetector) Detect(_ context.Context, text string) ([]NEREntity, error) {
|
||||
return f.fn(text)
|
||||
}
|
||||
|
||||
// pinAfterCard mimics the real encoder's context sensitivity: "4421" is a
|
||||
// PIN only when "card" appears earlier in the same document (measured on
|
||||
// privacy-filter-multilingual: alone it detects nothing, with the eliciting
|
||||
// question it detects PIN).
|
||||
func pinAfterCard(text string) ([]NEREntity, error) {
|
||||
i := strings.Index(text, "4421")
|
||||
if i < 0 || !strings.Contains(text[:i], "card") {
|
||||
return nil, nil
|
||||
}
|
||||
return []NEREntity{{Group: "PIN", Start: i, End: i + 4, Score: 0.9}}, nil
|
||||
}
|
||||
|
||||
var _ = Describe("RedactNERSegments", func() {
|
||||
ctx := context.Background()
|
||||
maskCfg := func(d NERDetector) []NERConfig {
|
||||
return []NERConfig{{Detector: d, DefaultAction: ActionMask}}
|
||||
}
|
||||
|
||||
It("scans segments as one document so context crosses messages", func() {
|
||||
det := &funcNERDetector{fn: pinAfterCard}
|
||||
|
||||
// Scanned alone the digits are invisible...
|
||||
alone, err := RedactNER(ctx, "it is 4421 ok", maskCfg(det))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(alone.Spans).To(BeEmpty())
|
||||
|
||||
// ...as a segment after the eliciting question they are detected,
|
||||
// and the span maps back to the second segment with local offsets.
|
||||
res, err := RedactNERSegments(ctx,
|
||||
[]string{"What are the last four digits of your card?", "it is 4421 ok"},
|
||||
maskCfg(det))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res).To(HaveLen(2))
|
||||
Expect(res[0].Spans).To(BeEmpty())
|
||||
Expect(res[0].Redacted).To(Equal("What are the last four digits of your card?"))
|
||||
Expect(res[1].Spans).To(HaveLen(1))
|
||||
Expect(res[1].Spans[0].Start).To(Equal(6))
|
||||
Expect(res[1].Spans[0].End).To(Equal(10))
|
||||
Expect(res[1].Masked).To(BeTrue())
|
||||
Expect(res[1].Redacted).To(Equal("it is [REDACTED:ner:PIN] ok"))
|
||||
})
|
||||
|
||||
It("splits a hit crossing a segment boundary, masking both fragments", func() {
|
||||
det := &funcNERDetector{fn: func(text string) ([]NEREntity, error) {
|
||||
i := strings.Index(text, "22 Baker")
|
||||
j := strings.Index(text, "Street")
|
||||
if i < 0 || j < 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return []NEREntity{{Group: "STREET", Start: i, End: j + len("Street"), Score: 0.9}}, nil
|
||||
}}
|
||||
res, err := RedactNERSegments(ctx, []string{"22 Baker", "Street"}, maskCfg(det))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res[0].Redacted).To(Equal("[REDACTED:ner:STREET]"))
|
||||
Expect(res[1].Redacted).To(Equal("[REDACTED:ner:STREET]"))
|
||||
})
|
||||
|
||||
It("returns best-effort results with the first detector error", func() {
|
||||
bad := NERConfig{Detector: &stubNERDetector{err: errors.New("backend down")}, DefaultAction: ActionMask}
|
||||
good := NERConfig{
|
||||
Detector: &stubNERDetector{entities: []NEREntity{{Group: "PER", Start: 0, End: 5, Score: 0.9}}},
|
||||
DefaultAction: ActionMask,
|
||||
}
|
||||
res, err := RedactNERSegments(ctx, []string{"Alice", "rest"}, []NERConfig{bad, good})
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(res[0].Spans).To(HaveLen(1), "healthy detector's hits still apply")
|
||||
})
|
||||
|
||||
It("is a per-text no-op without detectors or texts", func() {
|
||||
res, err := RedactNERSegments(ctx, []string{"a", ""}, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res).To(HaveLen(2))
|
||||
Expect(res[0].Redacted).To(Equal("a"))
|
||||
Expect(res[1].Redacted).To(Equal(""))
|
||||
|
||||
res, err = RedactNERSegments(ctx, nil, maskCfg(&stubNERDetector{}))
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(res).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("MaxPatternLength", func() {
|
||||
It("returns the longest pattern's max length", func() {
|
||||
patterns := mustCompile("email", "ssn")
|
||||
got := MaxPatternLength(patterns)
|
||||
// email is the longer of the two (254). The streaming filter
|
||||
// will use this to size its tail buffer.
|
||||
Expect(got).To(Equal(254))
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("RedactWithOverrides", func() {
|
||||
It("upgrades action", func() {
|
||||
// email is mask by default; the per-model override turns it into a
|
||||
// hard block for one request without mutating the redactor.
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
res := r.RedactWithOverrides("contact alice@example.com",
|
||||
map[string]Action{"email": ActionBlock})
|
||||
Expect(res.Blocked).To(BeTrue(), "override should have set Blocked")
|
||||
// Block leaves the value intact (the caller short-circuits the
|
||||
// request) — the redactor never echoes the matched text.
|
||||
Expect(res.Redacted).To(ContainSubstring("alice@example.com"), "block leaves text intact for the caller to discard")
|
||||
// Stored action is unchanged so a subsequent default Redact still
|
||||
// masks rather than blocks.
|
||||
res2 := r.Redact("contact alice@example.com")
|
||||
Expect(res2.Blocked).To(BeFalse(), "override must not mutate stored action")
|
||||
})
|
||||
|
||||
It("ignores unknown IDs", func() {
|
||||
// An override for a pattern this redactor doesn't know about is a
|
||||
// no-op rather than an error — per-model configs may reference
|
||||
// patterns from a wider catalogue than the active redactor holds.
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
res := r.RedactWithOverrides("contact alice@example.com",
|
||||
map[string]Action{"ssn": ActionBlock})
|
||||
Expect(res.Blocked).To(BeFalse(), "ssn override against email-only redactor must be no-op")
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("SetAction", func() {
|
||||
It("swaps in place", func() {
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
Expect(r.SetAction("email", ActionAllow)).To(Succeed())
|
||||
res := r.Redact("contact alice@example.com")
|
||||
Expect(res.Masked).To(BeFalse(), "allow leaves text intact, so nothing is masked")
|
||||
Expect(res.Redacted).To(ContainSubstring("alice@example.com"), "allow should leave the match in place")
|
||||
Expect(res.Spans).To(HaveLen(1), "allow still records the match")
|
||||
Expect(res.Blocked).To(BeFalse(), "SetAction(allow) should not block")
|
||||
})
|
||||
|
||||
It("rejects unknown id", func() {
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
Expect(r.SetAction("nonexistent", ActionMask)).NotTo(Succeed(), "expected error for unknown pattern id")
|
||||
})
|
||||
|
||||
It("rejects unknown action", func() {
|
||||
r := NewRedactor(mustCompile("email"))
|
||||
Expect(r.SetAction("email", Action("frobnicate"))).NotTo(Succeed(), "expected error for unknown action")
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -27,7 +27,10 @@ type ListQuery struct {
|
||||
UserID string
|
||||
PatternID string
|
||||
Kind EventKind
|
||||
Limit int
|
||||
// Origin scopes the search to redaction events from one surface
|
||||
// (middleware | proxy | pii_analyze | pii_redact); empty matches any.
|
||||
Origin Origin
|
||||
Limit int
|
||||
}
|
||||
|
||||
// NewMemoryEventStore returns an in-memory ring-buffer event store.
|
||||
@@ -91,6 +94,9 @@ func (s *memoryEventStore) List(_ context.Context, q ListQuery) ([]PIIEvent, err
|
||||
if q.Kind != "" && e.ResolvedKind() != q.Kind {
|
||||
return false
|
||||
}
|
||||
if q.Origin != "" && e.Origin != q.Origin {
|
||||
return false
|
||||
}
|
||||
out = append(out, e)
|
||||
return len(out) >= limit
|
||||
}
|
||||
|
||||
48
core/services/routing/pii/store_test.go
Normal file
48
core/services/routing/pii/store_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("EventStore Origin filter", func() {
|
||||
var store EventStore
|
||||
ctx := context.Background()
|
||||
|
||||
BeforeEach(func() {
|
||||
store = NewMemoryEventStore(0)
|
||||
// Three redaction events from three different surfaces.
|
||||
for _, o := range []Origin{OriginMiddleware, OriginRedactAPI, OriginAnalyzeAPI} {
|
||||
Expect(store.Record(ctx, PIIEvent{
|
||||
ID: NewEventID(),
|
||||
Kind: KindPII,
|
||||
Origin: o,
|
||||
PatternID: "ner:EMAIL",
|
||||
})).To(Succeed())
|
||||
}
|
||||
// An older row with no Origin (pre-field) must not match any origin filter.
|
||||
Expect(store.Record(ctx, PIIEvent{ID: NewEventID(), Kind: KindPII, PatternID: "ner:EMAIL"})).To(Succeed())
|
||||
})
|
||||
|
||||
It("returns only events from the requested origin", func() {
|
||||
got, err := store.List(ctx, ListQuery{Origin: OriginRedactAPI})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Origin).To(Equal(OriginRedactAPI))
|
||||
})
|
||||
|
||||
It("an empty origin matches every event (including pre-field rows)", func() {
|
||||
got, err := store.List(ctx, ListQuery{})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(HaveLen(4))
|
||||
})
|
||||
|
||||
It("does not match a pre-field (empty-origin) row against a concrete origin", func() {
|
||||
got, err := store.List(ctx, ListQuery{Origin: OriginMiddleware})
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Origin).To(Equal(OriginMiddleware))
|
||||
})
|
||||
})
|
||||
@@ -1,198 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// StreamFilter applies the regex PII tier to a streaming response,
|
||||
// chunk by chunk, with a buffered-emit invariant: for any active
|
||||
// pattern with bounded max-length L, the filter never emits the
|
||||
// trailing L-1 characters of the cumulative input until either
|
||||
//
|
||||
// (a) more text arrives that disambiguates the boundary, or
|
||||
// (b) the stream closes (Drain).
|
||||
//
|
||||
// That keeps the redactor honest across chunk splits — an email
|
||||
// arriving as "alice@" + "example.com" still masks the same way as
|
||||
// "alice@example.com" arriving in one piece.
|
||||
//
|
||||
// Action handling in stream mode differs from the request-side
|
||||
// middleware. Earlier chunks of the response are already on the wire
|
||||
// by the time later chunks are scanned, so a "block" can't actually
|
||||
// reject the request. We remap block → mask for redaction purposes
|
||||
// while still recording PIIEvent rows with action="block" so audits
|
||||
// surface the original intent ("the model would have leaked X here,
|
||||
// suppressed in flight"). allow on the output side is a no-op — the
|
||||
// text is left intact, matching its request-side detect-and-log
|
||||
// behaviour.
|
||||
//
|
||||
// StreamFilter is NOT safe for concurrent use across goroutines; one
|
||||
// instance per response stream.
|
||||
type StreamFilter struct {
|
||||
redactor *Redactor
|
||||
maskOverrides map[string]Action // block → mask map used for redaction
|
||||
auditActions map[string]Action // original action per pattern, used for events
|
||||
store EventStore
|
||||
correlationID string
|
||||
userID string
|
||||
holdLen int
|
||||
buffer strings.Builder
|
||||
emittedBytes int
|
||||
}
|
||||
|
||||
// NewStreamFilter constructs a per-response filter. modelOverrides is
|
||||
// the per-model action override map (same shape the request-side
|
||||
// middleware uses); it can be nil when the model only accepts global
|
||||
// defaults.
|
||||
//
|
||||
// store may be nil — events are then computed but not persisted, which
|
||||
// is what the chat handler does when --disable-stats is set.
|
||||
func NewStreamFilter(redactor *Redactor, modelOverrides map[string]Action, store EventStore, correlationID, userID string) *StreamFilter {
|
||||
if redactor == nil {
|
||||
return &StreamFilter{}
|
||||
}
|
||||
|
||||
patterns := redactor.Patterns()
|
||||
|
||||
// auditActions: the action we *would* have applied if this match
|
||||
// occurred on the request side. Honours the per-model override.
|
||||
auditActions := make(map[string]Action, len(patterns))
|
||||
for _, p := range patterns {
|
||||
auditActions[p.ID] = p.Action
|
||||
}
|
||||
for id, action := range modelOverrides {
|
||||
auditActions[id] = action
|
||||
}
|
||||
|
||||
// maskOverrides: the action we actually apply to the stream. Same
|
||||
// as auditActions, but with every block remapped to mask.
|
||||
maskOverrides := make(map[string]Action, len(auditActions))
|
||||
for id, action := range auditActions {
|
||||
if action == ActionBlock {
|
||||
maskOverrides[id] = ActionMask
|
||||
} else {
|
||||
maskOverrides[id] = action
|
||||
}
|
||||
}
|
||||
|
||||
return &StreamFilter{
|
||||
redactor: redactor,
|
||||
maskOverrides: maskOverrides,
|
||||
auditActions: auditActions,
|
||||
store: store,
|
||||
correlationID: correlationID,
|
||||
userID: userID,
|
||||
holdLen: redactor.MaxPatternLength() - 1,
|
||||
}
|
||||
}
|
||||
|
||||
// Push appends new text to the filter's buffer and returns the prefix
|
||||
// safe to emit downstream — the cumulative input minus a tail of
|
||||
// holdLen characters that might still be the start of a longer match.
|
||||
// Returned text has masks already applied.
|
||||
//
|
||||
// Returns an empty string when not enough text has arrived to clear
|
||||
// the hold window.
|
||||
func (sf *StreamFilter) Push(text string) string {
|
||||
if sf.redactor == nil || sf.holdLen <= 0 {
|
||||
return text
|
||||
}
|
||||
sf.buffer.WriteString(text)
|
||||
bufStr := sf.buffer.String()
|
||||
n := len(bufStr)
|
||||
|
||||
if n <= sf.holdLen {
|
||||
return ""
|
||||
}
|
||||
|
||||
emitBoundary := n - sf.holdLen
|
||||
|
||||
// Scan the entire buffer. A match whose start is before the
|
||||
// boundary but whose end runs past it crosses the window — pull
|
||||
// the boundary back to match.start so the pattern stays whole in
|
||||
// the buffer for the next Push to scan again.
|
||||
full := sf.redactor.RedactWithOverrides(bufStr, sf.maskOverrides)
|
||||
for _, span := range full.Spans {
|
||||
if span.Start < emitBoundary && span.End > emitBoundary {
|
||||
emitBoundary = span.Start
|
||||
}
|
||||
}
|
||||
|
||||
// holdLen is byte-sized but a chunk boundary may land mid-codepoint.
|
||||
// Snap back to the nearest rune start so neither the emitted prefix
|
||||
// nor the retained tail contains a split codepoint — otherwise the
|
||||
// next regex scan over an invalid-UTF-8 prefix could mis-match.
|
||||
for emitBoundary > 0 && emitBoundary < n && !utf8.RuneStart(bufStr[emitBoundary]) {
|
||||
emitBoundary--
|
||||
}
|
||||
|
||||
if emitBoundary <= 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
emitted := sf.applyAndEmit(bufStr[:emitBoundary])
|
||||
sf.buffer.Reset()
|
||||
sf.buffer.WriteString(bufStr[emitBoundary:])
|
||||
return emitted
|
||||
}
|
||||
|
||||
// Drain emits whatever's left in the buffer with all matches applied.
|
||||
// Call exactly once when the stream closes — repeat calls return the
|
||||
// empty string.
|
||||
func (sf *StreamFilter) Drain() string {
|
||||
if sf.redactor == nil {
|
||||
return sf.buffer.String()
|
||||
}
|
||||
bufStr := sf.buffer.String()
|
||||
if bufStr == "" {
|
||||
return ""
|
||||
}
|
||||
emitted := sf.applyAndEmit(bufStr)
|
||||
sf.buffer.Reset()
|
||||
return emitted
|
||||
}
|
||||
|
||||
// applyAndEmit runs the redactor over a committed-for-emit fragment,
|
||||
// substitutes mask/block placeholders inline, and records one
|
||||
// PIIEvent per matched span (with the audit action, not the masked
|
||||
// one). ByteOffset is referenced to the cumulative emitted output so
|
||||
// admins can correlate event positions against the streamed body.
|
||||
func (sf *StreamFilter) applyAndEmit(fragment string) string {
|
||||
res := sf.redactor.RedactWithOverrides(fragment, sf.maskOverrides)
|
||||
output := res.Redacted
|
||||
|
||||
if len(res.Spans) > 0 {
|
||||
now := time.Now().UTC()
|
||||
for _, span := range res.Spans {
|
||||
ev := PIIEvent{
|
||||
ID: newStreamEventID(),
|
||||
CorrelationID: sf.correlationID,
|
||||
UserID: sf.userID,
|
||||
Direction: DirectionOut,
|
||||
PatternID: span.Pattern,
|
||||
ByteOffset: sf.emittedBytes + span.Start,
|
||||
Length: span.End - span.Start,
|
||||
HashPrefix: span.HashPrefix,
|
||||
Action: sf.auditActions[span.Pattern],
|
||||
CreatedAt: now,
|
||||
}
|
||||
if sf.store != nil {
|
||||
_ = sf.store.Record(context.Background(), ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sf.emittedBytes += len(fragment)
|
||||
return output
|
||||
}
|
||||
|
||||
func newStreamEventID() string {
|
||||
var b [12]byte
|
||||
_, _ = rand.Read(b[:])
|
||||
return "pii_" + hex.EncodeToString(b[:])
|
||||
}
|
||||
@@ -1,184 +0,0 @@
|
||||
package pii
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func newStreamRedactor(ids ...string) *Redactor {
|
||||
all := DefaultPatterns()
|
||||
chosen := all
|
||||
if len(ids) > 0 {
|
||||
chosen = pick(all, ids)
|
||||
}
|
||||
patterns, err := Compile(chosen)
|
||||
ExpectWithOffset(1, err).NotTo(HaveOccurred(), "compile")
|
||||
return NewRedactor(patterns)
|
||||
}
|
||||
|
||||
var _ = Describe("StreamFilter", func() {
|
||||
It("masks across chunks", func() {
|
||||
// The most important streaming test: an email split arbitrarily
|
||||
// across chunk boundaries must mask exactly the same way as one
|
||||
// arriving in a single Push.
|
||||
red := newStreamRedactor("email")
|
||||
sf := NewStreamFilter(red, nil, nil, "", "")
|
||||
|
||||
// "alice@example.com" (17 bytes) split between '@' and 'e'.
|
||||
out := ""
|
||||
out += sf.Push("hi alice@")
|
||||
out += sf.Push("example.com! end")
|
||||
out += sf.Drain()
|
||||
|
||||
Expect(out).NotTo(ContainSubstring("alice@example.com"), "stream leaked email across chunk boundary")
|
||||
Expect(out).To(ContainSubstring("[REDACTED:email]"))
|
||||
})
|
||||
|
||||
It("block becomes mask", func() {
|
||||
// api_key_prefix is block by default. In stream mode the earlier
|
||||
// chunks are already on the wire so block is impossible — the
|
||||
// filter remaps to mask while still recording action="block" so
|
||||
// the audit log keeps the original intent.
|
||||
red := newStreamRedactor("api_key_prefix")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
sf := NewStreamFilter(red, nil, store, "corr-1", "user-1")
|
||||
|
||||
out := sf.Push("here is your token: sk-abcdefghijklmnopqrstuvwxyz0123456789 done")
|
||||
out += sf.Drain()
|
||||
|
||||
Expect(out).NotTo(ContainSubstring("abcdefghijklmnopqrstuvwxyz0123456789"), "block-in-stream must mask, leaked the value")
|
||||
Expect(out).To(ContainSubstring("[REDACTED:api_key_prefix]"))
|
||||
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 10})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].Action).To(Equal(ActionBlock), "audit must record original block action")
|
||||
Expect(events[0].Direction).To(Equal(DirectionOut), "stream events must be DirectionOut")
|
||||
})
|
||||
|
||||
It("no match passthrough", func() {
|
||||
red := newStreamRedactor("email")
|
||||
sf := NewStreamFilter(red, nil, nil, "", "")
|
||||
out := sf.Push("perfectly clean text that should") + sf.Push(" pass through unchanged.") + sf.Drain()
|
||||
Expect(out).To(Equal("perfectly clean text that should pass through unchanged."))
|
||||
})
|
||||
|
||||
It("nil redactor passthrough", func() {
|
||||
// --disable-pii path: NewStreamFilter(nil, ...) returns a filter
|
||||
// that just forwards Push input verbatim.
|
||||
sf := NewStreamFilter(nil, nil, nil, "", "")
|
||||
out := sf.Push("any old text including alice@example.com") + sf.Drain()
|
||||
Expect(out).To(Equal("any old text including alice@example.com"))
|
||||
})
|
||||
|
||||
It("per-model overrides", func() {
|
||||
// email defaults to mask; per-model override upgrades to block.
|
||||
// In stream mode the override still maps to mask placeholder, but
|
||||
// the audit event records action="block".
|
||||
red := newStreamRedactor("email")
|
||||
store := NewMemoryEventStore(0)
|
||||
defer func() { _ = store.Close() }()
|
||||
sf := NewStreamFilter(red, map[string]Action{"email": ActionBlock}, store, "corr-2", "user-2")
|
||||
|
||||
out := sf.Push("contact alice@example.com please") + sf.Drain()
|
||||
Expect(out).NotTo(ContainSubstring("alice@example.com"), "override block-in-stream must mask")
|
||||
events, _ := store.List(context.Background(), ListQuery{Limit: 10})
|
||||
Expect(events).To(HaveLen(1))
|
||||
Expect(events[0].Action).To(Equal(ActionBlock))
|
||||
})
|
||||
|
||||
// StreamFilter_BufferedEmitInvariant feeds the redactor a corpus
|
||||
// one rune at a time, randomly chunked, and asserts:
|
||||
//
|
||||
// 1. Across all (input, splitting) pairs, the cumulative emitted
|
||||
// output never contains any of the secret values that were
|
||||
// embedded in the input.
|
||||
// 2. The output, fully drained, equals what Redact would have
|
||||
// produced on the unsplit input.
|
||||
//
|
||||
// This is the load-bearing property of streaming PII: regardless of
|
||||
// where chunks split, the emitted bytes cannot contain a value that a
|
||||
// single-shot redactor would have masked.
|
||||
It("buffered emit invariant", func() {
|
||||
corpus := []struct {
|
||||
text string
|
||||
secrets []string
|
||||
}{
|
||||
{"contact alice@example.com or bob@example.org", []string{"alice@example.com", "bob@example.org"}},
|
||||
{"my SSN is 123-45-6789 and his is 987-65-4321", []string{"123-45-6789", "987-65-4321"}},
|
||||
{"sk-abcdefghijklmnopqrstuvwxyz0123456789 leaked", []string{"sk-abcdefghijklmnopqrstuvwxyz0123456789"}},
|
||||
{"repeats: alice@example.com / alice@example.com / alice@example.com", []string{"alice@example.com"}},
|
||||
// Multibyte UTF-8 corpora pin the rune-boundary snap in
|
||||
// StreamFilter.Push: holdLen is byte-sized, so a chunk boundary
|
||||
// may land mid-codepoint. Without the snap, the retained tail
|
||||
// has a partial codepoint and the next regex scan can mis-align.
|
||||
// Each entry mixes ASCII secrets with surrounding multibyte text
|
||||
// so a byte-aligned cut would land inside a CJK or accented
|
||||
// character on at least some splits.
|
||||
{"こんにちは alice@example.com さようなら", []string{"alice@example.com"}},
|
||||
{"クレジットカード: 4111-1111-1111-1111 終わり", []string{"4111-1111-1111-1111"}},
|
||||
{"naïve résumé: alice@example.com, façade", []string{"alice@example.com"}},
|
||||
}
|
||||
|
||||
red := newStreamRedactor() // all default patterns
|
||||
rng := rand.New(rand.NewSource(1)) // seeded for reproducibility
|
||||
|
||||
for _, tc := range corpus {
|
||||
for trial := 0; trial < 10; trial++ {
|
||||
sf := NewStreamFilter(red, nil, nil, "", "")
|
||||
var out strings.Builder
|
||||
for i := 0; i < utf8.RuneCountInString(tc.text); {
|
||||
// Random chunk size 1-8 runes, never crossing the end.
|
||||
chunk := 1 + rng.Intn(8)
|
||||
if i+chunk > utf8.RuneCountInString(tc.text) {
|
||||
chunk = utf8.RuneCountInString(tc.text) - i
|
||||
}
|
||||
out.WriteString(sf.Push(stringSlice(tc.text, i, i+chunk)))
|
||||
i += chunk
|
||||
}
|
||||
out.WriteString(sf.Drain())
|
||||
result := out.String()
|
||||
|
||||
// Property 1: no secret value appears anywhere in the
|
||||
// output.
|
||||
for _, secret := range tc.secrets {
|
||||
Expect(result).NotTo(ContainSubstring(secret),
|
||||
fmt.Sprintf("trial %d: secret %q leaked through streaming\n input: %q\n output: %q", trial, secret, tc.text, result))
|
||||
}
|
||||
|
||||
// Property 2: the streamed output equals what a single-shot
|
||||
// Redact would have produced on the same input. (Block
|
||||
// patterns get masked in stream mode, so we compare against
|
||||
// a remapped redaction.)
|
||||
expected := singleShotMaskAll(red, tc.text)
|
||||
Expect(result).To(Equal(expected),
|
||||
fmt.Sprintf("trial %d: stream != single-shot\n input: %q", trial, tc.text))
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// singleShotMaskAll runs the redactor in one pass with all blocks
|
||||
// remapped to mask — the same view the StreamFilter produces.
|
||||
func singleShotMaskAll(red *Redactor, text string) string {
|
||||
patterns := red.Patterns()
|
||||
overrides := make(map[string]Action, len(patterns))
|
||||
for _, p := range patterns {
|
||||
if p.Action == ActionBlock {
|
||||
overrides[p.ID] = ActionMask
|
||||
}
|
||||
}
|
||||
res := red.RedactWithOverrides(text, overrides)
|
||||
return res.Redacted
|
||||
}
|
||||
|
||||
func stringSlice(s string, fromRune, toRune int) string {
|
||||
runes := []rune(s)
|
||||
return string(runes[fromRune:toRune])
|
||||
}
|
||||
@@ -62,10 +62,12 @@ const (
|
||||
// substring slicing; call sites that need to log it strip it via
|
||||
// HashPrefix.
|
||||
type Span struct {
|
||||
Start int
|
||||
End int
|
||||
Pattern string // matches Pattern.ID
|
||||
HashPrefix string // first 8 chars of sha256(matched value); audit-safe
|
||||
Start int
|
||||
End int
|
||||
Pattern string // synthetic detector id, "<source>:<GROUP>" (e.g. "ner:EMAIL", "pattern:ANTHROPIC_KEY")
|
||||
HashPrefix string // first 8 chars of sha256(matched value); audit-safe
|
||||
Action Action // the action that fired for this span (after merge)
|
||||
Score float32 // detector confidence for the (winning) hit, 0..1
|
||||
}
|
||||
|
||||
// Result is what Redact returns. Redacted is the input string after
|
||||
@@ -88,30 +90,6 @@ type Result struct {
|
||||
Masked bool
|
||||
}
|
||||
|
||||
// Pattern is one configurable rule. Description is shown in the admin
|
||||
// UI alongside the pattern; the regex itself stays an implementation
|
||||
// detail (a leak-prone admin showing an SSN regex with a sample value
|
||||
// in the field is a risk we deliberately design around).
|
||||
type Pattern struct {
|
||||
ID string
|
||||
Description string
|
||||
Action Action
|
||||
// Disabled skips the pattern entirely when true — useful for
|
||||
// admins who want to keep a regex around (visible in the UI) but
|
||||
// turn it off without removing the YAML entry. Default-false so
|
||||
// every existing pattern stays active without touching its config.
|
||||
Disabled bool
|
||||
// MaxMatchLength is the longest possible match in characters. The
|
||||
// streaming filter (subsystem 3, follow-up commit) uses this to
|
||||
// size its tail buffer. For regex patterns we compute it at
|
||||
// compile time from the pattern's structure when possible, or set
|
||||
// a conservative upper bound otherwise.
|
||||
MaxMatchLength int
|
||||
|
||||
// internal — populated by Compile().
|
||||
regex regexpMatcher
|
||||
}
|
||||
|
||||
// EventKind classifies a stored audit event. The store is shared by the
|
||||
// PII filter (its original use), the MITM proxy (connect decisions and
|
||||
// per-request traffic counters), and — when subsystem 2 lands — the
|
||||
@@ -135,6 +113,20 @@ const (
|
||||
KindAdmission EventKind = "admission"
|
||||
)
|
||||
|
||||
// Origin labels which surface produced a redaction event, so the events
|
||||
// log distinguishes an inline chat redaction from a MITM-proxy one and
|
||||
// from an explicit /api/pii/{analyze,redact} call. It is set on PII
|
||||
// redaction events only (Kind KindPII); connection/admission events leave
|
||||
// it empty. An empty Origin on an older row reads as "unknown".
|
||||
type Origin = string
|
||||
|
||||
const (
|
||||
OriginMiddleware Origin = "middleware" // in-band chat/completions PII middleware
|
||||
OriginProxy Origin = "proxy" // cloud-proxy MITM input path
|
||||
OriginAnalyzeAPI Origin = "pii_analyze" // POST /api/pii/analyze
|
||||
OriginRedactAPI Origin = "pii_redact" // POST /api/pii/redact
|
||||
)
|
||||
|
||||
// PIIEvent is the persisted record. The Hash field is the first 8 chars
|
||||
// of sha256(matched value) — enough to deduplicate "is this the same
|
||||
// thing as last time" without ever storing the value itself.
|
||||
@@ -146,6 +138,7 @@ const (
|
||||
type PIIEvent struct {
|
||||
ID string `json:"id"`
|
||||
Kind EventKind `json:"kind,omitempty"`
|
||||
Origin Origin `json:"origin,omitempty"`
|
||||
CorrelationID string `json:"correlation_id,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
Direction Direction `json:"direction,omitempty"`
|
||||
@@ -154,7 +147,11 @@ type PIIEvent struct {
|
||||
Length int `json:"length,omitempty"`
|
||||
HashPrefix string `json:"hash_prefix,omitempty"`
|
||||
Action Action `json:"action,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
// Score is the detector confidence (0..1) for an NER PII hit. Metadata
|
||||
// only — never the matched value. Lets admins see how sure the model was
|
||||
// about a (possibly false-positive) detection without re-running it.
|
||||
Score float32 `json:"score,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
Host string `json:"host,omitempty"`
|
||||
Intercepted *bool `json:"intercepted,omitempty"`
|
||||
|
||||
119
core/services/routing/piiadapter/ollama.go
Normal file
119
core/services/routing/piiadapter/ollama.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package piiadapter
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
|
||||
// OllamaChat returns a pii.Adapter for *schema.OllamaChatRequest (POST
|
||||
// /api/chat). It scans each message's text content (Ollama messages carry a
|
||||
// plain string, no multimodal block form) and writes redacted text back.
|
||||
func OllamaChat() pii.Adapter {
|
||||
return pii.Adapter{
|
||||
Scan: func(parsed any) []pii.ScannedText {
|
||||
req, ok := parsed.(*schema.OllamaChatRequest)
|
||||
if !ok || req == nil {
|
||||
return nil
|
||||
}
|
||||
var out []pii.ScannedText
|
||||
for i := range req.Messages {
|
||||
if req.Messages[i].Content != "" {
|
||||
out = append(out, pii.ScannedText{Index: i, Text: req.Messages[i].Content})
|
||||
}
|
||||
}
|
||||
return out
|
||||
},
|
||||
Apply: func(parsed any, updates []pii.ScannedText) {
|
||||
req, ok := parsed.(*schema.OllamaChatRequest)
|
||||
if !ok || req == nil {
|
||||
return
|
||||
}
|
||||
for _, u := range updates {
|
||||
if u.Index >= 0 && u.Index < len(req.Messages) {
|
||||
req.Messages[u.Index].Content = u.Text
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Field selectors for OllamaGenerate (Prompt + System).
|
||||
const (
|
||||
ollamaGenPrompt = iota
|
||||
ollamaGenSystem
|
||||
)
|
||||
|
||||
// OllamaGenerate returns a pii.Adapter for *schema.OllamaGenerateRequest (POST
|
||||
// /api/generate). It scans the Prompt and System strings.
|
||||
func OllamaGenerate() pii.Adapter {
|
||||
return pii.Adapter{
|
||||
Scan: func(parsed any) []pii.ScannedText {
|
||||
req, ok := parsed.(*schema.OllamaGenerateRequest)
|
||||
if !ok || req == nil {
|
||||
return nil
|
||||
}
|
||||
var out []pii.ScannedText
|
||||
if req.Prompt != "" {
|
||||
out = append(out, pii.ScannedText{Index: ollamaGenPrompt, Text: req.Prompt})
|
||||
}
|
||||
if req.System != "" {
|
||||
out = append(out, pii.ScannedText{Index: ollamaGenSystem, Text: req.System})
|
||||
}
|
||||
return out
|
||||
},
|
||||
Apply: func(parsed any, updates []pii.ScannedText) {
|
||||
req, ok := parsed.(*schema.OllamaGenerateRequest)
|
||||
if !ok || req == nil {
|
||||
return
|
||||
}
|
||||
for _, u := range updates {
|
||||
switch u.Index {
|
||||
case ollamaGenPrompt:
|
||||
req.Prompt = u.Text
|
||||
case ollamaGenSystem:
|
||||
req.System = u.Text
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Field selectors for OllamaEmbed (Input + its Prompt alias). Reuses the
|
||||
// shared encField/decField packing.
|
||||
const (
|
||||
ollamaEmbInput = iota
|
||||
ollamaEmbPrompt
|
||||
)
|
||||
|
||||
// OllamaEmbed returns a pii.Adapter for *schema.OllamaEmbedRequest (POST
|
||||
// /api/embed, /api/embeddings). Input and its Prompt alias may be a string or
|
||||
// a []any of strings; non-string elements are skipped.
|
||||
func OllamaEmbed() pii.Adapter {
|
||||
return pii.Adapter{
|
||||
Scan: func(parsed any) []pii.ScannedText {
|
||||
req, ok := parsed.(*schema.OllamaEmbedRequest)
|
||||
if !ok || req == nil {
|
||||
return nil
|
||||
}
|
||||
var out []pii.ScannedText
|
||||
scanAnyText(ollamaEmbInput, req.Input, &out)
|
||||
scanAnyText(ollamaEmbPrompt, req.Prompt, &out)
|
||||
return out
|
||||
},
|
||||
Apply: func(parsed any, updates []pii.ScannedText) {
|
||||
req, ok := parsed.(*schema.OllamaEmbedRequest)
|
||||
if !ok || req == nil {
|
||||
return
|
||||
}
|
||||
for _, u := range updates {
|
||||
field, elem := decField(u.Index)
|
||||
switch field {
|
||||
case ollamaEmbInput:
|
||||
req.Input = applyAnyText(req.Input, elem, u.Text)
|
||||
case ollamaEmbPrompt:
|
||||
req.Prompt = applyAnyText(req.Prompt, elem, u.Text)
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
46
core/services/routing/piiadapter/ollama_test.go
Normal file
46
core/services/routing/piiadapter/ollama_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package piiadapter
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Ollama adapters", func() {
|
||||
It("OllamaChat scans and rewrites message content", func() {
|
||||
req := &schema.OllamaChatRequest{Messages: []schema.OllamaMessage{
|
||||
{Role: "user", Content: "I'm alice@example.com"},
|
||||
{Role: "assistant", Content: ""},
|
||||
}}
|
||||
a := OllamaChat()
|
||||
Expect(a.Scan(req)).To(HaveLen(1))
|
||||
applyAll(a, req, func(string) string { return "X" })
|
||||
Expect(req.Messages[0].Content).To(Equal("X"))
|
||||
Expect(req.Messages[1].Content).To(Equal(""))
|
||||
})
|
||||
|
||||
It("OllamaGenerate scans Prompt and System", func() {
|
||||
req := &schema.OllamaGenerateRequest{Prompt: "ssn 123", System: "be terse"}
|
||||
a := OllamaGenerate()
|
||||
Expect(a.Scan(req)).To(HaveLen(2))
|
||||
applyAll(a, req, func(string) string { return "Y" })
|
||||
Expect(req.Prompt).To(Equal("Y"))
|
||||
Expect(req.System).To(Equal("Y"))
|
||||
})
|
||||
|
||||
It("OllamaEmbed scans string and array Input, skipping non-strings", func() {
|
||||
a := OllamaEmbed()
|
||||
|
||||
s := &schema.OllamaEmbedRequest{Input: "secret email"}
|
||||
Expect(a.Scan(s)).To(HaveLen(1))
|
||||
applyAll(a, s, func(string) string { return "Z" })
|
||||
Expect(s.Input).To(Equal("Z"))
|
||||
|
||||
arr := &schema.OllamaEmbedRequest{Input: []any{"a secret", float64(1)}}
|
||||
Expect(a.Scan(arr)).To(HaveLen(1))
|
||||
applyAll(a, arr, func(string) string { return "Z" })
|
||||
got, _ := arr.Input.([]any)
|
||||
Expect(got).To(Equal([]any{"Z", float64(1)}))
|
||||
})
|
||||
})
|
||||
@@ -6,6 +6,8 @@
|
||||
package piiadapter
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
@@ -74,17 +76,35 @@ func OpenAI() pii.Adapter {
|
||||
}
|
||||
msg := &req.Messages[msgIdx]
|
||||
if blockIdx < 0 {
|
||||
// Whole-string content.
|
||||
// Whole-string content. Write BOTH the serializable
|
||||
// Content and the StringContent staging buffer: the
|
||||
// rendered-template path (evaluator.TemplateMessages,
|
||||
// taken whenever use_tokenizer_template is off — e.g.
|
||||
// cloud-proxy translate and Go-templated chat models)
|
||||
// reads StringContent, not Content. Masking only Content
|
||||
// would leave the original in StringContent and leak it
|
||||
// to the backend/upstream.
|
||||
msg.Content = u.Text
|
||||
msg.StringContent = u.Text
|
||||
continue
|
||||
}
|
||||
blocks, ok := msg.Content.([]any)
|
||||
if !ok || blockIdx >= len(blocks) {
|
||||
continue
|
||||
}
|
||||
if blockMap, ok := blocks[blockIdx].(map[string]any); ok {
|
||||
blockMap["text"] = u.Text
|
||||
blockMap, ok := blocks[blockIdx].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// Keep the StringContent projection in sync. For multimodal
|
||||
// messages StringContent is the text blocks flattened with
|
||||
// media markers injected (see middleware/request.go), so we
|
||||
// can't just overwrite it — replace this block's original text
|
||||
// run in place, preserving the markers around it.
|
||||
if orig, ok := blockMap["text"].(string); ok && orig != "" && msg.StringContent != "" {
|
||||
msg.StringContent = strings.Replace(msg.StringContent, orig, u.Text, 1)
|
||||
}
|
||||
blockMap["text"] = u.Text
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
91
core/services/routing/piiadapter/openai_completion.go
Normal file
91
core/services/routing/piiadapter/openai_completion.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package piiadapter
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
)
|
||||
|
||||
// Field selectors for the prompt-style OpenAI requests (/v1/completions,
|
||||
// /v1/embeddings, /v1/edits), which carry user text in Prompt / Input /
|
||||
// Instruction rather than Messages.
|
||||
const (
|
||||
fldPrompt = iota
|
||||
fldInput
|
||||
fldInstruction
|
||||
)
|
||||
|
||||
// encField packs (field, element) into one ScannedText.Index. element=-1
|
||||
// means the field is a whole string; element>=0 indexes into a []any value.
|
||||
// Stored as element+1 so -1 maps to 0, with the field in the high bits.
|
||||
func encField(field, elem int) int { return (field << 24) | (elem + 1) }
|
||||
func decField(p int) (field, elem int) { return p >> 24, (p & 0xFFFFFF) - 1 }
|
||||
|
||||
// scanAnyText appends scannable strings from a string-or-[]any field. Non-string
|
||||
// array elements (token-id arrays, numbers) are skipped — only human text is
|
||||
// redacted.
|
||||
func scanAnyText(field int, v any, out *[]pii.ScannedText) {
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
if t != "" {
|
||||
*out = append(*out, pii.ScannedText{Index: encField(field, -1), Text: t})
|
||||
}
|
||||
case []any:
|
||||
for k, e := range t {
|
||||
if s, ok := e.(string); ok && s != "" {
|
||||
*out = append(*out, pii.ScannedText{Index: encField(field, k), Text: s})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// applyAnyText writes redacted text back to a string-or-[]any field, returning
|
||||
// the (possibly replaced) value to assign back to the struct field.
|
||||
func applyAnyText(v any, elem int, text string) any {
|
||||
if elem < 0 {
|
||||
return text
|
||||
}
|
||||
if arr, ok := v.([]any); ok && elem >= 0 && elem < len(arr) {
|
||||
arr[elem] = text
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// OpenAICompletion returns a pii.Adapter for the prompt-style OpenAI requests
|
||||
// (completions, embeddings, edits) on *schema.OpenAIRequest. It scans Prompt,
|
||||
// Input and Instruction — the string form and the string elements of an array
|
||||
// form — and writes redacted text back. Chat uses the separate OpenAI()
|
||||
// adapter (Messages); these endpoints leave Messages empty and vice versa.
|
||||
func OpenAICompletion() pii.Adapter {
|
||||
return pii.Adapter{
|
||||
Scan: func(parsed any) []pii.ScannedText {
|
||||
req, ok := parsed.(*schema.OpenAIRequest)
|
||||
if !ok || req == nil {
|
||||
return nil
|
||||
}
|
||||
var out []pii.ScannedText
|
||||
scanAnyText(fldPrompt, req.Prompt, &out)
|
||||
scanAnyText(fldInput, req.Input, &out)
|
||||
if req.Instruction != "" {
|
||||
out = append(out, pii.ScannedText{Index: encField(fldInstruction, -1), Text: req.Instruction})
|
||||
}
|
||||
return out
|
||||
},
|
||||
Apply: func(parsed any, updates []pii.ScannedText) {
|
||||
req, ok := parsed.(*schema.OpenAIRequest)
|
||||
if !ok || req == nil {
|
||||
return
|
||||
}
|
||||
for _, u := range updates {
|
||||
field, elem := decField(u.Index)
|
||||
switch field {
|
||||
case fldPrompt:
|
||||
req.Prompt = applyAnyText(req.Prompt, elem, u.Text)
|
||||
case fldInput:
|
||||
req.Input = applyAnyText(req.Input, elem, u.Text)
|
||||
case fldInstruction:
|
||||
req.Instruction = u.Text
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
59
core/services/routing/piiadapter/openai_completion_test.go
Normal file
59
core/services/routing/piiadapter/openai_completion_test.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package piiadapter
|
||||
|
||||
import (
|
||||
"github.com/mudler/LocalAI/core/schema"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
// applyAll feeds every scanned span back through Apply with the text
|
||||
// transformed by fn — the shape the middleware uses (scan, redact, apply).
|
||||
func applyAll(a pii.Adapter, parsed any, fn func(string) string) {
|
||||
scanned := a.Scan(parsed)
|
||||
updates := make([]pii.ScannedText, 0, len(scanned))
|
||||
for _, s := range scanned {
|
||||
updates = append(updates, pii.ScannedText{Index: s.Index, Text: fn(s.Text)})
|
||||
}
|
||||
a.Apply(parsed, updates)
|
||||
}
|
||||
|
||||
var _ = Describe("OpenAICompletion adapter", func() {
|
||||
a := OpenAICompletion()
|
||||
|
||||
It("scans and rewrites a string prompt", func() {
|
||||
req := &schema.OpenAIRequest{}
|
||||
req.Prompt = "contact alice@example.com"
|
||||
got := a.Scan(req)
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Text).To(Equal("contact alice@example.com"))
|
||||
applyAll(a, req, func(string) string { return "REDACTED" })
|
||||
Expect(req.Prompt).To(Equal("REDACTED"))
|
||||
})
|
||||
|
||||
It("scans array prompt elements and skips non-strings (token ids)", func() {
|
||||
req := &schema.OpenAIRequest{}
|
||||
req.Prompt = []any{"first secret", float64(42), "second secret"}
|
||||
got := a.Scan(req)
|
||||
Expect(got).To(HaveLen(2))
|
||||
applyAll(a, req, func(s string) string { return "[X]" })
|
||||
arr, _ := req.Prompt.([]any)
|
||||
Expect(arr).To(Equal([]any{"[X]", float64(42), "[X]"}))
|
||||
})
|
||||
|
||||
It("scans Input and Instruction (the edit/embeddings shape)", func() {
|
||||
req := &schema.OpenAIRequest{Instruction: "fix the SSN 123-45-6789"}
|
||||
req.Input = "my email is bob@example.com"
|
||||
got := a.Scan(req)
|
||||
Expect(got).To(HaveLen(2))
|
||||
applyAll(a, req, func(string) string { return "*" })
|
||||
Expect(req.Input).To(Equal("*"))
|
||||
Expect(req.Instruction).To(Equal("*"))
|
||||
})
|
||||
|
||||
It("returns nothing for an empty / non-matching request", func() {
|
||||
Expect(a.Scan(&schema.OpenAIRequest{})).To(BeEmpty())
|
||||
Expect(a.Scan(nil)).To(BeNil())
|
||||
})
|
||||
})
|
||||
@@ -54,6 +54,55 @@ var _ = Describe("OpenAI adapter", func() {
|
||||
Expect(req.Messages[1].Content.(string)).To(Equal("REDACTED-1"))
|
||||
})
|
||||
|
||||
It("Apply keeps StringContent in sync for string content", func() {
|
||||
// Regression: the request middleware fills StringContent from Content
|
||||
// at parse time, and the rendered-template path (TemplateMessages)
|
||||
// reads StringContent, not Content. Apply must redact both or the
|
||||
// original leaks to the backend/upstream (e.g. cloud-proxy translate).
|
||||
req := &schema.OpenAIRequest{
|
||||
Messages: []schema.Message{
|
||||
{Role: "user", Content: "my key is sk-secret", StringContent: "my key is sk-secret"},
|
||||
},
|
||||
}
|
||||
adapter := OpenAI()
|
||||
scans := adapter.Scan(req)
|
||||
Expect(scans).To(HaveLen(1))
|
||||
scans[0].Text = "my key is [REDACTED]"
|
||||
adapter.Apply(req, scans)
|
||||
|
||||
Expect(req.Messages[0].Content.(string)).To(Equal("my key is [REDACTED]"))
|
||||
Expect(req.Messages[0].StringContent).To(Equal("my key is [REDACTED]"),
|
||||
"StringContent (what TemplateMessages renders) must be redacted too")
|
||||
})
|
||||
|
||||
It("Apply keeps StringContent in sync for content blocks, preserving media markers", func() {
|
||||
// For multimodal content StringContent is the flattened text with
|
||||
// media markers injected (request.go), so Apply must redact the text
|
||||
// run in place rather than clobber the whole buffer.
|
||||
req := &schema.OpenAIRequest{
|
||||
Messages: []schema.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: []any{
|
||||
map[string]any{"type": "text", "text": "leak sk-secret here"},
|
||||
map[string]any{"type": "image_url", "image_url": map[string]any{"url": "data:image/png;base64,xyz"}},
|
||||
},
|
||||
StringContent: "leak sk-secret here<__media__>",
|
||||
},
|
||||
},
|
||||
}
|
||||
adapter := OpenAI()
|
||||
scans := adapter.Scan(req)
|
||||
Expect(scans).To(HaveLen(1))
|
||||
scans[0].Text = "leak [REDACTED] here"
|
||||
adapter.Apply(req, scans)
|
||||
|
||||
blocks := req.Messages[0].Content.([]any)
|
||||
Expect(blocks[0].(map[string]any)["text"]).To(Equal("leak [REDACTED] here"))
|
||||
Expect(req.Messages[0].StringContent).To(Equal("leak [REDACTED] here<__media__>"),
|
||||
"StringContent must be redacted in place, keeping the media marker")
|
||||
})
|
||||
|
||||
It("Apply mutates content block selectively", func() {
|
||||
req := &schema.OpenAIRequest{
|
||||
Messages: []schema.Message{
|
||||
|
||||
86
core/services/routing/piidetector/detector.go
Normal file
86
core/services/routing/piidetector/detector.go
Normal file
@@ -0,0 +1,86 @@
|
||||
// Package piidetector adapts the core/backend token-classification
|
||||
// wrapper to the PII redactor's pii.NERDetector seam. It lives outside
|
||||
// the pii package so pii stays free of core/backend imports (the
|
||||
// redactor is unit-tested with stub detectors). The dependency runs one
|
||||
// way: piidetector -> {core/backend, pii}.
|
||||
package piidetector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/mudler/xlog"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
// New builds a pii.NERDetector backed by the token-classification model
|
||||
// in modelConfig. Phase 0: the Python `transformers` backend loaded with
|
||||
// Type=TokenClassification; Phase 2: the GGML privacy-filter backend —
|
||||
// both speak the same gRPC TokenClassify contract, so this adapter is
|
||||
// unchanged across the swap. The model is resolved lazily on first
|
||||
// Detect, so building a detector for a not-yet-loaded model is cheap and
|
||||
// never blocks startup.
|
||||
func New(loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) pii.NERDetector {
|
||||
return &nerDetector{
|
||||
classifier: backend.NewTokenClassifier(loader, modelConfig, appConfig, backend.TokenClassifyOptions{}),
|
||||
modelName: modelConfig.Name,
|
||||
}
|
||||
}
|
||||
|
||||
type nerDetector struct {
|
||||
classifier backend.TokenClassifier
|
||||
modelName string
|
||||
}
|
||||
|
||||
// Detect runs the model and maps its spans onto pii.NEREntity. Offsets
|
||||
// pass through as BYTE offsets per the TokenClassify proto contract.
|
||||
// Spans whose offsets fall outside the text or land off a UTF-8 rune
|
||||
// boundary are dropped: a bad offset must never reach the redactor,
|
||||
// which splices text[Start:End] and would otherwise corrupt output or
|
||||
// panic. The redactor applies NERConfig.MinScore and the entity->action
|
||||
// map itself, so we deliberately return every (validated) span here.
|
||||
//
|
||||
// CONTRACT NOTE: the proto defines start/end as UTF-8 byte offsets. The
|
||||
// Python transformers backend converts HuggingFace's codepoint offsets to
|
||||
// bytes before responding (see TokenClassify in backend.py), and the GGML
|
||||
// privacy-filter backend will emit bytes natively. The boundary check
|
||||
// below is defense-in-depth against a backend that regresses to codepoint
|
||||
// offsets: it downgrades the bug from "corrupted redaction / panic" to
|
||||
// "dropped span + warning" rather than trusting the wire blindly.
|
||||
func (d *nerDetector) Detect(ctx context.Context, text string) ([]pii.NEREntity, error) {
|
||||
ents, err := d.classifier.TokenClassify(ctx, text)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n := len(text)
|
||||
out := make([]pii.NEREntity, 0, len(ents))
|
||||
for _, e := range ents {
|
||||
if e.Group == "" || e.Start < 0 || e.Start >= e.End || e.End > n {
|
||||
xlog.Warn("pii NER: dropping span with invalid byte range",
|
||||
"model", d.modelName, "group", e.Group, "start", e.Start, "end", e.End, "len", n)
|
||||
continue
|
||||
}
|
||||
// text[e.Start] is safe (Start < End <= n => Start < n). End is
|
||||
// exclusive: when End < n, text[End] is the first byte past the
|
||||
// span and must itself start a rune. Off-boundary offsets are the
|
||||
// signature of codepoint-vs-byte offset confusion.
|
||||
if !utf8.RuneStart(text[e.Start]) || (e.End < n && !utf8.RuneStart(text[e.End])) {
|
||||
xlog.Warn("pii NER: dropping span off UTF-8 boundary (offset units mismatch?)",
|
||||
"model", d.modelName, "group", e.Group, "start", e.Start, "end", e.End)
|
||||
continue
|
||||
}
|
||||
out = append(out, pii.NEREntity{
|
||||
Group: e.Group,
|
||||
Start: e.Start,
|
||||
End: e.End,
|
||||
Score: e.Score,
|
||||
Text: e.Text,
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
80
core/services/routing/piidetector/pattern.go
Normal file
80
core/services/routing/piidetector/pattern.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package piidetector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/mudler/LocalAI/core/backend"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/piipattern"
|
||||
"github.com/mudler/LocalAI/core/trace"
|
||||
)
|
||||
|
||||
// NewPattern builds a pii.NERDetector that matches secrets with the restricted
|
||||
// regex tier (built-ins + operator-defined patterns) instead of a neural model.
|
||||
// It runs entirely in-process — no backend, GGUF, or VRAM — and the patterns
|
||||
// compile once here, so an invalid pattern is reported now (the resolver fails
|
||||
// closed) rather than per request. Matches are reported under their group with
|
||||
// a deterministic Score of 1.0.
|
||||
func NewPattern(modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (pii.NERDetector, error) {
|
||||
custom := make([]piipattern.Pattern, 0, len(modelConfig.PIIDetection.Patterns))
|
||||
for _, p := range modelConfig.PIIDetection.Patterns {
|
||||
custom = append(custom, piipattern.Pattern{Group: p.Name, Pattern: p.Match, MinLen: p.MinLen})
|
||||
}
|
||||
m, err := piipattern.NewMatcher(modelConfig.PIIDetection.Builtins, custom)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &patternDetector{matcher: m, modelName: modelConfig.Name, appConfig: appConfig}, nil
|
||||
}
|
||||
|
||||
type patternDetector struct {
|
||||
matcher *piipattern.Matcher
|
||||
modelName string
|
||||
appConfig *config.ApplicationConfig
|
||||
}
|
||||
|
||||
// Detect runs the compiled patterns and maps each match onto a pii.NEREntity.
|
||||
// When tracing is enabled it records a pattern_pii BackendTrace so the matches
|
||||
// (group, byte range, text) show in the Traces UI alongside NER detections.
|
||||
func (d *patternDetector) Detect(_ context.Context, text string) ([]pii.NEREntity, error) {
|
||||
var start time.Time
|
||||
if d.appConfig != nil && d.appConfig.EnableTracing {
|
||||
trace.InitBackendTracingIfEnabled(d.appConfig.TracingMaxItems, d.appConfig.TracingMaxBodyBytes)
|
||||
start = time.Now()
|
||||
}
|
||||
|
||||
matches := d.matcher.Find(text)
|
||||
out := make([]pii.NEREntity, 0, len(matches))
|
||||
var traceEnts []backend.TokenEntity
|
||||
for _, mt := range matches {
|
||||
out = append(out, pii.NEREntity{Group: mt.Group, Start: mt.Start, End: mt.End, Score: 1.0, Text: mt.Text})
|
||||
if d.appConfig != nil && d.appConfig.EnableTracing {
|
||||
traceEnts = append(traceEnts, backend.TokenEntity{Group: mt.Group, Start: mt.Start, End: mt.End, Score: 1.0, Text: mt.Text})
|
||||
}
|
||||
}
|
||||
|
||||
if d.appConfig != nil && d.appConfig.EnableTracing {
|
||||
trace.RecordBackendTrace(patternPIITrace(d.modelName, text, traceEnts, start))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// patternPIITrace assembles the Traces-UI row for one pattern-detector run.
|
||||
// Split out so the Data assembly is unit-testable without a request.
|
||||
func patternPIITrace(modelName, text string, entities []backend.TokenEntity, start time.Time) trace.BackendTrace {
|
||||
return trace.BackendTrace{
|
||||
Timestamp: start,
|
||||
Duration: time.Since(start),
|
||||
Type: trace.BackendTracePatternPII,
|
||||
ModelName: modelName,
|
||||
Backend: "pattern",
|
||||
Summary: trace.TruncateString(text, 200),
|
||||
Data: map[string]any{
|
||||
"input_chars": len(text),
|
||||
"matches": len(entities),
|
||||
"entities": entities,
|
||||
},
|
||||
}
|
||||
}
|
||||
61
core/services/routing/piidetector/pattern_test.go
Normal file
61
core/services/routing/piidetector/pattern_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package piidetector_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/core/services/routing/pii"
|
||||
"github.com/mudler/LocalAI/core/services/routing/piidetector"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestPiidetector(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "piidetector suite")
|
||||
}
|
||||
|
||||
func patternModel() config.ModelConfig {
|
||||
c := config.ModelConfig{Name: "secret-filter", Backend: "pattern"}
|
||||
c.PIIDetection.Builtins = []string{"anthropic_api_key"}
|
||||
c.PIIDetection.Patterns = []config.PIIPattern{{Name: "INTERNAL_TOKEN", Match: `tok-[A-Za-z0-9]{8,}`}}
|
||||
return c
|
||||
}
|
||||
|
||||
var _ = Describe("pattern detector", func() {
|
||||
It("matches built-in and custom secrets as whole-span deterministic hits", func() {
|
||||
det, err := piidetector.NewPattern(patternModel(), &config.ApplicationConfig{})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
ents, err := det.Detect(context.Background(), "use sk-ant-api03-AAAABBBBCCCCDDDDEEEE and tok-ABCD1234 ok")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
byGroup := map[string]pii.NEREntity{}
|
||||
for _, e := range ents {
|
||||
byGroup[e.Group] = e
|
||||
Expect(e.Score).To(BeEquivalentTo(float32(1.0)), "pattern matches are deterministic")
|
||||
}
|
||||
Expect(byGroup).To(HaveKey("ANTHROPIC_KEY"))
|
||||
Expect(byGroup["INTERNAL_TOKEN"].Text).To(Equal("tok-ABCD1234"))
|
||||
})
|
||||
|
||||
It("still detects (and exercises the trace path) with tracing enabled", func() {
|
||||
det, err := piidetector.NewPattern(patternModel(), &config.ApplicationConfig{
|
||||
EnableTracing: true, TracingMaxItems: 8,
|
||||
})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
ents, err := det.Detect(context.Background(), "sk-ant-api03-AAAABBBBCCCCDDDDEEEE")
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(ents).To(HaveLen(1))
|
||||
Expect(ents[0].Group).To(Equal("ANTHROPIC_KEY"))
|
||||
})
|
||||
|
||||
It("fails to build on an invalid (unanchored) custom pattern", func() {
|
||||
c := config.ModelConfig{Name: "bad", Backend: "pattern"}
|
||||
c.PIIDetection.Patterns = []config.PIIPattern{{Name: "X", Match: `.*`}}
|
||||
_, err := piidetector.NewPattern(c, &config.ApplicationConfig{})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
61
core/services/routing/piipattern/builtins.go
Normal file
61
core/services/routing/piipattern/builtins.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package piipattern
|
||||
|
||||
import "sort"
|
||||
|
||||
// Builtin is a named, ready-made secret pattern. Group is the uppercase entity
|
||||
// label a match is reported under (so it keys into a detector model's
|
||||
// pii_detection.entity_actions, exactly like an NER group). Every Builtin
|
||||
// pattern is written in the restricted subset and is verified at test time to
|
||||
// pass ValidatePattern and compile.
|
||||
type Builtin struct {
|
||||
Name string
|
||||
Group string
|
||||
Pattern string
|
||||
Description string
|
||||
}
|
||||
|
||||
// builtins is the curated catalogue. Patterns intentionally anchor on each
|
||||
// provider's fixed prefix and require a long high-entropy tail, so they fire on
|
||||
// real credentials and not on ordinary prose. Names are stable identifiers
|
||||
// referenced from a model config's pii_detection.builtins list.
|
||||
var builtins = []Builtin{
|
||||
{"anthropic_api_key", "ANTHROPIC_KEY", `sk-ant-[A-Za-z0-9_-]{20,}`, "Anthropic API key (sk-ant-…)"},
|
||||
{"openai_api_key", "OPENAI_KEY", `sk-(?:proj-)?[A-Za-z0-9_-]{20,}`, "OpenAI API key (sk-… / sk-proj-…)"},
|
||||
{"github_token", "GITHUB_TOKEN", `(?:ghp|gho|ghs|ghr|ghu)_[A-Za-z0-9]{36,}`, "GitHub access token (ghp_/gho_/ghs_/ghr_/ghu_)"},
|
||||
{"github_pat", "GITHUB_TOKEN", `github_pat_[A-Za-z0-9_]{20,}`, "GitHub fine-grained personal access token"},
|
||||
{"aws_access_key", "AWS_ACCESS_KEY", `AKIA[0-9A-Z]{16}`, "AWS access key ID (AKIA…)"},
|
||||
{"google_api_key", "GOOGLE_API_KEY", `AIza[0-9A-Za-z_-]{35}`, "Google API key (AIza…)"},
|
||||
{"slack_token", "SLACK_TOKEN", `xox[baprs]-[0-9A-Za-z-]{10,}`, "Slack token (xoxb-/xoxa-/xoxp-/xoxr-/xoxs-)"},
|
||||
{"stripe_key", "STRIPE_KEY", `(?:sk|rk)_live_[0-9A-Za-z]{16,}`, "Stripe live secret/restricted key"},
|
||||
{"jwt", "JWT", `eyJ[A-Za-z0-9_-]{10,}\.eyJ[A-Za-z0-9_-]{10,}\.[A-Za-z0-9_-]{10,}`, "JSON Web Token (eyJ….eyJ….…)"},
|
||||
{"private_key_block", "PRIVATE_KEY", `-----BEGIN [A-Z ]*PRIVATE KEY-----`, "PEM private-key header"},
|
||||
}
|
||||
|
||||
// BuiltinCatalogue returns the built-in patterns sorted by name. Used by the
|
||||
// config-metadata registry to populate the editor's builtins checklist.
|
||||
func BuiltinCatalogue() []Builtin {
|
||||
out := make([]Builtin, len(builtins))
|
||||
copy(out, builtins)
|
||||
sort.Slice(out, func(i, j int) bool { return out[i].Name < out[j].Name })
|
||||
return out
|
||||
}
|
||||
|
||||
// BuiltinNames returns the built-in pattern names, sorted.
|
||||
func BuiltinNames() []string {
|
||||
out := make([]string, 0, len(builtins))
|
||||
for _, b := range builtins {
|
||||
out = append(out, b.Name)
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// LookupBuiltin finds a built-in by name.
|
||||
func LookupBuiltin(name string) (Builtin, bool) {
|
||||
for _, b := range builtins {
|
||||
if b.Name == name {
|
||||
return b, true
|
||||
}
|
||||
}
|
||||
return Builtin{}, false
|
||||
}
|
||||
20
core/services/routing/piipattern/compile.go
Normal file
20
core/services/routing/piipattern/compile.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package piipattern
|
||||
|
||||
import "regexp"
|
||||
|
||||
// Compile validates src against the restricted grammar and, if it passes,
|
||||
// compiles it to an RE2 program set to leftmost-longest matching so a hit grabs
|
||||
// the whole secret (the entire key) rather than the shortest prefix.
|
||||
func Compile(src string) (*regexp.Regexp, error) {
|
||||
if err := ValidatePattern(src); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
re, err := regexp.Compile(src)
|
||||
if err != nil {
|
||||
// ValidatePattern already parsed with the same flags, so this is
|
||||
// effectively unreachable, but surface it rather than panic.
|
||||
return nil, err
|
||||
}
|
||||
re.Longest()
|
||||
return re, nil
|
||||
}
|
||||
163
core/services/routing/piipattern/grammar.go
Normal file
163
core/services/routing/piipattern/grammar.go
Normal file
@@ -0,0 +1,163 @@
|
||||
// Package piipattern is a bounded, restricted-regex matcher for high-entropy,
|
||||
// highly-regular secrets (API keys, tokens, private-key blocks) that the NER
|
||||
// PII tier cannot catch — it has no credential class, so it fragments a key
|
||||
// into the nearest-looking trained categories and may leave the secret part
|
||||
// exposed.
|
||||
//
|
||||
// The language is a deliberately restricted subset of regular expressions
|
||||
// compiled to Go's RE2 engine (regexp), which is linear-time with no
|
||||
// backtracking — there is no ReDoS class of failure. On top of RE2 we cap the
|
||||
// pattern source length, the {n,m} expansion bound, the pattern count, and the
|
||||
// scanned input, and we require every pattern to carry a fixed literal
|
||||
// "anchor". The anchor rule is what admits `sk-ant-…` / `ghp_…` style keys
|
||||
// while rejecting open-ended shapes like an email address or a bare `\w+`
|
||||
// (which would match almost anything) — those stay with the NER tier.
|
||||
//
|
||||
// This package is a leaf: it imports only the standard library, so both
|
||||
// core/config (validation at load) and core/application (the resolver) can use
|
||||
// it without an import cycle.
|
||||
package piipattern
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp/syntax"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxPatternLen caps the source length of a single pattern. Generous for a
|
||||
// credential shape, small enough that the compiled program stays tiny.
|
||||
MaxPatternLen = 256
|
||||
// MaxQuantifier caps an explicit {n,m} upper bound. RE2 expands a bounded
|
||||
// repeat into that many copies, so an uncapped {0,1000000} would blow up
|
||||
// the compiled program's memory. Unbounded {n,} (no upper) is a loop, not
|
||||
// an expansion, and is allowed.
|
||||
MaxQuantifier = 4096
|
||||
// MaxAlternation caps the arms of a single `a|b|c` alternation.
|
||||
MaxAlternation = 64
|
||||
// MaxAST bounds recursion depth so a pathologically nested pattern can't
|
||||
// blow the stack during validation.
|
||||
MaxAST = 64
|
||||
// MinAnchorLen is the shortest fixed literal run a pattern must contain to
|
||||
// be considered "anchored" to a recognisable secret prefix/shape.
|
||||
MinAnchorLen = 3
|
||||
)
|
||||
|
||||
// parseFlags enables Perl character classes (\w \d \s) and word boundaries,
|
||||
// matching what regexp.Compile uses, so validation and compilation agree.
|
||||
const parseFlags = syntax.Perl
|
||||
|
||||
// ValidatePattern reports whether src is an acceptable restricted-subset
|
||||
// pattern. It returns a descriptive error naming the offending construct so an
|
||||
// operator editing a model config gets actionable feedback (the error is
|
||||
// surfaced by config Validate at load and by the resolver, which fails closed).
|
||||
func ValidatePattern(src string) error {
|
||||
if src == "" {
|
||||
return fmt.Errorf("pattern is empty")
|
||||
}
|
||||
if len(src) > MaxPatternLen {
|
||||
return fmt.Errorf("pattern is too long (%d chars; max %d)", len(src), MaxPatternLen)
|
||||
}
|
||||
re, err := syntax.Parse(src, parseFlags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid pattern: %w", err)
|
||||
}
|
||||
if err := walk(re, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
if anchorLen(re) < MinAnchorLen {
|
||||
return fmt.Errorf("pattern must contain a fixed literal run of at least %d characters "+
|
||||
"(e.g. \"sk-ant-\", \"ghp_\", \"AKIA\") so it is anchored to a recognisable secret; "+
|
||||
"open-ended shapes like emails or bare \\w+ belong to the NER tier", MinAnchorLen)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// walk enforces the allow-list of regex constructs.
|
||||
func walk(re *syntax.Regexp, depth int) error {
|
||||
if depth > MaxAST {
|
||||
return fmt.Errorf("pattern is too deeply nested")
|
||||
}
|
||||
switch re.Op {
|
||||
case syntax.OpAnyChar, syntax.OpAnyCharNotNL:
|
||||
return fmt.Errorf("'.' (any character) is not allowed; use an explicit class like [A-Za-z0-9]")
|
||||
case syntax.OpCapture:
|
||||
return fmt.Errorf("capturing groups are not allowed; use a non-capturing group (?:…) if you need grouping")
|
||||
case syntax.OpRepeat:
|
||||
if re.Min > MaxQuantifier || (re.Max >= 0 && re.Max > MaxQuantifier) {
|
||||
return fmt.Errorf("{n,m} bound is too large (max %d)", MaxQuantifier)
|
||||
}
|
||||
case syntax.OpAlternate:
|
||||
if len(re.Sub) > MaxAlternation {
|
||||
return fmt.Errorf("too many alternation arms (%d; max %d)", len(re.Sub), MaxAlternation)
|
||||
}
|
||||
case syntax.OpLiteral, syntax.OpCharClass, syntax.OpConcat,
|
||||
syntax.OpStar, syntax.OpPlus, syntax.OpQuest,
|
||||
syntax.OpEmptyMatch,
|
||||
syntax.OpBeginLine, syntax.OpEndLine, syntax.OpBeginText, syntax.OpEndText,
|
||||
syntax.OpWordBoundary, syntax.OpNoWordBoundary:
|
||||
// allowed
|
||||
default:
|
||||
return fmt.Errorf("unsupported construct in pattern")
|
||||
}
|
||||
for _, sub := range re.Sub {
|
||||
if err := walk(sub, depth+1); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// anchorLen returns the number of fixed (non-space) literal characters every
|
||||
// match of re is guaranteed to contain — the pattern's "anchor strength".
|
||||
// Concatenation sums its parts; alternation takes the min (every arm must
|
||||
// carry the anchor); a `+`/{n,} with n>=1 contributes its body's literal once;
|
||||
// `*`, `?`, {0,m} and char classes/anchors contribute 0 (they may be absent).
|
||||
//
|
||||
// We sum rather than measure the longest contiguous run because RE2 factors
|
||||
// common prefixes — `(?:ghp|gho|ghs)_…` parses to `gh[ops]_…`, whose longest
|
||||
// contiguous literal is only "gh" (2) but whose guaranteed literals are
|
||||
// "gh"+"_" (3). Summing keeps such real key prefixes admissible while still
|
||||
// rejecting open-ended shapes: an email `[\w.]+@[\w.]+\.\w+` guarantees only
|
||||
// `@` and `.` (2 < MinAnchorLen).
|
||||
func anchorLen(re *syntax.Regexp) int {
|
||||
switch re.Op {
|
||||
case syntax.OpLiteral:
|
||||
n := 0
|
||||
for _, r := range re.Rune {
|
||||
if r != ' ' && r != '\t' && r != '\n' && r != '\r' {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
case syntax.OpConcat:
|
||||
sum := 0
|
||||
for _, sub := range re.Sub {
|
||||
sum += anchorLen(sub)
|
||||
}
|
||||
return sum
|
||||
case syntax.OpAlternate:
|
||||
if len(re.Sub) == 0 {
|
||||
return 0
|
||||
}
|
||||
min := -1
|
||||
for _, sub := range re.Sub {
|
||||
if a := anchorLen(sub); min < 0 || a < min {
|
||||
min = a
|
||||
}
|
||||
}
|
||||
return min
|
||||
case syntax.OpPlus:
|
||||
if len(re.Sub) == 1 {
|
||||
return anchorLen(re.Sub[0])
|
||||
}
|
||||
return 0
|
||||
case syntax.OpRepeat:
|
||||
if re.Min >= 1 && len(re.Sub) == 1 {
|
||||
return anchorLen(re.Sub[0])
|
||||
}
|
||||
return 0
|
||||
default:
|
||||
// char classes, anchors, OpStar, OpQuest carry no guaranteed literal.
|
||||
return 0
|
||||
}
|
||||
}
|
||||
100
core/services/routing/piipattern/matcher.go
Normal file
100
core/services/routing/piipattern/matcher.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package piipattern
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxPatternsPerMatcher bounds how many patterns one detector may hold.
|
||||
MaxPatternsPerMatcher = 128
|
||||
// MaxMatchesPerPattern bounds matches emitted per pattern per call, so a
|
||||
// pathological input can't produce an unbounded result set.
|
||||
MaxMatchesPerPattern = 1000
|
||||
)
|
||||
|
||||
// Pattern is one compiled-ready rule: matches are reported under Group, and a
|
||||
// match shorter than MinLen bytes is dropped (0 = no floor).
|
||||
type Pattern struct {
|
||||
Group string
|
||||
Pattern string
|
||||
MinLen int
|
||||
}
|
||||
|
||||
// Match is one detected span: a half-open byte range [Start,End) into the
|
||||
// scanned text, the matched text, and the reporting Group.
|
||||
type Match struct {
|
||||
Group string
|
||||
Start int
|
||||
End int
|
||||
Text string
|
||||
}
|
||||
|
||||
type compiled struct {
|
||||
group string
|
||||
re *regexp.Regexp
|
||||
minLen int
|
||||
}
|
||||
|
||||
// Matcher holds a set of compiled patterns and scans text for all of them.
|
||||
type Matcher struct {
|
||||
pats []compiled
|
||||
}
|
||||
|
||||
// NewMatcher compiles the named built-ins plus the custom patterns into a
|
||||
// Matcher. Unknown built-in names and patterns that fail the restricted grammar
|
||||
// are reported as errors (the caller fails closed). Built-in and custom counts
|
||||
// together may not exceed MaxPatternsPerMatcher.
|
||||
func NewMatcher(builtinNames []string, custom []Pattern) (*Matcher, error) {
|
||||
if len(builtinNames)+len(custom) > MaxPatternsPerMatcher {
|
||||
return nil, fmt.Errorf("too many patterns (%d; max %d)", len(builtinNames)+len(custom), MaxPatternsPerMatcher)
|
||||
}
|
||||
m := &Matcher{}
|
||||
for _, name := range builtinNames {
|
||||
b, ok := LookupBuiltin(name)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown built-in pattern %q", name)
|
||||
}
|
||||
re, err := Compile(b.Pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("built-in %q: %w", name, err)
|
||||
}
|
||||
m.pats = append(m.pats, compiled{group: b.Group, re: re})
|
||||
}
|
||||
for _, p := range custom {
|
||||
if p.Group == "" {
|
||||
return nil, fmt.Errorf("custom pattern is missing a name/group")
|
||||
}
|
||||
re, err := Compile(p.Pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pattern %q: %w", p.Group, err)
|
||||
}
|
||||
m.pats = append(m.pats, compiled{group: p.Group, re: re, minLen: p.MinLen})
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Find returns every match of every pattern over text. Spans from different
|
||||
// patterns may overlap; the caller (the redactor) unions and resolves them.
|
||||
func (m *Matcher) Find(text string) []Match {
|
||||
if m == nil || text == "" {
|
||||
return nil
|
||||
}
|
||||
var out []Match
|
||||
for _, p := range m.pats {
|
||||
locs := p.re.FindAllStringIndex(text, MaxMatchesPerPattern)
|
||||
for _, loc := range locs {
|
||||
start, end := loc[0], loc[1]
|
||||
if end-start < p.minLen {
|
||||
continue
|
||||
}
|
||||
out = append(out, Match{
|
||||
Group: p.group,
|
||||
Start: start,
|
||||
End: end,
|
||||
Text: text[start:end],
|
||||
})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
105
core/services/routing/piipattern/piipattern_test.go
Normal file
105
core/services/routing/piipattern/piipattern_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package piipattern
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
func TestPiipattern(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "piipattern suite")
|
||||
}
|
||||
|
||||
var _ = Describe("ValidatePattern", func() {
|
||||
DescribeTable("accepts anchored, bounded patterns",
|
||||
func(src string) { Expect(ValidatePattern(src)).To(Succeed()) },
|
||||
Entry("anthropic", `sk-ant-[A-Za-z0-9_-]{20,200}`),
|
||||
Entry("github via alternation", `(?:ghp|gho|ghs)_[A-Za-z0-9]{36,}`),
|
||||
Entry("custom token", `tok-\w{32,64}`),
|
||||
Entry("aws", `AKIA[0-9A-Z]{16}`),
|
||||
Entry("anchored by mid-literal", `(?:sk|rk)_live_[0-9A-Za-z]{16,}`),
|
||||
)
|
||||
|
||||
DescribeTable("rejects unanchored or unsafe patterns",
|
||||
func(src string) { Expect(ValidatePattern(src)).NotTo(Succeed()) },
|
||||
Entry("email (no fixed anchor)", `[\w.]+@[\w.]+\.\w+`),
|
||||
Entry("bare word run", `\w+`),
|
||||
Entry("any-char greedy", `sk-.*`),
|
||||
Entry("capturing group", `(sk-ant-[A-Za-z0-9]+)`),
|
||||
Entry("two fixed chars only", `ab[0-9]{8,}`),
|
||||
Entry("over-long source", "sk-ant-"+strings.Repeat("a", MaxPatternLen)),
|
||||
Entry("huge bounded repeat", `sk-ant-[A-Za-z0-9]{5000}`),
|
||||
Entry("empty", ``),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("Compile", func() {
|
||||
It("compiles a valid pattern with leftmost-longest semantics", func() {
|
||||
re, err := Compile(`sk-ant-[A-Za-z0-9_-]{4,}`)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
// Longest() makes the match span the whole key, not a shorter prefix.
|
||||
loc := re.FindString("key sk-ant-AAAA1111bbbb end")
|
||||
Expect(loc).To(Equal("sk-ant-AAAA1111bbbb"))
|
||||
})
|
||||
It("refuses an invalid pattern", func() {
|
||||
_, err := Compile(`.*`)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
|
||||
var _ = Describe("builtins", func() {
|
||||
It("every built-in validates, compiles, and is uniquely named", func() {
|
||||
seen := map[string]bool{}
|
||||
for _, b := range BuiltinCatalogue() {
|
||||
Expect(seen[b.Name]).To(BeFalse(), "duplicate builtin %s", b.Name)
|
||||
seen[b.Name] = true
|
||||
Expect(ValidatePattern(b.Pattern)).To(Succeed(), "builtin %s pattern %q", b.Name, b.Pattern)
|
||||
}
|
||||
})
|
||||
|
||||
DescribeTable("matches a real sample and not a decoy",
|
||||
func(name, sample, decoy string) {
|
||||
b, ok := LookupBuiltin(name)
|
||||
Expect(ok).To(BeTrue())
|
||||
re, err := Compile(b.Pattern)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(re.MatchString(sample)).To(BeTrue(), "should match %q", sample)
|
||||
Expect(re.MatchString(decoy)).To(BeFalse(), "should not match %q", decoy)
|
||||
},
|
||||
Entry("anthropic", "anthropic_api_key", "sk-ant-api03-AbCdEf012345_-AbCdEf012345", "sk-ant-short"),
|
||||
Entry("aws", "aws_access_key", "AKIAIOSFODNN7EXAMPLE", "AKIAshort"),
|
||||
Entry("github", "github_token", "ghp_"+strings.Repeat("a", 36), "ghp_short"),
|
||||
)
|
||||
})
|
||||
|
||||
var _ = Describe("Matcher", func() {
|
||||
It("reports the whole key as one span under its group", func() {
|
||||
m, err := NewMatcher([]string{"anthropic_api_key"}, nil)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
got := m.Find("my key is sk-ant-api03-AbCdEf012345AbCdEf012345 thanks")
|
||||
Expect(got).To(HaveLen(1))
|
||||
Expect(got[0].Group).To(Equal("ANTHROPIC_KEY"))
|
||||
Expect(got[0].Text).To(Equal("sk-ant-api03-AbCdEf012345AbCdEf012345"))
|
||||
})
|
||||
|
||||
It("compiles custom patterns and honours MinLen", func() {
|
||||
m, err := NewMatcher(nil, []Pattern{{Group: "INTERNAL", Pattern: `tok-[A-Za-z0-9]{4,}`, MinLen: 12}})
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
// "tok-AAAA" (8 bytes) is below MinLen 12 and is dropped.
|
||||
Expect(m.Find("tok-AAAA")).To(BeEmpty())
|
||||
Expect(m.Find("tok-AAAABBBBCCCC")).To(HaveLen(1))
|
||||
})
|
||||
|
||||
It("fails closed on an unknown built-in", func() {
|
||||
_, err := NewMatcher([]string{"nope"}, nil)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("rejects an invalid custom pattern", func() {
|
||||
_, err := NewMatcher(nil, []Pattern{{Group: "X", Pattern: `.*`}})
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
})
|
||||
@@ -34,6 +34,8 @@ const (
|
||||
BackendTraceAudioTransform BackendTraceType = "audio_transform"
|
||||
BackendTraceModelLoad BackendTraceType = "model_load"
|
||||
BackendTraceScore BackendTraceType = "score"
|
||||
BackendTraceTokenClassify BackendTraceType = "token_classify"
|
||||
BackendTracePatternPII BackendTraceType = "pattern_pii"
|
||||
BackendTraceVectorStore BackendTraceType = "vector_store"
|
||||
)
|
||||
|
||||
@@ -59,10 +61,12 @@ type BackendTrace struct {
|
||||
// runaway buffer when a caller streams MB-scale payloads.
|
||||
const MaxTraceBodyBytes = 1 << 20
|
||||
|
||||
var backendTraceBuffer *circularbuffer.Queue[*BackendTrace]
|
||||
var backendMu sync.Mutex
|
||||
var backendLogChan = make(chan *BackendTrace, 100)
|
||||
var backendInitOnce sync.Once
|
||||
var (
|
||||
backendTraceBuffer *circularbuffer.Queue[*BackendTrace]
|
||||
backendMu sync.Mutex
|
||||
backendLogChan = make(chan *BackendTrace, 100)
|
||||
backendInitOnce sync.Once
|
||||
)
|
||||
|
||||
// backendMaxBodyBytes caps each captured string value in a BackendTrace.Data
|
||||
// field to keep the /api/backend-traces JSON small enough for the admin UI to
|
||||
|
||||
Reference in New Issue
Block a user