feat(middleware): Model routing, PII filtering, Cloud model proxies (#9802)

Add a routing middleware stack and a cloud-proxy backend.

* cloud-proxy: a Go gRPC backend that forwards OpenAI- and
  Anthropic-shaped chat requests to upstream providers, with an
  optional translate mode (OpenAI request -> Anthropic /v1/messages
  -> OpenAI response) and full tool-calling support.

* routing: admission control, content-aware model routing
  (embedding cache + classifier + rerank + Arch-Router score),
  PII detection/redaction (regex + NER) with streaming filter and
  OpenAI/Anthropic adapters, and a per-user/per-key billing recorder
  backed by GORM or in-memory storage.

* middleware: UsageMiddleware records usage via the billing recorder,
  plus admission, route-model, usage-stamp and trace middlewares.

* observability: BackendTrace ring buffer stores full request bodies
  (capped), MITM proxy emits structured trace events, and router
  classifier decisions surface at /api/router/decide.

* gallery: Arch-Router-1.5B (Q4_K_M and Q8_0).

* UI: cloud-proxy model-editor fields, classifier system-prompt and
  score-normalization config, and a Traces page rendering request
  bodies.

Assisted-by: claude-code:claude-opus-4-7 [Read] [Edit] [Bash]

Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
Richard Palethorpe
2026-05-25 08:28:27 +01:00
committed by GitHub
parent 1dcd1ae915
commit 6a80e23733
229 changed files with 26339 additions and 1030 deletions

View File

@@ -26,7 +26,7 @@ import torch.cuda
XPU=os.environ.get("XPU", "0") == "1"
import transformers as transformers_module
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria
from transformers import AutoTokenizer, AutoModel, AutoProcessor, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, pipeline
from scipy.io import wavfile
from sentence_transformers import SentenceTransformer
@@ -200,6 +200,21 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
autoTokenizer = False
self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode)
self.SentenceTransformer = True
elif request.Type == "TokenClassification":
# NER / PII tagging via HuggingFace's token-classification
# pipeline. aggregation_strategy="simple" merges B-/I- tags
# into single spans and gives byte offsets back. The
# tokenizer is bundled inside the pipeline, so we skip the
# AutoTokenizer load below.
autoTokenizer = False
self.tokenClassifier = pipeline(
"token-classification",
model=model_name,
aggregation_strategy="simple",
device=0 if self.CUDA else -1,
trust_remote_code=request.TrustRemoteCode,
)
self.TokenClassification = True
else:
# Generic: dynamically resolve model class from transformers
model_type = TYPE_ALIASES.get(request.Type, request.Type)
@@ -253,6 +268,39 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)
def TokenClassify(self, request, context):
# Runs HuggingFace's token-classification pipeline and returns
# the aggregated entity spans. The pipeline gives us byte
# offsets via aggregation_strategy="simple" (set at load
# time), so the caller can slice the original text without
# re-tokenising on the Go side.
if not getattr(self, "TokenClassification", False):
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("model was not loaded as Type=TokenClassification")
return backend_pb2.TokenClassifyResponse()
try:
results = self.tokenClassifier(request.text)
except Exception as err:
print("TokenClassify error:", err, file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"token-classification failed: {err}")
return backend_pb2.TokenClassifyResponse()
threshold = request.threshold if request.threshold > 0 else 0.0
entities = []
for r in results:
score = float(r.get("score", 0.0))
if score < threshold:
continue
entities.append(backend_pb2.TokenClassifyEntity(
entity_group=str(r.get("entity_group") or r.get("entity") or ""),
start=int(r.get("start", 0)),
end=int(r.get("end", 0)),
score=score,
text=str(r.get("word", "")),
))
return backend_pb2.TokenClassifyResponse(entities=entities)
def Embedding(self, request, context):
set_seed(request.Seed)
# Tokenize input

View File

