diff --git a/src/exo/shared/types/api.py b/src/exo/shared/types/api.py index 494f02dd..89ec6cdb 100644 --- a/src/exo/shared/types/api.py +++ b/src/exo/shared/types/api.py @@ -239,10 +239,12 @@ class ChatCompletionTaskParams(BaseModel): tool_choice: str | dict[str, Any] | None = None parallel_tool_calls: bool | None = None user: str | None = None + # Internal flag for benchmark mode - set by API, preserved through serialization + bench: bool = False class BenchChatCompletionTaskParams(ChatCompletionTaskParams): - pass + bench: bool = True class PlaceInstanceParams(BaseModel): diff --git a/src/exo/worker/engines/mlx/generator/generate.py b/src/exo/worker/engines/mlx/generator/generate.py index 53b35df4..3e6f9a7a 100644 --- a/src/exo/worker/engines/mlx/generator/generate.py +++ b/src/exo/worker/engines/mlx/generator/generate.py @@ -8,7 +8,6 @@ from mlx_lm.sample_utils import make_sampler from mlx_lm.tokenizer_utils import TokenizerWrapper from exo.shared.types.api import ( - BenchChatCompletionTaskParams, ChatCompletionMessage, FinishReason, GenerationStats, @@ -368,7 +367,7 @@ def mlx_generate( ) -> Generator[GenerationResponse]: # Ensure that generation stats only contains peak memory for this generation mx.reset_peak_memory() - is_bench: bool = isinstance(task, BenchChatCompletionTaskParams) + is_bench: bool = task.bench # Currently we support chat-completion tasks only. logger.debug(f"task_params: {task}")