diff --git a/backend/index.yaml b/backend/index.yaml index a1f5688a8..d0f75a4ca 100644 --- a/backend/index.yaml +++ b/backend/index.yaml @@ -197,6 +197,7 @@ amd: "rocm-vllm" intel: "intel-vllm" nvidia-cuda-12: "cuda12-vllm" + cpu: "cpu-vllm" - &vllm-omni name: "vllm-omni" license: apache-2.0 @@ -1563,6 +1564,7 @@ nvidia: "cuda12-vllm-development" amd: "rocm-vllm-development" intel: "intel-vllm-development" + cpu: "cpu-vllm-development" - !!merge <<: *vllm name: "cuda12-vllm" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-vllm" @@ -1578,6 +1580,11 @@ uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-vllm" mirrors: - localai/localai-backends:latest-gpu-intel-vllm +- !!merge <<: *vllm + name: "cpu-vllm" + uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-vllm" + mirrors: + - localai/localai-backends:latest-cpu-vllm - !!merge <<: *vllm name: "cuda12-vllm-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-vllm" @@ -1593,6 +1600,11 @@ uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-vllm" mirrors: - localai/localai-backends:master-gpu-intel-vllm +- !!merge <<: *vllm + name: "cpu-vllm-development" + uri: "quay.io/go-skynet/local-ai-backends:master-cpu-vllm" + mirrors: + - localai/localai-backends:master-cpu-vllm # vllm-omni - !!merge <<: *vllm-omni name: "vllm-omni-development" diff --git a/backend/python/common/vllm_utils.py b/backend/python/common/vllm_utils.py new file mode 100644 index 000000000..bc0518663 --- /dev/null +++ b/backend/python/common/vllm_utils.py @@ -0,0 +1,84 @@ +"""Shared utilities for vLLM-based backends.""" +import json +import sys + + +def parse_options(options_list): + """Parse Options[] list of 'key:value' strings into a dict. + + Supports type inference for common cases (bool, int, float). + Used by LoadModel to extract backend-specific options. + """ + opts = {} + for opt in options_list: + if ":" not in opt: + continue + key, value = opt.split(":", 1) + key = key.strip() + value = value.strip() + # Try type conversion + if value.lower() in ("true", "false"): + opts[key] = value.lower() == "true" + else: + try: + opts[key] = int(value) + except ValueError: + try: + opts[key] = float(value) + except ValueError: + opts[key] = value + return opts + + +def messages_to_dicts(proto_messages): + """Convert proto Message objects to list of dicts for apply_chat_template(). + + Handles: role, content, name, tool_call_id, reasoning_content, tool_calls (JSON string -> list). + """ + result = [] + for msg in proto_messages: + d = {"role": msg.role, "content": msg.content or ""} + if msg.name: + d["name"] = msg.name + if msg.tool_call_id: + d["tool_call_id"] = msg.tool_call_id + if msg.reasoning_content: + d["reasoning_content"] = msg.reasoning_content + if msg.tool_calls: + try: + d["tool_calls"] = json.loads(msg.tool_calls) + except json.JSONDecodeError: + pass + result.append(d) + return result + + +def setup_parsers(opts): + """Return (tool_parser_cls, reasoning_parser_cls) tuple from opts dict. + + Uses vLLM's native ToolParserManager and ReasoningParserManager. + Returns (None, None) if vLLM is not installed or parsers not available. + """ + tool_parser_cls = None + reasoning_parser_cls = None + + tool_parser_name = opts.get("tool_parser") + reasoning_parser_name = opts.get("reasoning_parser") + + if tool_parser_name: + try: + from vllm.tool_parsers import ToolParserManager + tool_parser_cls = ToolParserManager.get_tool_parser(tool_parser_name) + print(f"[vllm_utils] Loaded tool_parser: {tool_parser_name}", file=sys.stderr) + except Exception as e: + print(f"[vllm_utils] Failed to load tool_parser {tool_parser_name}: {e}", file=sys.stderr) + + if reasoning_parser_name: + try: + from vllm.reasoning import ReasoningParserManager + reasoning_parser_cls = ReasoningParserManager.get_reasoning_parser(reasoning_parser_name) + print(f"[vllm_utils] Loaded reasoning_parser: {reasoning_parser_name}", file=sys.stderr) + except Exception as e: + print(f"[vllm_utils] Failed to load reasoning_parser {reasoning_parser_name}: {e}", file=sys.stderr) + + return tool_parser_cls, reasoning_parser_cls diff --git a/backend/python/vllm-omni/backend.py b/backend/python/vllm-omni/backend.py index 96eb8a111..646af2a2e 100644 --- a/backend/python/vllm-omni/backend.py +++ b/backend/python/vllm-omni/backend.py @@ -17,6 +17,8 @@ import time import os import base64 import io +import json +import gc from PIL import Image import torch @@ -30,6 +32,7 @@ import grpc sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common')) from grpc_auth import get_auth_interceptors +from vllm_utils import parse_options, messages_to_dicts, setup_parsers from vllm_omni.entrypoints.omni import Omni @@ -148,23 +151,20 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): def LoadModel(self, request, context): try: + # CPU detection: if no CUDA, default vLLM target device to CPU. + try: + if not torch.cuda.is_available(): + os.environ.setdefault("VLLM_TARGET_DEVICE", "cpu") + os.environ.setdefault("VLLM_CPU_KVCACHE_SPACE", "4") + except Exception: + pass + print(f"Loading model {request.Model}...", file=sys.stderr) print(f"Request {request}", file=sys.stderr) - # Parse options from request.Options (key:value pairs) - self.options = {} - for opt in request.Options: - if ":" not in opt: - continue - key, value = opt.split(":", 1) - # Convert value to appropriate type - if is_float(value): - value = float(value) - elif is_int(value): - value = int(value) - elif value.lower() in ["true", "false"]: - value = value.lower() == "true" - self.options[key] = value + # Parse options from request.Options using shared helper + self.options = parse_options(request.Options) + opts = self.options print(f"Options: {self.options}", file=sys.stderr) @@ -244,6 +244,24 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): omni_kwargs["max_model_len"] = request.MaxModelLen self.omni = Omni(**omni_kwargs) + + # Load tokenizer for LLM/TTS so chat templates work + if self.model_type in ("llm", "tts"): + try: + from vllm.transformers_utils.tokenizer import get_tokenizer + self.tokenizer = get_tokenizer( + request.Model, + trust_remote_code=opts.get("trust_remote_code", False), + ) + except Exception as e: + print(f"Failed to load tokenizer: {e}", file=sys.stderr) + self.tokenizer = None + else: + self.tokenizer = None + + # Setup optional tool / reasoning parsers + self.tool_parser_cls, self.reasoning_parser_cls = setup_parsers(opts) + print("Model loaded successfully", file=sys.stderr) return backend_pb2.Result(message="Model loaded successfully", success=True) @@ -466,14 +484,32 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): # Extract prompt if request.Prompt: prompt = request.Prompt - elif request.Messages and request.UseTokenizerTemplate: - # Build prompt from messages (simplified - would need tokenizer for full template) - prompt = "" - for msg in request.Messages: - role = msg.role - content = msg.content - prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n" - prompt += "<|im_start|>assistant\n" + elif request.Messages: + if getattr(self, "tokenizer", None) is not None: + messages_dicts = messages_to_dicts(request.Messages) + template_kwargs = {"tokenize": False, "add_generation_prompt": True} + if request.Tools: + try: + template_kwargs["tools"] = json.loads(request.Tools) + except json.JSONDecodeError: + pass + try: + if request.Metadata.get("enable_thinking", "").lower() == "true": + template_kwargs["enable_thinking"] = True + except Exception: + pass + try: + prompt = self.tokenizer.apply_chat_template(messages_dicts, **template_kwargs) + except TypeError: + prompt = self.tokenizer.apply_chat_template( + messages_dicts, tokenize=False, add_generation_prompt=True + ) + else: + # Fallback: basic template + prompt = "" + for msg in request.Messages: + prompt += f"<|im_start|>{msg.role}\n{msg.content}<|im_end|>\n" + prompt += "<|im_start|>assistant\n" else: yield backend_pb2.Reply(message=bytes("", 'utf-8')) return @@ -539,20 +575,79 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): # Call omni.generate() (returns generator for LLM mode) omni_generator = self.omni.generate([inputs], sampling_params_list) - # Extract text from outputs + # Extract text from outputs and track token usage generated_text = "" + prompt_tokens = 0 + completion_tokens = 0 for stage_outputs in omni_generator: if stage_outputs.final_output_type == "text": for output in stage_outputs.request_output: - text_output = output.outputs[0].text + completion = output.outputs[0] + text_output = completion.text + # Track tokens when available + try: + if getattr(output, "prompt_token_ids", None) is not None: + prompt_tokens = len(output.prompt_token_ids) + if getattr(completion, "token_ids", None) is not None: + completion_tokens = len(completion.token_ids) + except Exception: + pass if streaming: # Remove already sent text (vllm concatenates) delta_text = text_output.removeprefix(generated_text) - yield backend_pb2.Reply(message=bytes(delta_text, encoding='utf-8')) + yield backend_pb2.Reply( + message=bytes(delta_text, encoding='utf-8'), + tokens=completion_tokens, + prompt_tokens=prompt_tokens, + ) generated_text = text_output if not streaming: - yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) + # Build optional ChatDelta with parsed reasoning / tool calls + chat_deltas = [] + content_text = generated_text + reasoning_text = "" + tool_call_deltas = [] + + if self.reasoning_parser_cls is not None: + try: + parser = self.reasoning_parser_cls(self.tokenizer) if self.tokenizer else self.reasoning_parser_cls() + reasoning_text, content_text = parser.extract_reasoning_content(content_text, request=None) + reasoning_text = reasoning_text or "" + content_text = content_text or "" + except Exception as e: + print(f"reasoning_parser failed: {e}", file=sys.stderr) + + if self.tool_parser_cls is not None: + try: + parser = self.tool_parser_cls(self.tokenizer) if self.tokenizer else self.tool_parser_cls() + tool_info = parser.extract_tool_calls(content_text, request=None) + if getattr(tool_info, "tools_called", False): + content_text = tool_info.content or "" + for tc in tool_info.tool_calls or []: + fn = getattr(tc, "function", None) + tool_call_deltas.append(backend_pb2.ToolCallDelta( + index=getattr(tc, "index", 0) or 0, + id=getattr(tc, "id", "") or "", + name=getattr(fn, "name", "") if fn else "", + arguments=getattr(fn, "arguments", "") if fn else "", + )) + except Exception as e: + print(f"tool_parser failed: {e}", file=sys.stderr) + + if self.tool_parser_cls is not None or self.reasoning_parser_cls is not None: + chat_deltas.append(backend_pb2.ChatDelta( + content=content_text, + reasoning_content=reasoning_text, + tool_calls=tool_call_deltas, + )) + + yield backend_pb2.Reply( + message=bytes(generated_text, encoding='utf-8'), + tokens=completion_tokens, + prompt_tokens=prompt_tokens, + chat_deltas=chat_deltas, + ) except Exception as err: print(f"Error in Predict: {err}", file=sys.stderr) @@ -647,6 +742,37 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): traceback.print_exc() return backend_pb2.Result(success=False, message=f"Error generating TTS: {err}") + def TokenizeString(self, request, context): + if not hasattr(self, 'tokenizer') or self.tokenizer is None: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("Model/tokenizer not loaded") + return backend_pb2.TokenizationResponse() + try: + tokens = self.tokenizer.encode(request.Prompt) + return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens) + except Exception as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + return backend_pb2.TokenizationResponse() + + def Free(self, request, context): + try: + if hasattr(self, 'omni'): + del self.omni + if hasattr(self, 'tokenizer'): + del self.tokenizer + self.tool_parser_cls = None + self.reasoning_parser_cls = None + gc.collect() + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except Exception: + pass + return backend_pb2.Result(success=True, message="Model freed") + except Exception as e: + return backend_pb2.Result(success=False, message=str(e)) + def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), diff --git a/backend/python/vllm/requirements-after.txt b/backend/python/vllm/requirements-after.txt index 76f11f154..b5000e6ca 100644 --- a/backend/python/vllm/requirements-after.txt +++ b/backend/python/vllm/requirements-after.txt @@ -1 +1,2 @@ -vllm \ No newline at end of file +# vllm is installed per-acceleration in requirements-{profile}-after.txt +# (cublas12, hipblas, intel, cpu) diff --git a/backend/python/vllm/requirements-cpu-after.txt b/backend/python/vllm/requirements-cpu-after.txt new file mode 100644 index 000000000..20cf3d395 --- /dev/null +++ b/backend/python/vllm/requirements-cpu-after.txt @@ -0,0 +1 @@ +https://github.com/vllm-project/vllm/releases/download/v0.8.5/vllm-0.8.5+cpu-cp38-abi3-manylinux_2_35_x86_64.whl diff --git a/backend/python/vllm/requirements-cpu.txt b/backend/python/vllm/requirements-cpu.txt index 16c7cbac5..d1e882245 100644 --- a/backend/python/vllm/requirements-cpu.txt +++ b/backend/python/vllm/requirements-cpu.txt @@ -1,3 +1,4 @@ accelerate -torch==2.7.0 -transformers \ No newline at end of file +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.7.0+cpu +transformers diff --git a/backend/python/vllm/requirements-cublas12-after.txt b/backend/python/vllm/requirements-cublas12-after.txt index 9251ba608..cab27c888 100644 --- a/backend/python/vllm/requirements-cublas12-after.txt +++ b/backend/python/vllm/requirements-cublas12-after.txt @@ -1 +1,2 @@ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl +vllm diff --git a/backend/python/vllm/requirements-hipblas-after.txt b/backend/python/vllm/requirements-hipblas-after.txt new file mode 100644 index 000000000..e7a6c7781 --- /dev/null +++ b/backend/python/vllm/requirements-hipblas-after.txt @@ -0,0 +1 @@ +vllm diff --git a/backend/python/vllm/requirements-intel-after.txt b/backend/python/vllm/requirements-intel-after.txt new file mode 100644 index 000000000..e7a6c7781 --- /dev/null +++ b/backend/python/vllm/requirements-intel-after.txt @@ -0,0 +1 @@ +vllm diff --git a/backend/python/vllm/test_cpu_inference.py b/backend/python/vllm/test_cpu_inference.py new file mode 100644 index 000000000..ff606b5bf --- /dev/null +++ b/backend/python/vllm/test_cpu_inference.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +"""End-to-end CPU inference smoke test for the vllm backend. + +Spawns the gRPC backend server, loads a small Qwen model, runs Predict, +TokenizeString, and Free, and verifies non-empty output. + +Usage: + python test_cpu_inference.py [--model MODEL_ID] [--addr HOST:PORT] + +Defaults to Qwen/Qwen2.5-0.5B-Instruct (Qwen3.5-0.6B is not yet published +on the HuggingFace hub at the time of writing). +""" +import argparse +import os +import subprocess +import sys +import time + +import grpc + +# Make sibling backend_pb2 importable +HERE = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, HERE) + +import backend_pb2 +import backend_pb2_grpc + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", default=os.environ.get("TEST_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")) + parser.add_argument("--addr", default="127.0.0.1:50099") + parser.add_argument("--prompt", default="Hello, how are you?") + args = parser.parse_args() + + # Force CPU mode for vLLM + env = os.environ.copy() + env.setdefault("VLLM_TARGET_DEVICE", "cpu") + env.setdefault("VLLM_CPU_KVCACHE_SPACE", "4") + + server_proc = subprocess.Popen( + [sys.executable, os.path.join(HERE, "backend.py"), "--addr", args.addr], + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + + try: + # Wait for the server to come up + deadline = time.time() + 30 + channel = None + while time.time() < deadline: + try: + channel = grpc.insecure_channel(args.addr) + grpc.channel_ready_future(channel).result(timeout=2) + break + except Exception: + time.sleep(0.5) + if channel is None: + raise RuntimeError("backend server did not start in time") + + stub = backend_pb2_grpc.BackendStub(channel) + + print(f"[test] LoadModel({args.model})", flush=True) + load_resp = stub.LoadModel(backend_pb2.ModelOptions( + Model=args.model, + ContextSize=2048, + ), timeout=900) + assert load_resp.success, f"LoadModel failed: {load_resp.message}" + + print(f"[test] Predict prompt={args.prompt!r}", flush=True) + reply = stub.Predict(backend_pb2.PredictOptions( + Prompt=args.prompt, + Tokens=64, + Temperature=0.7, + TopP=0.9, + ), timeout=600) + text = reply.message.decode("utf-8") + print(f"[test] Predict output: {text!r}", flush=True) + assert text.strip(), "Predict returned empty text" + + print("[test] TokenizeString", flush=True) + tok_resp = stub.TokenizeString(backend_pb2.PredictOptions(Prompt="hello world"), timeout=30) + print(f"[test] TokenizeString length={tok_resp.length}", flush=True) + assert tok_resp.length > 0 + + print("[test] Free", flush=True) + free_resp = stub.Free(backend_pb2.MemoryUsageData(), timeout=30) + assert free_resp.success, f"Free failed: {free_resp.message}" + + print("[test] PASS", flush=True) + finally: + server_proc.terminate() + try: + server_proc.wait(timeout=10) + except subprocess.TimeoutExpired: + server_proc.kill() + + +if __name__ == "__main__": + main()