@@ -356,6 +356,133 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
except Exception as e:
return backend_pb2.Result(success=False, message=str(e))
async def Score(self, request, context):
"""
Joint log-probability of each candidate continuation given the
shared prompt. Used by routing-policy multi-label classification
(read the distribution rather than asking the model to emit a
single argmax label), reranking, and reward-model scoring.
Implementation uses vLLM's `prompt_logprobs` to recover the
per-token log P(token_i | tokens_<i) for the full concatenated
sequence; the candidate's tokens are the suffix whose logprobs
get summed. max_tokens=1 because vLLM requires at least one
generated token; the generated token is discarded.
"""
if not hasattr(self, 'llm') or self.llm is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("Model not loaded")
return backend_pb2.ScoreResponse()
if not hasattr(self, 'tokenizer') or self.tokenizer is None:
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
context.set_details("Tokenizer not available")
return backend_pb2.ScoreResponse()
if len(request.candidates) == 0:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details("candidates must be non-empty")
return backend_pb2.ScoreResponse()
try:
prompt = request.prompt or ""
prompt_token_ids = self.tokenizer.encode(prompt)
prompt_len = len(prompt_token_ids)
results = []
for candidate in request.candidates:
# Tokenise the concatenated sequence. We can't naively
# use len(prompt_tokens) + len(tokenizer.encode(candidate))
# because BPE merges at the boundary may produce a
# different tokenisation. Encoding the joined text and
# walking the divergence point is the correct primitive.
full_text = prompt + candidate
full_token_ids = self.tokenizer.encode(full_text)
divergence = prompt_len
min_len = min(prompt_len, len(full_token_ids))
for i in range(min_len):
if prompt_token_ids[i] != full_token_ids[i]:
divergence = i
break
candidate_token_ids = full_token_ids[divergence:]
num_candidate_tokens = len(candidate_token_ids)
if num_candidate_tokens == 0:
results.append(backend_pb2.CandidateScore(
log_prob=0.0,
length_normalized_log_prob=0.0,
num_tokens=0,
))
continue
sampling = SamplingParams(
max_tokens=1,
temperature=0.0,
prompt_logprobs=1,
detokenize=False,
)
request_id = random_uuid()
last_output = None
outputs_iter = self.llm.generate(
{"prompt": full_text},
sampling_params=sampling,
request_id=request_id,
)
try:
async for out in outputs_iter:
last_output = out
finally:
try:
await outputs_iter.aclose()
except Exception:
pass
if last_output is None or not getattr(last_output, "prompt_logprobs", None):
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details("vLLM did not return prompt_logprobs")
return backend_pb2.ScoreResponse()
prompt_logprobs = last_output.prompt_logprobs
total = 0.0
tokens_proto = []
for offset, tok_id in enumerate(candidate_token_ids):
position = divergence + offset
if position >= len(prompt_logprobs) or prompt_logprobs[position] is None:
continue
entry = prompt_logprobs[position]
lp_obj = entry.get(tok_id)
if lp_obj is not None:
lp = lp_obj.logprob
else:
# Token not in top-K; vLLM's top-1 may miss it.
# Fall back to the lowest available logprob in the
# entry — a conservative lower-bound on the true
# log P, biased against this candidate.
lp = min(v.logprob for v in entry.values())
total += lp
if request.include_token_logprobs:
tokens_proto.append(backend_pb2.TokenLogProb(
token=self.tokenizer.decode([tok_id]),
log_prob=lp,
))
cs = backend_pb2.CandidateScore(
log_prob=total,
num_tokens=num_candidate_tokens,
)
if request.length_normalize and num_candidate_tokens > 0:
cs.length_normalized_log_prob = total / num_candidate_tokens
if tokens_proto:
cs.tokens.extend(tokens_proto)
results.append(cs)
return backend_pb2.ScoreResponse(candidates=results)
except Exception as e:
print(f"Score error: {e}", file=sys.stderr)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(str(e))
return backend_pb2.ScoreResponse()
async def _predict(self, request, context, streaming=False):
# Build the sampling parameters
# NOTE: this must stay in sync with the vllm backend