diff --git a/backend/python/vllm/backend.py b/backend/python/vllm/backend.py index 07323c424..14d12af1e 100644 --- a/backend/python/vllm/backend.py +++ b/backend/python/vllm/backend.py @@ -5,6 +5,9 @@ import argparse import signal import sys import os +import json +import time +import gc from typing import List from PIL import Image @@ -26,6 +29,25 @@ from vllm.assets.video import VideoAsset import base64 import io +# Version-compat imports — wrap in try/except for older vLLM versions +try: + from vllm.tool_parsers import ToolParserManager + HAS_TOOL_PARSERS = True +except ImportError: + HAS_TOOL_PARSERS = False + +try: + from vllm.reasoning import ReasoningParserManager + HAS_REASONING_PARSERS = True +except ImportError: + HAS_REASONING_PARSERS = False + +try: + from vllm.sampling_params import GuidedDecodingParams + HAS_GUIDED_DECODING = True +except ImportError: + HAS_GUIDED_DECODING = False + _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 @@ -69,6 +91,35 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): break return decoded_text + def _parse_options(self, options_list): + """Parse Options[] key:value string list into a dict.""" + opts = {} + for opt in options_list: + if ":" not in opt: + continue + key, value = opt.split(":", 1) + opts[key.strip()] = value.strip() + return opts + + def _messages_to_dicts(self, messages): + """Convert proto Messages to list of dicts suitable for apply_chat_template().""" + result = [] + for msg in messages: + d = {"role": msg.role, "content": msg.content or ""} + if msg.name: + d["name"] = msg.name + if msg.tool_call_id: + d["tool_call_id"] = msg.tool_call_id + if msg.reasoning_content: + d["reasoning_content"] = msg.reasoning_content + if msg.tool_calls: + try: + d["tool_calls"] = json.loads(msg.tool_calls) + except json.JSONDecodeError: + pass + result.append(d) + return result + def Health(self, request, context): """ Returns a health check message. @@ -141,6 +192,27 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): ) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + + # Parse options for parser selection + opts = self._parse_options(request.Options) + + # Instantiate tool/reasoning parser classes (they'll be instantiated per-request with tokenizer) + self.tool_parser_cls = None + self.reasoning_parser_cls = None + if HAS_TOOL_PARSERS and opts.get("tool_parser"): + try: + self.tool_parser_cls = ToolParserManager.get_tool_parser(opts["tool_parser"]) + print(f"Loaded tool_parser: {opts['tool_parser']}", file=sys.stderr) + except Exception as e: + print(f"Failed to load tool_parser {opts.get('tool_parser')}: {e}", file=sys.stderr) + + if HAS_REASONING_PARSERS and opts.get("reasoning_parser"): + try: + self.reasoning_parser_cls = ReasoningParserManager.get_reasoning_parser(opts["reasoning_parser"]) + print(f"Loaded reasoning_parser: {opts['reasoning_parser']}", file=sys.stderr) + except Exception as e: + print(f"Failed to load reasoning_parser {opts.get('reasoning_parser')}: {e}", file=sys.stderr) + print("Model loaded successfully", file=sys.stderr) return backend_pb2.Result(message="Model loaded successfully", success=True) @@ -197,6 +269,38 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): finally: await iterations.aclose() + async def TokenizeString(self, request, context): + if not hasattr(self, 'tokenizer') or self.tokenizer is None: + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details("Model/tokenizer not loaded") + return backend_pb2.TokenizationResponse() + try: + tokens = self.tokenizer.encode(request.Prompt) + return backend_pb2.TokenizationResponse(length=len(tokens), tokens=tokens) + except Exception as e: + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(str(e)) + return backend_pb2.TokenizationResponse() + + async def Free(self, request, context): + try: + if hasattr(self, 'llm'): + del self.llm + if hasattr(self, 'tokenizer'): + del self.tokenizer + self.tool_parser_cls = None + self.reasoning_parser_cls = None + gc.collect() + try: + import torch + if torch.cuda.is_available(): + torch.cuda.empty_cache() + except ImportError: + pass + return backend_pb2.Result(success=True, message="Model freed") + except Exception as e: + return backend_pb2.Result(success=False, message=str(e)) + async def _predict(self, request, context, streaming=False): # Build the sampling parameters # NOTE: this must stay in sync with the vllm backend @@ -222,7 +326,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): "SkipSpecialTokens": "skip_special_tokens", "SpacesBetweenSpecialTokens": "spaces_between_special_tokens", "TruncatePromptTokens": "truncate_prompt_tokens", - "GuidedDecoding": "guided_decoding", } sampling_params = SamplingParams(top_p=0.9, max_tokens=200) @@ -233,6 +336,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): if value not in (None, 0, [], False): setattr(sampling_params, param_field, value) + # Guided decoding: use Grammar field to pass JSON schema or BNF + if HAS_GUIDED_DECODING and request.Grammar: + try: + json.loads(request.Grammar) # valid JSON = JSON schema + sampling_params.guided_decoding = GuidedDecodingParams(json=request.Grammar) + except json.JSONDecodeError: + sampling_params.guided_decoding = GuidedDecodingParams(grammar=request.Grammar) + # Extract image paths and process images prompt = request.Prompt @@ -244,7 +355,27 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template if not request.Prompt and request.UseTokenizerTemplate and request.Messages: - prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True) + messages_dicts = self._messages_to_dicts(request.Messages) + template_kwargs = {"tokenize": False, "add_generation_prompt": True} + + # Pass tools for tool calling + if request.Tools: + try: + template_kwargs["tools"] = json.loads(request.Tools) + except json.JSONDecodeError: + pass + + # Enable thinking mode if requested + if request.Metadata.get("enable_thinking", "").lower() == "true": + template_kwargs["enable_thinking"] = True + + try: + prompt = self.tokenizer.apply_chat_template(messages_dicts, **template_kwargs) + except TypeError: + # Some tokenizers don't support tools/enable_thinking kwargs — retry without them + prompt = self.tokenizer.apply_chat_template( + messages_dicts, tokenize=False, add_generation_prompt=True + ) # Generate text using the LLM engine request_id = random_uuid() @@ -265,25 +396,26 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): # Stream the results generated_text = "" + last_output = None try: async for request_output in outputs: iteration_text = request_output.outputs[0].text + last_output = request_output if streaming: # Remove text already sent as vllm concatenates the text from previous yields delta_iteration_text = iteration_text.removeprefix(generated_text) # Send the partial result - yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8')) + yield backend_pb2.Reply( + message=bytes(delta_iteration_text, encoding='utf-8'), + chat_deltas=[backend_pb2.ChatDelta(content=delta_iteration_text)], + ) # Keep track of text generated generated_text = iteration_text finally: await outputs.aclose() - # If streaming, we already sent everything - if streaming: - return - # Remove the image files from /tmp folder for img_path in image_paths: try: @@ -291,8 +423,93 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): except Exception as e: print(f"Error removing image file: {img_path}, {e}", file=sys.stderr) - # Sending the final generated text - yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) + # Parse reasoning and tool calls from final text using vLLM's native parsers + content = generated_text + reasoning_content = "" + tool_calls_proto = [] + + if self.reasoning_parser_cls: + try: + rp = self.reasoning_parser_cls(self.tokenizer) + r, c = rp.extract_reasoning(generated_text, request=None) + reasoning_content = r or "" + content = c if c is not None else generated_text + except Exception as e: + print(f"Reasoning parser error: {e}", file=sys.stderr) + + if self.tool_parser_cls and request.Tools: + try: + tools = json.loads(request.Tools) + tp = self.tool_parser_cls(self.tokenizer, tools=tools) + info = tp.extract_tool_calls(content, request=None) + if info.tools_called: + content = info.content or "" + for i, tc in enumerate(info.tool_calls): + tool_calls_proto.append(backend_pb2.ToolCallDelta( + index=i, + id=tc.id, + name=tc.function.name, + arguments=tc.function.arguments, + )) + except Exception as e: + print(f"Tool parser error: {e}", file=sys.stderr) + + # Extract token counts + prompt_tokens = 0 + completion_tokens = 0 + if last_output is not None: + try: + prompt_tokens = len(last_output.prompt_token_ids or []) + except Exception: + pass + try: + completion_tokens = len(last_output.outputs[0].token_ids or []) + except Exception: + pass + + # Extract logprobs if requested + logprobs_bytes = b"" + if last_output is not None and request.Logprobs > 0: + try: + lp = last_output.outputs[0].logprobs + if lp: + logprobs_data = {"content": []} + for token_lp_dict in lp: + if token_lp_dict: + first_tok_id, first_lp = next(iter(token_lp_dict.items())) + logprobs_data["content"].append({ + "token": getattr(first_lp, "decoded_token", str(first_tok_id)), + "logprob": first_lp.logprob, + }) + logprobs_bytes = json.dumps(logprobs_data).encode("utf-8") + except Exception as e: + print(f"Logprobs extraction error: {e}", file=sys.stderr) + + chat_delta = backend_pb2.ChatDelta( + content=content, + reasoning_content=reasoning_content, + tool_calls=tool_calls_proto, + ) + + if streaming: + # Final chunk with structured data + yield backend_pb2.Reply( + message=b"", + prompt_tokens=prompt_tokens, + tokens=completion_tokens, + chat_deltas=[chat_delta], + logprobs=logprobs_bytes, + ) + return + + # Non-streaming: single Reply with everything + yield backend_pb2.Reply( + message=bytes(content, encoding='utf-8'), + prompt_tokens=prompt_tokens, + tokens=completion_tokens, + chat_deltas=[chat_delta], + logprobs=logprobs_bytes, + ) def load_image(self, image_path: str): """ diff --git a/backend/python/vllm/test.py b/backend/python/vllm/test.py index 827aa71a3..21aaf4cf7 100644 --- a/backend/python/vllm/test.py +++ b/backend/python/vllm/test.py @@ -122,6 +122,89 @@ class TestBackendServicer(unittest.TestCase): self.tearDown() + def test_messages_to_dicts(self): + """ + Tests _messages_to_dicts conversion of proto Messages to dicts. + """ + import sys, os + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from backend import BackendServicer + servicer = BackendServicer() + msgs = [ + backend_pb2.Message(role="user", content="hello"), + backend_pb2.Message( + role="assistant", + content="", + tool_calls='[{"id":"call_1","type":"function","function":{"name":"foo","arguments":"{}"}}]', + reasoning_content="thinking...", + ), + backend_pb2.Message(role="tool", content="result", name="foo", tool_call_id="call_1"), + ] + result = servicer._messages_to_dicts(msgs) + self.assertEqual(len(result), 3) + self.assertEqual(result[0], {"role": "user", "content": "hello"}) + self.assertEqual(result[1]["reasoning_content"], "thinking...") + self.assertIsInstance(result[1]["tool_calls"], list) + self.assertEqual(result[1]["tool_calls"][0]["id"], "call_1") + self.assertEqual(result[2]["tool_call_id"], "call_1") + self.assertEqual(result[2]["name"], "foo") + + def test_parse_options(self): + """ + Tests _parse_options correctly parses key:value strings. + """ + import sys, os + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from backend import BackendServicer + servicer = BackendServicer() + opts = servicer._parse_options([ + "tool_parser:hermes", + "reasoning_parser:deepseek_r1", + "invalid_no_colon", + "key_with_colons:a:b:c", + ]) + self.assertEqual(opts["tool_parser"], "hermes") + self.assertEqual(opts["reasoning_parser"], "deepseek_r1") + self.assertEqual(opts["key_with_colons"], "a:b:c") + self.assertNotIn("invalid_no_colon", opts) + + def test_tokenize_string(self): + """ + Tests the TokenizeString RPC returns valid tokens. + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + self.assertTrue(response.success) + resp = stub.TokenizeString(backend_pb2.PredictOptions(Prompt="Hello world")) + self.assertGreater(resp.length, 0) + self.assertEqual(len(resp.tokens), resp.length) + except Exception as err: + print(err) + self.fail("TokenizeString service failed") + finally: + self.tearDown() + + def test_free(self): + """ + Tests the Free RPC doesn't crash. + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) + self.assertTrue(response.success) + free_resp = stub.Free(backend_pb2.HealthMessage()) + self.assertTrue(free_resp.success) + except Exception as err: + print(err) + self.fail("Free service failed") + finally: + self.tearDown() + def test_embedding(self): """ This method tests if the embeddings are generated successfully