mirror of
https://github.com/exo-explore/exo.git
synced 2026-03-06 07:06:28 -05:00
Compare commits
7 Commits
main
...
leo/prepar
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3239c55e40 | ||
|
|
725264cc33 | ||
|
|
401ccfbd30 | ||
|
|
06beffe0e2 | ||
|
|
e9193581bc | ||
|
|
69628383c5 | ||
|
|
f77a672126 |
@@ -1,16 +1,20 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from exo_pyo3_bindings import Keypair, NetworkingHandle, NoPeersSubscribedToTopicError
|
from exo_pyo3_bindings import (
|
||||||
|
Keypair,
|
||||||
|
NetworkingHandle,
|
||||||
|
NoPeersSubscribedToTopicError,
|
||||||
|
PyFromSwarm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sleep_on_multiple_items() -> None:
|
async def test_sleep_on_multiple_items() -> None:
|
||||||
print("PYTHON: starting handle")
|
print("PYTHON: starting handle")
|
||||||
h = NetworkingHandle(Keypair.generate_ed25519())
|
h = NetworkingHandle(Keypair.generate())
|
||||||
|
|
||||||
ct = asyncio.create_task(_await_cons(h))
|
rt = asyncio.create_task(_await_recv(h))
|
||||||
mt = asyncio.create_task(_await_msg(h))
|
|
||||||
|
|
||||||
# sleep for 4 ticks
|
# sleep for 4 ticks
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
@@ -22,13 +26,11 @@ async def test_sleep_on_multiple_items() -> None:
|
|||||||
print("caught it", e)
|
print("caught it", e)
|
||||||
|
|
||||||
|
|
||||||
async def _await_cons(h: NetworkingHandle):
|
async def _await_recv(h: NetworkingHandle):
|
||||||
while True:
|
while True:
|
||||||
c = await h.connection_update_recv()
|
event = await h.recv()
|
||||||
print(f"PYTHON: connection update: {c}")
|
match event:
|
||||||
|
case PyFromSwarm.Connection() as c:
|
||||||
|
print(f"PYTHON: connection update: {c}")
|
||||||
async def _await_msg(h: NetworkingHandle):
|
case PyFromSwarm.Message() as m:
|
||||||
while True:
|
print(f"PYTHON: message: {m}")
|
||||||
m = await h.gossipsub_recv()
|
|
||||||
print(f"PYTHON: message: {m}")
|
|
||||||
|
|||||||
@@ -258,6 +258,6 @@ def get_node_id_keypair(
|
|||||||
|
|
||||||
# if no valid credentials, create new ones and persist
|
# if no valid credentials, create new ones and persist
|
||||||
with open(path, "w+b") as f:
|
with open(path, "w+b") as f:
|
||||||
keypair = Keypair.generate_ed25519()
|
keypair = Keypair.generate()
|
||||||
f.write(keypair.to_bytes())
|
f.write(keypair.to_bytes())
|
||||||
return keypair
|
return keypair
|
||||||
|
|||||||
@@ -437,6 +437,7 @@ def mlx_generate(
|
|||||||
group: mx.distributed.Group | None,
|
group: mx.distributed.Group | None,
|
||||||
on_prefill_progress: Callable[[int, int], None] | None = None,
|
on_prefill_progress: Callable[[int, int], None] | None = None,
|
||||||
distributed_prompt_progress_callback: Callable[[], None] | None = None,
|
distributed_prompt_progress_callback: Callable[[], None] | None = None,
|
||||||
|
on_generation_token: Callable[[], None] | None = None,
|
||||||
) -> Generator[GenerationResponse]:
|
) -> Generator[GenerationResponse]:
|
||||||
# Ensure that generation stats only contains peak memory for this generation
|
# Ensure that generation stats only contains peak memory for this generation
|
||||||
mx.reset_peak_memory()
|
mx.reset_peak_memory()
|
||||||
@@ -644,6 +645,9 @@ def mlx_generate(
|
|||||||
full_prompt_tokens, caches, cache_snapshots
|
full_prompt_tokens, caches, cache_snapshots
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if on_generation_token is not None:
|
||||||
|
on_generation_token()
|
||||||
|
|
||||||
yield GenerationResponse(
|
yield GenerationResponse(
|
||||||
text=text,
|
text=text,
|
||||||
token=out.token,
|
token=out.token,
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from exo.download.download_utils import build_model_path
|
|||||||
from exo.shared.types.common import Host
|
from exo.shared.types.common import Host
|
||||||
from exo.shared.types.memory import Memory
|
from exo.shared.types.memory import Memory
|
||||||
from exo.shared.types.mlx import Model
|
from exo.shared.types.mlx import Model
|
||||||
|
from exo.shared.types.tasks import TaskId, TextGeneration
|
||||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||||
from exo.shared.types.worker.instances import (
|
from exo.shared.types.worker.instances import (
|
||||||
BoundInstance,
|
BoundInstance,
|
||||||
@@ -748,3 +749,55 @@ def _parse_kimi_tool_calls(text: str):
|
|||||||
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
|
return [_parse_single_tool(match) for match in tool_matches] # pyright: ignore[reportAny]
|
||||||
else:
|
else:
|
||||||
return [_parse_single_tool(text)]
|
return [_parse_single_tool(text)]
|
||||||
|
|
||||||
|
|
||||||
|
def mx_all_gather_tasks(
|
||||||
|
tasks: list[TextGeneration],
|
||||||
|
group: mx.distributed.Group | None,
|
||||||
|
) -> tuple[list[TextGeneration], list[TextGeneration]]:
|
||||||
|
def encode_task_id(task_id: TaskId) -> list[int]:
|
||||||
|
utf8_task_id = task_id.encode()
|
||||||
|
return [
|
||||||
|
int.from_bytes(utf8_task_id[i : i + 1]) for i in range(len(utf8_task_id))
|
||||||
|
]
|
||||||
|
|
||||||
|
def decode_task_id(encoded_task_id: list[int]) -> TaskId:
|
||||||
|
return TaskId(
|
||||||
|
bytes.decode(b"".join((x).to_bytes(length=1) for x in encoded_task_id))
|
||||||
|
)
|
||||||
|
|
||||||
|
uuid_byte_length = 36
|
||||||
|
|
||||||
|
n_tasks = len(tasks)
|
||||||
|
all_counts = cast(
|
||||||
|
list[int],
|
||||||
|
mx.distributed.all_gather(mx.array([n_tasks]), group=group).tolist(),
|
||||||
|
)
|
||||||
|
max_tasks = max(all_counts)
|
||||||
|
world_size: int = 1 if group is None else group.size()
|
||||||
|
|
||||||
|
if max_tasks == 0:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
padded = [encode_task_id(task.task_id) for task in tasks] + [
|
||||||
|
[0] * uuid_byte_length
|
||||||
|
] * (max_tasks - n_tasks)
|
||||||
|
gathered = cast(
|
||||||
|
list[list[list[int]]],
|
||||||
|
mx.distributed.all_gather(mx.array(padded), group=group)
|
||||||
|
.reshape(world_size, max_tasks, -1)
|
||||||
|
.tolist(),
|
||||||
|
)
|
||||||
|
all_task_ids: list[list[TaskId]] = [
|
||||||
|
[decode_task_id(encoded_task_id) for encoded_task_id in rank_tasks[:count]]
|
||||||
|
for rank_tasks, count in zip(gathered, all_counts, strict=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
agreed_ids: set[TaskId] = set(all_task_ids[0])
|
||||||
|
for rank_tasks in all_task_ids[1:]:
|
||||||
|
agreed_ids &= set(rank_tasks)
|
||||||
|
|
||||||
|
local_tasks = {task.task_id: task for task in tasks}
|
||||||
|
agreed = [local_tasks[tid] for tid in sorted(agreed_ids)]
|
||||||
|
different = [task for task in tasks if task.task_id not in agreed_ids]
|
||||||
|
return agreed, different
|
||||||
|
|||||||
@@ -297,10 +297,10 @@ def _pending_tasks(
|
|||||||
# the task status _should_ be set to completed by the LAST runner
|
# the task status _should_ be set to completed by the LAST runner
|
||||||
# it is currently set by the first
|
# it is currently set by the first
|
||||||
# this is definitely a hack
|
# this is definitely a hack
|
||||||
if task.task_id in runner.completed:
|
if task.task_id in runner.completed or task.task_id in runner.in_progress:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(runner.status, RunnerReady) and all(
|
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -32,11 +32,19 @@ def entrypoint(
|
|||||||
# Import main after setting global logger - this lets us just import logger from this module
|
# Import main after setting global logger - this lets us just import logger from this module
|
||||||
try:
|
try:
|
||||||
if bound_instance.is_image_model:
|
if bound_instance.is_image_model:
|
||||||
from exo.worker.runner.image_models.runner import main
|
from exo.worker.runner.image_models.runner import Runner as ImageRunner
|
||||||
else:
|
|
||||||
from exo.worker.runner.llm_inference.runner import main
|
|
||||||
|
|
||||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
runner = ImageRunner(
|
||||||
|
bound_instance, event_sender, task_receiver, cancel_receiver
|
||||||
|
)
|
||||||
|
runner.main()
|
||||||
|
else:
|
||||||
|
from exo.worker.runner.llm_inference.runner import Runner
|
||||||
|
|
||||||
|
runner = Runner(
|
||||||
|
bound_instance, event_sender, task_receiver, cancel_receiver
|
||||||
|
)
|
||||||
|
runner.main()
|
||||||
|
|
||||||
except ClosedResourceError:
|
except ClosedResourceError:
|
||||||
logger.warning("Runner communication closed unexpectedly")
|
logger.warning("Runner communication closed unexpectedly")
|
||||||
|
|||||||
@@ -182,272 +182,266 @@ def _send_image_chunk(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(
|
class Runner:
|
||||||
bound_instance: BoundInstance,
|
def __init__(
|
||||||
event_sender: MpSender[Event],
|
self,
|
||||||
task_receiver: MpReceiver[Task],
|
bound_instance: BoundInstance,
|
||||||
cancel_receiver: MpReceiver[TaskId],
|
event_sender: MpSender[Event],
|
||||||
):
|
task_receiver: MpReceiver[Task],
|
||||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
cancel_receiver: MpReceiver[TaskId],
|
||||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
|
):
|
||||||
|
self.event_sender = event_sender
|
||||||
|
self.task_receiver = task_receiver
|
||||||
|
self.cancel_receiver = cancel_receiver
|
||||||
|
self.bound_instance = bound_instance
|
||||||
|
|
||||||
instance, runner_id, shard_metadata = (
|
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||||
bound_instance.instance,
|
resource.setrlimit(resource.RLIMIT_NOFILE, (min(max(soft, 2048), hard), hard))
|
||||||
bound_instance.bound_runner_id,
|
|
||||||
bound_instance.bound_shard,
|
|
||||||
)
|
|
||||||
device_rank = shard_metadata.device_rank
|
|
||||||
logger.info("hello from the runner")
|
|
||||||
if getattr(shard_metadata, "immediate_exception", False):
|
|
||||||
raise Exception("Fake exception - runner failed to spin up.")
|
|
||||||
if timeout := getattr(shard_metadata, "should_timeout", 0):
|
|
||||||
time.sleep(timeout)
|
|
||||||
|
|
||||||
setup_start_time = time.time()
|
self.instance, self.runner_id, self.shard_metadata = (
|
||||||
cancelled_tasks = set[TaskId]()
|
bound_instance.instance,
|
||||||
|
bound_instance.bound_runner_id,
|
||||||
|
bound_instance.bound_shard,
|
||||||
|
)
|
||||||
|
self.device_rank = self.shard_metadata.device_rank
|
||||||
|
|
||||||
image_model: DistributedImageModel | None = None
|
logger.info("hello from the runner")
|
||||||
group = None
|
if getattr(self.shard_metadata, "immediate_exception", False):
|
||||||
|
raise Exception("Fake exception - runner failed to spin up.")
|
||||||
|
if timeout := getattr(self.shard_metadata, "should_timeout", 0):
|
||||||
|
time.sleep(timeout)
|
||||||
|
|
||||||
current_status: RunnerStatus = RunnerIdle()
|
self.setup_start_time = time.time()
|
||||||
logger.info("runner created")
|
self.cancelled_tasks = set[TaskId]()
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
self.image_model: DistributedImageModel | None = None
|
||||||
)
|
self.group = None
|
||||||
seen = set[TaskId]()
|
|
||||||
with task_receiver as tasks:
|
self.current_status: RunnerStatus = RunnerIdle()
|
||||||
for task in tasks:
|
logger.info("runner created")
|
||||||
if task.task_id in seen:
|
self.update_status(RunnerIdle())
|
||||||
logger.warning("repeat task - potential error")
|
self.seen = set[TaskId]()
|
||||||
seen.add(task.task_id)
|
|
||||||
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
def update_status(self, status: RunnerStatus):
|
||||||
event_sender.send(
|
self.current_status = status
|
||||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
self.event_sender.send(
|
||||||
|
RunnerStatusUpdated(
|
||||||
|
runner_id=self.runner_id, runner_status=self.current_status
|
||||||
)
|
)
|
||||||
match task:
|
)
|
||||||
case ConnectToGroup() if isinstance(
|
|
||||||
current_status, (RunnerIdle, RunnerFailed)
|
|
||||||
):
|
|
||||||
logger.info("runner connecting")
|
|
||||||
current_status = RunnerConnecting()
|
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(
|
|
||||||
runner_id=runner_id, runner_status=current_status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
|
||||||
group = initialize_mlx(bound_instance)
|
|
||||||
|
|
||||||
logger.info("runner connected")
|
def send_task_status(self, task: Task, status: TaskStatus):
|
||||||
current_status = RunnerConnected()
|
self.event_sender.send(
|
||||||
|
TaskStatusUpdated(task_id=task.task_id, task_status=status)
|
||||||
|
)
|
||||||
|
|
||||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
def acknowledge_task(self, task: Task):
|
||||||
case LoadModel() if (
|
self.event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||||
isinstance(current_status, RunnerConnected) and group is not None
|
|
||||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
|
||||||
current_status = RunnerLoading()
|
|
||||||
logger.info("runner loading")
|
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(
|
|
||||||
runner_id=runner_id, runner_status=current_status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
|
||||||
|
|
||||||
assert (
|
def main(self):
|
||||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
with self.task_receiver as tasks:
|
||||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
for task in tasks:
|
||||||
), f"Incorrect model task(s): {shard_metadata.model_card.tasks}"
|
if task.task_id in self.seen:
|
||||||
|
logger.warning("repeat task - potential error")
|
||||||
image_model = initialize_image_model(bound_instance)
|
self.seen.add(task.task_id)
|
||||||
current_status = RunnerLoaded()
|
self.cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||||
logger.info("runner loaded")
|
self.send_task_status(task, TaskStatus.Running)
|
||||||
|
self.handle_task(task)
|
||||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
was_cancelled = (task.task_id in self.cancelled_tasks) or (
|
||||||
current_status = RunnerWarmingUp()
|
TaskId("CANCEL_CURRENT_TASK") in self.cancelled_tasks
|
||||||
logger.info("runner warming up")
|
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(
|
|
||||||
runner_id=runner_id, runner_status=current_status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
|
||||||
|
|
||||||
logger.info(f"warming up inference for instance: {instance}")
|
|
||||||
|
|
||||||
assert image_model
|
|
||||||
image = warmup_image_generator(model=image_model)
|
|
||||||
if image is not None:
|
|
||||||
logger.info(f"warmed up by generating {image.size} image")
|
|
||||||
else:
|
|
||||||
logger.info("warmup completed (non-primary node)")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_status = RunnerReady()
|
|
||||||
logger.info("runner ready")
|
|
||||||
|
|
||||||
case ImageGeneration(
|
|
||||||
task_params=task_params, command_id=command_id
|
|
||||||
) if isinstance(current_status, RunnerReady):
|
|
||||||
assert image_model
|
|
||||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
|
||||||
current_status = RunnerRunning()
|
|
||||||
logger.info("runner running")
|
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(
|
|
||||||
runner_id=runner_id, runner_status=current_status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
|
||||||
|
|
||||||
try:
|
|
||||||
image_index = 0
|
|
||||||
for response in generate_image(
|
|
||||||
model=image_model, task=task_params
|
|
||||||
):
|
|
||||||
is_primary_output = _is_primary_output_node(shard_metadata)
|
|
||||||
|
|
||||||
if is_primary_output:
|
|
||||||
match response:
|
|
||||||
case PartialImageResponse():
|
|
||||||
logger.info(
|
|
||||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
|
||||||
)
|
|
||||||
_process_image_response(
|
|
||||||
response,
|
|
||||||
command_id,
|
|
||||||
shard_metadata,
|
|
||||||
event_sender,
|
|
||||||
image_index,
|
|
||||||
)
|
|
||||||
case ImageGenerationResponse():
|
|
||||||
logger.info("sending final ImageChunk")
|
|
||||||
_process_image_response(
|
|
||||||
response,
|
|
||||||
command_id,
|
|
||||||
shard_metadata,
|
|
||||||
event_sender,
|
|
||||||
image_index,
|
|
||||||
)
|
|
||||||
image_index += 1
|
|
||||||
# can we make this more explicit?
|
|
||||||
except Exception as e:
|
|
||||||
if _is_primary_output_node(shard_metadata):
|
|
||||||
event_sender.send(
|
|
||||||
ChunkGenerated(
|
|
||||||
command_id=command_id,
|
|
||||||
chunk=ErrorChunk(
|
|
||||||
model=shard_metadata.model_card.model_id,
|
|
||||||
finish_reason="error",
|
|
||||||
error_message=str(e),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
_send_traces_if_enabled(event_sender, task.task_id, device_rank)
|
|
||||||
|
|
||||||
current_status = RunnerReady()
|
|
||||||
logger.info("runner ready")
|
|
||||||
|
|
||||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
|
||||||
isinstance(current_status, RunnerReady)
|
|
||||||
):
|
|
||||||
assert image_model
|
|
||||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
|
||||||
current_status = RunnerRunning()
|
|
||||||
logger.info("runner running")
|
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(
|
|
||||||
runner_id=runner_id, runner_status=current_status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
|
||||||
|
|
||||||
try:
|
|
||||||
image_index = 0
|
|
||||||
for response in generate_image(
|
|
||||||
model=image_model, task=task_params
|
|
||||||
):
|
|
||||||
if _is_primary_output_node(shard_metadata):
|
|
||||||
match response:
|
|
||||||
case PartialImageResponse():
|
|
||||||
logger.info(
|
|
||||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
|
||||||
)
|
|
||||||
_process_image_response(
|
|
||||||
response,
|
|
||||||
command_id,
|
|
||||||
shard_metadata,
|
|
||||||
event_sender,
|
|
||||||
image_index,
|
|
||||||
)
|
|
||||||
case ImageGenerationResponse():
|
|
||||||
logger.info("sending final ImageChunk")
|
|
||||||
_process_image_response(
|
|
||||||
response,
|
|
||||||
command_id,
|
|
||||||
shard_metadata,
|
|
||||||
event_sender,
|
|
||||||
image_index,
|
|
||||||
)
|
|
||||||
image_index += 1
|
|
||||||
except Exception as e:
|
|
||||||
if _is_primary_output_node(shard_metadata):
|
|
||||||
event_sender.send(
|
|
||||||
ChunkGenerated(
|
|
||||||
command_id=command_id,
|
|
||||||
chunk=ErrorChunk(
|
|
||||||
model=shard_metadata.model_card.model_id,
|
|
||||||
finish_reason="error",
|
|
||||||
error_message=str(e),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
_send_traces_if_enabled(event_sender, task.task_id, device_rank)
|
|
||||||
|
|
||||||
current_status = RunnerReady()
|
|
||||||
logger.info("runner ready")
|
|
||||||
|
|
||||||
case Shutdown():
|
|
||||||
current_status = RunnerShuttingDown()
|
|
||||||
logger.info("runner shutting down")
|
|
||||||
if not TYPE_CHECKING:
|
|
||||||
del image_model, group
|
|
||||||
mx.clear_cache()
|
|
||||||
import gc
|
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
event_sender.send(
|
|
||||||
RunnerStatusUpdated(
|
|
||||||
runner_id=runner_id, runner_status=current_status
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
|
||||||
|
|
||||||
current_status = RunnerShutdown()
|
|
||||||
case _:
|
|
||||||
raise ValueError(
|
|
||||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
|
||||||
)
|
|
||||||
was_cancelled = (task.task_id in cancelled_tasks) or (
|
|
||||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
|
||||||
)
|
|
||||||
if not was_cancelled:
|
|
||||||
event_sender.send(
|
|
||||||
TaskStatusUpdated(
|
|
||||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
event_sender.send(
|
if not was_cancelled:
|
||||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
self.send_task_status(task, TaskStatus.Complete)
|
||||||
)
|
self.update_status(self.current_status)
|
||||||
|
|
||||||
if isinstance(current_status, RunnerShutdown):
|
if isinstance(self.current_status, RunnerShutdown):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def handle_task(self, task: Task):
|
||||||
|
match task:
|
||||||
|
case ConnectToGroup() if isinstance(
|
||||||
|
self.current_status, (RunnerIdle, RunnerFailed)
|
||||||
|
):
|
||||||
|
logger.info("runner connecting")
|
||||||
|
self.update_status(RunnerConnecting())
|
||||||
|
self.acknowledge_task(task)
|
||||||
|
self.group = initialize_mlx(self.bound_instance)
|
||||||
|
|
||||||
|
logger.info("runner connected")
|
||||||
|
self.current_status = RunnerConnected()
|
||||||
|
|
||||||
|
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||||
|
case LoadModel() if (
|
||||||
|
isinstance(self.current_status, RunnerConnected)
|
||||||
|
and self.group is not None
|
||||||
|
) or (isinstance(self.current_status, RunnerIdle) and self.group is None):
|
||||||
|
logger.info("runner loading")
|
||||||
|
self.update_status(RunnerLoading())
|
||||||
|
self.acknowledge_task(task)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
ModelTask.TextToImage in self.shard_metadata.model_card.tasks
|
||||||
|
or ModelTask.ImageToImage in self.shard_metadata.model_card.tasks
|
||||||
|
), f"Incorrect model task(s): {self.shard_metadata.model_card.tasks}"
|
||||||
|
|
||||||
|
self.image_model = initialize_image_model(self.bound_instance)
|
||||||
|
self.current_status = RunnerLoaded()
|
||||||
|
logger.info("runner loaded")
|
||||||
|
|
||||||
|
case StartWarmup() if isinstance(self.current_status, RunnerLoaded):
|
||||||
|
logger.info("runner warming up")
|
||||||
|
self.update_status(RunnerWarmingUp())
|
||||||
|
self.acknowledge_task(task)
|
||||||
|
|
||||||
|
logger.info(f"warming up inference for instance: {self.instance}")
|
||||||
|
|
||||||
|
assert self.image_model
|
||||||
|
image = warmup_image_generator(model=self.image_model)
|
||||||
|
if image is not None:
|
||||||
|
logger.info(f"warmed up by generating {image.size} image")
|
||||||
|
else:
|
||||||
|
logger.info("warmup completed (non-primary node)")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"runner initialized in {time.time() - self.setup_start_time} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_status = RunnerReady()
|
||||||
|
logger.info("runner ready")
|
||||||
|
|
||||||
|
case ImageGeneration(task_params=task_params, command_id=command_id) if (
|
||||||
|
isinstance(self.current_status, RunnerReady)
|
||||||
|
):
|
||||||
|
assert self.image_model
|
||||||
|
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||||
|
logger.info("runner running")
|
||||||
|
self.update_status(RunnerRunning())
|
||||||
|
self.acknowledge_task(task)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_index = 0
|
||||||
|
for response in generate_image(
|
||||||
|
model=self.image_model, task=task_params
|
||||||
|
):
|
||||||
|
is_primary_output = _is_primary_output_node(self.shard_metadata)
|
||||||
|
|
||||||
|
if is_primary_output:
|
||||||
|
match response:
|
||||||
|
case PartialImageResponse():
|
||||||
|
logger.info(
|
||||||
|
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||||
|
)
|
||||||
|
_process_image_response(
|
||||||
|
response,
|
||||||
|
command_id,
|
||||||
|
self.shard_metadata,
|
||||||
|
self.event_sender,
|
||||||
|
image_index,
|
||||||
|
)
|
||||||
|
case ImageGenerationResponse():
|
||||||
|
logger.info("sending final ImageChunk")
|
||||||
|
_process_image_response(
|
||||||
|
response,
|
||||||
|
command_id,
|
||||||
|
self.shard_metadata,
|
||||||
|
self.event_sender,
|
||||||
|
image_index,
|
||||||
|
)
|
||||||
|
image_index += 1
|
||||||
|
# can we make this more explicit?
|
||||||
|
except Exception as e:
|
||||||
|
if _is_primary_output_node(self.shard_metadata):
|
||||||
|
self.event_sender.send(
|
||||||
|
ChunkGenerated(
|
||||||
|
command_id=command_id,
|
||||||
|
chunk=ErrorChunk(
|
||||||
|
model=self.shard_metadata.model_card.model_id,
|
||||||
|
finish_reason="error",
|
||||||
|
error_message=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
_send_traces_if_enabled(
|
||||||
|
self.event_sender, task.task_id, self.device_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_status = RunnerReady()
|
||||||
|
logger.info("runner ready")
|
||||||
|
|
||||||
|
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||||
|
isinstance(self.current_status, RunnerReady)
|
||||||
|
):
|
||||||
|
assert self.image_model
|
||||||
|
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||||
|
logger.info("runner running")
|
||||||
|
self.update_status(RunnerRunning())
|
||||||
|
self.acknowledge_task(task)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image_index = 0
|
||||||
|
for response in generate_image(
|
||||||
|
model=self.image_model, task=task_params
|
||||||
|
):
|
||||||
|
if _is_primary_output_node(self.shard_metadata):
|
||||||
|
match response:
|
||||||
|
case PartialImageResponse():
|
||||||
|
logger.info(
|
||||||
|
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||||
|
)
|
||||||
|
_process_image_response(
|
||||||
|
response,
|
||||||
|
command_id,
|
||||||
|
self.shard_metadata,
|
||||||
|
self.event_sender,
|
||||||
|
image_index,
|
||||||
|
)
|
||||||
|
case ImageGenerationResponse():
|
||||||
|
logger.info("sending final ImageChunk")
|
||||||
|
_process_image_response(
|
||||||
|
response,
|
||||||
|
command_id,
|
||||||
|
self.shard_metadata,
|
||||||
|
self.event_sender,
|
||||||
|
image_index,
|
||||||
|
)
|
||||||
|
image_index += 1
|
||||||
|
except Exception as e:
|
||||||
|
if _is_primary_output_node(self.shard_metadata):
|
||||||
|
self.event_sender.send(
|
||||||
|
ChunkGenerated(
|
||||||
|
command_id=command_id,
|
||||||
|
chunk=ErrorChunk(
|
||||||
|
model=self.shard_metadata.model_card.model_id,
|
||||||
|
finish_reason="error",
|
||||||
|
error_message=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
_send_traces_if_enabled(
|
||||||
|
self.event_sender, task.task_id, self.device_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
self.current_status = RunnerReady()
|
||||||
|
logger.info("runner ready")
|
||||||
|
|
||||||
|
case Shutdown():
|
||||||
|
logger.info("runner shutting down")
|
||||||
|
if not TYPE_CHECKING:
|
||||||
|
del self.image_model, self.group
|
||||||
|
mx.clear_cache()
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
self.update_status(RunnerShuttingDown())
|
||||||
|
self.acknowledge_task(task)
|
||||||
|
|
||||||
|
self.current_status = RunnerShutdown()
|
||||||
|
case _:
|
||||||
|
raise ValueError(
|
||||||
|
f"Received {task.__class__.__name__} outside of state machine in {self.current_status=}"
|
||||||
|
)
|
||||||
|
|||||||
293
src/exo/worker/runner/llm_inference/batch_generator.py
Normal file
293
src/exo/worker/runner/llm_inference/batch_generator.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import deque
|
||||||
|
from collections.abc import Generator, Iterable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
|
from exo.shared.types.chunks import ErrorChunk, PrefillProgressChunk
|
||||||
|
from exo.shared.types.common import ModelId
|
||||||
|
from exo.shared.types.events import ChunkGenerated, Event
|
||||||
|
from exo.shared.types.mlx import Model
|
||||||
|
from exo.shared.types.tasks import TaskId, TextGeneration
|
||||||
|
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||||
|
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||||
|
from exo.utils.channels import MpReceiver, MpSender
|
||||||
|
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||||
|
from exo.worker.engines.mlx.generator.generate import (
|
||||||
|
PrefillCancelled,
|
||||||
|
mlx_generate,
|
||||||
|
warmup_inference,
|
||||||
|
)
|
||||||
|
from exo.worker.engines.mlx.utils_mlx import (
|
||||||
|
apply_chat_template,
|
||||||
|
mx_all_gather_tasks,
|
||||||
|
mx_any,
|
||||||
|
)
|
||||||
|
from exo.worker.runner.bootstrap import logger
|
||||||
|
|
||||||
|
from .model_output_parsers import apply_all_parsers
|
||||||
|
from .tool_parsers import ToolParser
|
||||||
|
|
||||||
|
|
||||||
|
class Cancelled:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Finished:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorQueue[T]:
|
||||||
|
def __init__(self):
|
||||||
|
self._q = deque[T]()
|
||||||
|
|
||||||
|
def push(self, t: T):
|
||||||
|
self._q.append(t)
|
||||||
|
|
||||||
|
def gen(self) -> Generator[T | None]:
|
||||||
|
while True:
|
||||||
|
if len(self._q) == 0:
|
||||||
|
yield None
|
||||||
|
else:
|
||||||
|
yield self._q.popleft()
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceGenerator(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def warmup(self) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def submit(
|
||||||
|
self,
|
||||||
|
task: TextGeneration,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
) -> Iterable[
|
||||||
|
tuple[TaskId, ToolCallResponse | GenerationResponse | Cancelled | Finished]
|
||||||
|
]: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||||
|
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||||
|
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||||
|
|
||||||
|
|
||||||
|
def _check_for_debug_prompts(task_params: TextGenerationTaskParams) -> None:
|
||||||
|
"""Check for debug prompt triggers in the input."""
|
||||||
|
from exo.worker.engines.mlx.utils_mlx import mlx_force_oom
|
||||||
|
|
||||||
|
if len(task_params.input) == 0:
|
||||||
|
return
|
||||||
|
prompt = task_params.input[0].content
|
||||||
|
if not prompt:
|
||||||
|
return
|
||||||
|
if EXO_RUNNER_MUST_FAIL in prompt:
|
||||||
|
raise Exception("Artificial runner exception - for testing purposes only.")
|
||||||
|
if EXO_RUNNER_MUST_OOM in prompt:
|
||||||
|
mlx_force_oom()
|
||||||
|
if EXO_RUNNER_MUST_TIMEOUT in prompt:
|
||||||
|
time.sleep(100)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class SequentialGenerator(InferenceGenerator):
|
||||||
|
model: Model
|
||||||
|
tokenizer: TokenizerWrapper
|
||||||
|
group: mx.distributed.Group | None
|
||||||
|
kv_prefix_cache: KVPrefixCache | None
|
||||||
|
tool_parser: ToolParser | None
|
||||||
|
model_id: ModelId
|
||||||
|
device_rank: int
|
||||||
|
cancel_receiver: MpReceiver[TaskId]
|
||||||
|
event_sender: MpSender[Event]
|
||||||
|
check_for_cancel_every: int = 50
|
||||||
|
|
||||||
|
_cancelled_tasks: set[TaskId] = field(default_factory=set, init=False)
|
||||||
|
_maybe_queue: list[TextGeneration] = field(default_factory=list, init=False)
|
||||||
|
_queue: deque[TextGeneration] = field(default_factory=deque, init=False)
|
||||||
|
_active: (
|
||||||
|
tuple[
|
||||||
|
TextGeneration,
|
||||||
|
# mlx generator that does work
|
||||||
|
Generator[GenerationResponse],
|
||||||
|
# queue that the 1st generator should push to and 3rd generator should pull from
|
||||||
|
GeneratorQueue[GenerationResponse],
|
||||||
|
# generator to get parsed outputs
|
||||||
|
Generator[GenerationResponse | ToolCallResponse | None],
|
||||||
|
]
|
||||||
|
| None
|
||||||
|
) = field(default=None, init=False)
|
||||||
|
|
||||||
|
def warmup(self):
|
||||||
|
logger.info(f"warming up inference for instance: {self.model_id}")
|
||||||
|
|
||||||
|
t = time.monotonic()
|
||||||
|
toks = warmup_inference(
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
group=self.group,
|
||||||
|
)
|
||||||
|
logger.info(f"warmed up by generating {toks} tokens")
|
||||||
|
check_for_cancel_every = min(
|
||||||
|
math.ceil(toks / min(time.monotonic() - t, 0.001)), 100
|
||||||
|
)
|
||||||
|
if self.group is not None:
|
||||||
|
self.check_for_cancel_every = int(
|
||||||
|
mx.max(
|
||||||
|
mx.distributed.all_gather(
|
||||||
|
mx.array([check_for_cancel_every]),
|
||||||
|
group=self.group,
|
||||||
|
)
|
||||||
|
).item()
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"runner checking for cancellation every {check_for_cancel_every} tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
def submit(
|
||||||
|
self,
|
||||||
|
task: TextGeneration,
|
||||||
|
) -> None:
|
||||||
|
self._cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||||
|
self._maybe_queue.append(task)
|
||||||
|
|
||||||
|
def agree_on_tasks(self) -> None:
|
||||||
|
"""Agree between all ranks about the task ordering (some may have received in different order or not at all)."""
|
||||||
|
agreed, different = mx_all_gather_tasks(self._maybe_queue, self.group)
|
||||||
|
self._queue.extend(task for task in self._maybe_queue if task in agreed)
|
||||||
|
self._maybe_queue = [task for task in self._maybe_queue if task in different]
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
) -> Iterable[
|
||||||
|
tuple[TaskId, GenerationResponse | ToolCallResponse | Cancelled | Finished]
|
||||||
|
]:
|
||||||
|
if self._active is None:
|
||||||
|
self.agree_on_tasks()
|
||||||
|
|
||||||
|
if self._queue:
|
||||||
|
self._start_next()
|
||||||
|
else:
|
||||||
|
return map(lambda task: (task, Cancelled()), self._cancelled_tasks)
|
||||||
|
|
||||||
|
assert self._active is not None
|
||||||
|
|
||||||
|
task, mlx_gen, queue, output_generator = self._active
|
||||||
|
response = None
|
||||||
|
try:
|
||||||
|
queue.push(next(mlx_gen))
|
||||||
|
response = next(output_generator)
|
||||||
|
except (StopIteration, PrefillCancelled):
|
||||||
|
response = Finished()
|
||||||
|
self._active = None
|
||||||
|
if self._queue:
|
||||||
|
self._start_next()
|
||||||
|
except Exception as e:
|
||||||
|
self._send_error(task, e)
|
||||||
|
self._active = None
|
||||||
|
raise
|
||||||
|
return itertools.chain(
|
||||||
|
[] if response is None else [(task.task_id, response)],
|
||||||
|
map(lambda task: (task, Cancelled()), self._cancelled_tasks),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _start_next(self) -> None:
|
||||||
|
task = self._queue.popleft()
|
||||||
|
try:
|
||||||
|
mlx_gen = self._build_generator(task)
|
||||||
|
except Exception as e:
|
||||||
|
self._send_error(task, e)
|
||||||
|
raise
|
||||||
|
queue = GeneratorQueue[GenerationResponse]()
|
||||||
|
output_generator = apply_all_parsers(
|
||||||
|
queue.gen(),
|
||||||
|
apply_chat_template(self.tokenizer, task.task_params),
|
||||||
|
self.tool_parser,
|
||||||
|
self.tokenizer,
|
||||||
|
type(self.model),
|
||||||
|
self.model_id,
|
||||||
|
)
|
||||||
|
self._active = (task, mlx_gen, queue, output_generator)
|
||||||
|
|
||||||
|
def _send_error(self, task: TextGeneration, e: Exception) -> None:
|
||||||
|
if self.device_rank == 0:
|
||||||
|
self.event_sender.send(
|
||||||
|
ChunkGenerated(
|
||||||
|
command_id=task.command_id,
|
||||||
|
chunk=ErrorChunk(
|
||||||
|
model=self.model_id,
|
||||||
|
finish_reason="error",
|
||||||
|
error_message=str(e),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_generator(self, task: TextGeneration) -> Generator[GenerationResponse]:
|
||||||
|
_check_for_debug_prompts(task.task_params)
|
||||||
|
prompt = apply_chat_template(self.tokenizer, task.task_params)
|
||||||
|
|
||||||
|
def on_prefill_progress(processed: int, total: int) -> None:
|
||||||
|
if self.device_rank == 0:
|
||||||
|
self.event_sender.send(
|
||||||
|
ChunkGenerated(
|
||||||
|
command_id=task.command_id,
|
||||||
|
chunk=PrefillProgressChunk(
|
||||||
|
model=self.model_id,
|
||||||
|
processed_tokens=processed,
|
||||||
|
total_tokens=total,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def distributed_prompt_progress_callback() -> None:
|
||||||
|
self._cancelled_tasks.update(self.cancel_receiver.collect())
|
||||||
|
want_to_cancel = (task.task_id in self._cancelled_tasks) or (
|
||||||
|
TaskId("CANCEL_CURRENT_TASK") in self._cancelled_tasks
|
||||||
|
)
|
||||||
|
if mx_any(want_to_cancel, self.group):
|
||||||
|
raise PrefillCancelled()
|
||||||
|
|
||||||
|
self.agree_on_tasks()
|
||||||
|
|
||||||
|
tokens_since_cancel_check = self.check_for_cancel_every
|
||||||
|
|
||||||
|
def on_generation_token() -> None:
|
||||||
|
nonlocal tokens_since_cancel_check
|
||||||
|
tokens_since_cancel_check += 1
|
||||||
|
if tokens_since_cancel_check >= self.check_for_cancel_every:
|
||||||
|
tokens_since_cancel_check = 0
|
||||||
|
self._cancelled_tasks.update(self.cancel_receiver.collect())
|
||||||
|
want_to_cancel = (task.task_id in self._cancelled_tasks) or (
|
||||||
|
TaskId("CANCEL_CURRENT_TASK") in self._cancelled_tasks
|
||||||
|
)
|
||||||
|
if mx_any(want_to_cancel, self.group):
|
||||||
|
raise PrefillCancelled()
|
||||||
|
|
||||||
|
self.agree_on_tasks()
|
||||||
|
|
||||||
|
return mlx_generate(
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
task=task.task_params,
|
||||||
|
prompt=prompt,
|
||||||
|
kv_prefix_cache=self.kv_prefix_cache,
|
||||||
|
on_prefill_progress=on_prefill_progress,
|
||||||
|
distributed_prompt_progress_callback=distributed_prompt_progress_callback,
|
||||||
|
on_generation_token=on_generation_token,
|
||||||
|
group=self.group,
|
||||||
|
)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
del self.model, self.tokenizer, self.group
|
||||||
376
src/exo/worker/runner/llm_inference/model_output_parsers.py
Normal file
376
src/exo/worker/runner/llm_inference/model_output_parsers.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||||
|
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||||
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||||
|
HarmonyEncodingName,
|
||||||
|
HarmonyError, # pyright: ignore[reportUnknownVariableType]
|
||||||
|
Role,
|
||||||
|
StreamableParser,
|
||||||
|
load_harmony_encoding,
|
||||||
|
)
|
||||||
|
|
||||||
|
from exo.shared.types.api import ToolCallItem
|
||||||
|
from exo.shared.types.common import ModelId
|
||||||
|
from exo.shared.types.mlx import Model
|
||||||
|
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||||
|
from exo.worker.engines.mlx.utils_mlx import (
|
||||||
|
detect_thinking_prompt_suffix,
|
||||||
|
)
|
||||||
|
from exo.worker.runner.bootstrap import logger
|
||||||
|
|
||||||
|
from .tool_parsers import ToolParser
|
||||||
|
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def get_gpt_oss_encoding():
|
||||||
|
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||||
|
return encoding
|
||||||
|
|
||||||
|
|
||||||
|
def apply_all_parsers(
|
||||||
|
receiver: Generator[GenerationResponse | None],
|
||||||
|
prompt: str,
|
||||||
|
tool_parser: ToolParser | None,
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
|
model_type: type[Model],
|
||||||
|
model_id: ModelId,
|
||||||
|
) -> Generator[GenerationResponse | ToolCallResponse | None]:
|
||||||
|
mlx_generator = receiver
|
||||||
|
|
||||||
|
if tokenizer.has_thinking:
|
||||||
|
mlx_generator = parse_thinking_models(
|
||||||
|
mlx_generator,
|
||||||
|
tokenizer,
|
||||||
|
starts_in_thinking=detect_thinking_prompt_suffix(prompt, tokenizer),
|
||||||
|
)
|
||||||
|
|
||||||
|
if issubclass(model_type, GptOssModel):
|
||||||
|
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||||
|
elif (
|
||||||
|
issubclass(model_type, DeepseekV32Model)
|
||||||
|
and "deepseek" in model_id.normalize().lower()
|
||||||
|
):
|
||||||
|
mlx_generator = parse_deepseek_v32(mlx_generator)
|
||||||
|
elif tool_parser:
|
||||||
|
mlx_generator = parse_tool_calls(mlx_generator, tool_parser)
|
||||||
|
|
||||||
|
return mlx_generator
|
||||||
|
|
||||||
|
|
||||||
|
def parse_gpt_oss(
|
||||||
|
responses: Generator[GenerationResponse | None],
|
||||||
|
) -> Generator[GenerationResponse | ToolCallResponse | None]:
|
||||||
|
encoding = get_gpt_oss_encoding()
|
||||||
|
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||||
|
thinking = False
|
||||||
|
current_tool_name: str | None = None
|
||||||
|
tool_arg_parts: list[str] = []
|
||||||
|
|
||||||
|
for response in responses:
|
||||||
|
if response is None:
|
||||||
|
yield None
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
stream.process(response.token)
|
||||||
|
except HarmonyError:
|
||||||
|
logger.error("Encountered critical Harmony Error, returning early")
|
||||||
|
return
|
||||||
|
|
||||||
|
delta = stream.last_content_delta
|
||||||
|
ch = stream.current_channel
|
||||||
|
recipient = stream.current_recipient
|
||||||
|
|
||||||
|
# Debug: log every token with state
|
||||||
|
logger.debug(
|
||||||
|
f"parse_gpt_oss token={response.token} text={response.text!r} "
|
||||||
|
f"recipient={recipient!r} ch={ch!r} delta={delta!r} "
|
||||||
|
f"state={stream.state} current_tool={current_tool_name!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if recipient != current_tool_name:
|
||||||
|
if current_tool_name is not None:
|
||||||
|
prefix = "functions."
|
||||||
|
if current_tool_name.startswith(prefix):
|
||||||
|
current_tool_name = current_tool_name[len(prefix) :]
|
||||||
|
logger.info(
|
||||||
|
f"parse_gpt_oss yielding tool call: name={current_tool_name!r}"
|
||||||
|
)
|
||||||
|
yield ToolCallResponse(
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallItem(
|
||||||
|
name=current_tool_name,
|
||||||
|
arguments="".join(tool_arg_parts).strip(),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage=response.usage,
|
||||||
|
)
|
||||||
|
tool_arg_parts = []
|
||||||
|
current_tool_name = recipient
|
||||||
|
|
||||||
|
# If inside a tool call, accumulate arguments
|
||||||
|
if current_tool_name is not None:
|
||||||
|
if delta:
|
||||||
|
tool_arg_parts.append(delta)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ch == "analysis" and not thinking:
|
||||||
|
thinking = True
|
||||||
|
|
||||||
|
if ch != "analysis" and thinking:
|
||||||
|
thinking = False
|
||||||
|
|
||||||
|
if delta:
|
||||||
|
yield response.model_copy(update={"text": delta, "is_thinking": thinking})
|
||||||
|
|
||||||
|
if response.finish_reason is not None:
|
||||||
|
yield response
|
||||||
|
|
||||||
|
|
||||||
|
def parse_deepseek_v32(
|
||||||
|
responses: Generator[GenerationResponse | None],
|
||||||
|
) -> Generator[GenerationResponse | ToolCallResponse | None]:
|
||||||
|
"""Parse DeepSeek V3.2 DSML tool calls from the generation stream.
|
||||||
|
|
||||||
|
Uses accumulated-text matching (not per-token marker checks) because
|
||||||
|
DSML markers like <|DSML|function_calls> may span multiple tokens.
|
||||||
|
Also handles <think>...</think> blocks for thinking mode.
|
||||||
|
"""
|
||||||
|
from exo.worker.engines.mlx.dsml_encoding import (
|
||||||
|
THINKING_END,
|
||||||
|
THINKING_START,
|
||||||
|
TOOL_CALLS_END,
|
||||||
|
TOOL_CALLS_START,
|
||||||
|
parse_dsml_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
accumulated = ""
|
||||||
|
in_tool_call = False
|
||||||
|
thinking = False
|
||||||
|
# Tokens buffered while we detect the start of a DSML block
|
||||||
|
pending_buffer: list[GenerationResponse] = []
|
||||||
|
# Text accumulated during a tool call block
|
||||||
|
tool_call_text = ""
|
||||||
|
|
||||||
|
for response in responses:
|
||||||
|
if response is None:
|
||||||
|
yield None
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ── Handle thinking tags ──
|
||||||
|
if not thinking and THINKING_START in response.text:
|
||||||
|
thinking = True
|
||||||
|
# Yield any text before the <think> tag
|
||||||
|
before = response.text[: response.text.index(THINKING_START)]
|
||||||
|
if before:
|
||||||
|
yield response.model_copy(update={"text": before})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if thinking and THINKING_END in response.text:
|
||||||
|
thinking = False
|
||||||
|
# Yield any text after the </think> tag
|
||||||
|
after = response.text[
|
||||||
|
response.text.index(THINKING_END) + len(THINKING_END) :
|
||||||
|
]
|
||||||
|
if after:
|
||||||
|
yield response.model_copy(update={"text": after, "is_thinking": False})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if thinking:
|
||||||
|
yield response.model_copy(update={"is_thinking": True})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ── Handle tool call accumulation ──
|
||||||
|
if in_tool_call:
|
||||||
|
tool_call_text += response.text
|
||||||
|
if TOOL_CALLS_END in tool_call_text:
|
||||||
|
# Parse the accumulated DSML block
|
||||||
|
parsed = parse_dsml_output(tool_call_text)
|
||||||
|
if parsed is not None:
|
||||||
|
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||||
|
yield ToolCallResponse(
|
||||||
|
tool_calls=parsed,
|
||||||
|
usage=response.usage,
|
||||||
|
stats=response.stats,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||||
|
)
|
||||||
|
yield response.model_copy(update={"text": tool_call_text})
|
||||||
|
in_tool_call = False
|
||||||
|
tool_call_text = ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
# EOS reached before end marker — yield buffered text as-is
|
||||||
|
if response.finish_reason is not None:
|
||||||
|
logger.info("DSML tool call parsing interrupted by EOS")
|
||||||
|
yield response.model_copy(update={"text": tool_call_text})
|
||||||
|
in_tool_call = False
|
||||||
|
tool_call_text = ""
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ── Detect start of tool call block ──
|
||||||
|
accumulated += response.text
|
||||||
|
|
||||||
|
if TOOL_CALLS_START in accumulated:
|
||||||
|
# The start marker might be split across pending_buffer + current token
|
||||||
|
start_idx = accumulated.index(TOOL_CALLS_START)
|
||||||
|
# Yield any pending tokens that are purely before the marker
|
||||||
|
pre_text = accumulated[:start_idx]
|
||||||
|
if pre_text:
|
||||||
|
# Flush pending buffer tokens that contributed text before the marker
|
||||||
|
for buf_resp in pending_buffer:
|
||||||
|
if pre_text:
|
||||||
|
chunk = buf_resp.text
|
||||||
|
if len(chunk) <= len(pre_text):
|
||||||
|
yield buf_resp
|
||||||
|
pre_text = pre_text[len(chunk) :]
|
||||||
|
else:
|
||||||
|
yield buf_resp.model_copy(update={"text": pre_text})
|
||||||
|
pre_text = ""
|
||||||
|
pending_buffer = []
|
||||||
|
tool_call_text = accumulated[start_idx:]
|
||||||
|
accumulated = ""
|
||||||
|
|
||||||
|
# Check if the end marker is already present (entire tool call in one token)
|
||||||
|
if TOOL_CALLS_END in tool_call_text:
|
||||||
|
parsed = parse_dsml_output(tool_call_text)
|
||||||
|
if parsed is not None:
|
||||||
|
logger.info(f"parsed DSML tool calls: {parsed}")
|
||||||
|
yield ToolCallResponse(
|
||||||
|
tool_calls=parsed,
|
||||||
|
usage=response.usage,
|
||||||
|
stats=response.stats,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"DSML tool call parsing failed for: {tool_call_text}"
|
||||||
|
)
|
||||||
|
yield response.model_copy(update={"text": tool_call_text})
|
||||||
|
tool_call_text = ""
|
||||||
|
else:
|
||||||
|
in_tool_call = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if accumulated text might be the start of a DSML marker
|
||||||
|
# Buffer tokens if we see a partial match at the end
|
||||||
|
if _could_be_dsml_prefix(accumulated):
|
||||||
|
pending_buffer.append(response)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# No partial match — flush all pending tokens and the current one
|
||||||
|
for buf_resp in pending_buffer:
|
||||||
|
yield buf_resp
|
||||||
|
pending_buffer = []
|
||||||
|
accumulated = ""
|
||||||
|
yield response
|
||||||
|
|
||||||
|
# Flush any remaining pending buffer at generator end
|
||||||
|
for buf_resp in pending_buffer:
|
||||||
|
yield buf_resp
|
||||||
|
|
||||||
|
|
||||||
|
def _could_be_dsml_prefix(text: str) -> bool:
|
||||||
|
"""Check if the end of text could be the start of a DSML function_calls marker.
|
||||||
|
|
||||||
|
We look for suffixes of text that are prefixes of the TOOL_CALLS_START pattern.
|
||||||
|
This allows us to buffer tokens until we can determine if a tool call is starting.
|
||||||
|
"""
|
||||||
|
from exo.worker.engines.mlx.dsml_encoding import TOOL_CALLS_START
|
||||||
|
|
||||||
|
# Only check the last portion of text that could overlap with the marker
|
||||||
|
max_check = len(TOOL_CALLS_START)
|
||||||
|
tail = text[-max_check:] if len(text) > max_check else text
|
||||||
|
|
||||||
|
# Check if any suffix of tail is a prefix of TOOL_CALLS_START
|
||||||
|
for i in range(len(tail)):
|
||||||
|
suffix = tail[i:]
|
||||||
|
if TOOL_CALLS_START.startswith(suffix):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def parse_thinking_models(
|
||||||
|
responses: Generator[GenerationResponse | None],
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
|
starts_in_thinking: bool = True,
|
||||||
|
) -> Generator[GenerationResponse | None]:
|
||||||
|
"""Route thinking tokens via is_thinking flag.
|
||||||
|
|
||||||
|
Swallows think tag tokens, sets is_thinking on all others.
|
||||||
|
Always yields tokens with finish_reason to avoid hanging the chunk stream.
|
||||||
|
"""
|
||||||
|
in_thinking = starts_in_thinking
|
||||||
|
for response in responses:
|
||||||
|
if response is None:
|
||||||
|
yield None
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_think_tag = (
|
||||||
|
tokenizer.think_end is not None and response.text == tokenizer.think_end
|
||||||
|
) or (
|
||||||
|
tokenizer.think_start is not None and response.text == tokenizer.think_start
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_think_tag:
|
||||||
|
in_thinking = response.text != tokenizer.think_end
|
||||||
|
# Never swallow finish_reason — the chunk stream needs it to terminate.
|
||||||
|
if response.finish_reason is not None:
|
||||||
|
yield response.model_copy(update={"text": "", "is_thinking": False})
|
||||||
|
continue
|
||||||
|
yield response.model_copy(update={"is_thinking": in_thinking})
|
||||||
|
|
||||||
|
|
||||||
|
def parse_tool_calls(
|
||||||
|
responses: Generator[GenerationResponse | None], tool_parser: ToolParser
|
||||||
|
) -> Generator[GenerationResponse | ToolCallResponse | None]:
|
||||||
|
in_tool_call = False
|
||||||
|
tool_call_text_parts: list[str] = []
|
||||||
|
for response in responses:
|
||||||
|
if response is None:
|
||||||
|
yield None
|
||||||
|
continue
|
||||||
|
if not in_tool_call and response.text.startswith(tool_parser.start_parsing):
|
||||||
|
in_tool_call = True
|
||||||
|
|
||||||
|
if in_tool_call:
|
||||||
|
tool_call_text_parts.append(response.text)
|
||||||
|
if response.text.endswith(tool_parser.end_parsing):
|
||||||
|
# parse the actual tool calls from the tool call text
|
||||||
|
parsed = tool_parser.parse_tool_calls(
|
||||||
|
"".join(tool_call_text_parts).strip()
|
||||||
|
)
|
||||||
|
logger.info(f"parsed {tool_call_text_parts=} into {parsed=}")
|
||||||
|
if parsed is not None:
|
||||||
|
yield ToolCallResponse(
|
||||||
|
tool_calls=parsed, usage=response.usage, stats=response.stats
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"tool call parsing failed for text {''.join(tool_call_text_parts)}"
|
||||||
|
)
|
||||||
|
response.text = "".join(tool_call_text_parts)
|
||||||
|
yield response
|
||||||
|
|
||||||
|
in_tool_call = False
|
||||||
|
tool_call_text_parts = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
if response.finish_reason is not None:
|
||||||
|
logger.info(
|
||||||
|
"tool call parsing interrupted, yield partial tool call as text"
|
||||||
|
)
|
||||||
|
response = response.model_copy(
|
||||||
|
update={
|
||||||
|
"text": "".join(tool_call_text_parts),
|
||||||
|
"token": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
yield response
|
||||||
|
|
||||||
|
else:
|
||||||
|
# fallthrough
|
||||||
|
yield response
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -52,6 +52,7 @@ class RunnerSupervisor:
|
|||||||
_tg: TaskGroup = field(default_factory=TaskGroup, init=False)
|
_tg: TaskGroup = field(default_factory=TaskGroup, init=False)
|
||||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||||
|
in_progress: set[TaskId] = field(default_factory=set, init=False)
|
||||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||||
_cancel_watch_runner: anyio.CancelScope = field(
|
_cancel_watch_runner: anyio.CancelScope = field(
|
||||||
@@ -157,6 +158,7 @@ class RunnerSupervisor:
|
|||||||
async def cancel_task(self, task_id: TaskId):
|
async def cancel_task(self, task_id: TaskId):
|
||||||
if task_id in self.completed:
|
if task_id in self.completed:
|
||||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||||
|
self.cancelled.add(task_id)
|
||||||
return
|
return
|
||||||
self.cancelled.add(task_id)
|
self.cancelled.add(task_id)
|
||||||
with anyio.move_on_after(0.5) as scope:
|
with anyio.move_on_after(0.5) as scope:
|
||||||
@@ -173,6 +175,7 @@ class RunnerSupervisor:
|
|||||||
self.status = event.runner_status
|
self.status = event.runner_status
|
||||||
if isinstance(event, TaskAcknowledged):
|
if isinstance(event, TaskAcknowledged):
|
||||||
self.pending.pop(event.task_id).set()
|
self.pending.pop(event.task_id).set()
|
||||||
|
self.in_progress.add(event.task_id)
|
||||||
continue
|
continue
|
||||||
if (
|
if (
|
||||||
isinstance(event, TaskStatusUpdated)
|
isinstance(event, TaskStatusUpdated)
|
||||||
@@ -189,6 +192,7 @@ class RunnerSupervisor:
|
|||||||
RunnerShuttingDown,
|
RunnerShuttingDown,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
self.in_progress.discard(event.task_id)
|
||||||
self.completed.add(event.task_id)
|
self.completed.add(event.task_id)
|
||||||
await self._event_sender.send(event)
|
await self._event_sender.send(event)
|
||||||
except (ClosedResourceError, BrokenResourceError) as e:
|
except (ClosedResourceError, BrokenResourceError) as e:
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ class FakeRunnerSupervisor:
|
|||||||
bound_instance: BoundInstance
|
bound_instance: BoundInstance
|
||||||
status: RunnerStatus
|
status: RunnerStatus
|
||||||
completed: set[TaskId] = field(default_factory=set)
|
completed: set[TaskId] = field(default_factory=set)
|
||||||
|
in_progress: set[TaskId] = field(default_factory=set)
|
||||||
|
pending: dict[TaskId, object] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class OtherTask(BaseTask):
|
class OtherTask(BaseTask):
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from exo.worker.engines.mlx.dsml_encoding import (
|
|||||||
encode_messages,
|
encode_messages,
|
||||||
parse_dsml_output,
|
parse_dsml_output,
|
||||||
)
|
)
|
||||||
from exo.worker.runner.llm_inference.runner import parse_deepseek_v32
|
from exo.worker.runner.llm_inference.model_output_parsers import parse_deepseek_v32
|
||||||
|
|
||||||
# ── Shared fixtures ──────────────────────────────────────────────
|
# ── Shared fixtures ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from typing import Callable
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import exo.worker.runner.llm_inference.batch_generator as mlx_batch_generator
|
||||||
|
import exo.worker.runner.llm_inference.model_output_parsers as mlx_model_output_parsers
|
||||||
import exo.worker.runner.llm_inference.runner as mlx_runner
|
import exo.worker.runner.llm_inference.runner as mlx_runner
|
||||||
from exo.shared.types.chunks import TokenChunk
|
from exo.shared.types.chunks import TokenChunk
|
||||||
from exo.shared.types.events import (
|
from exo.shared.types.events import (
|
||||||
@@ -114,27 +116,41 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# initialize_mlx returns a mock group
|
# initialize_mlx returns a mock group
|
||||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
||||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
monkeypatch.setattr(mlx_batch_generator, "warmup_inference", make_nothin(1))
|
||||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
monkeypatch.setattr(mlx_batch_generator, "_check_for_debug_prompts", nothin)
|
||||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
monkeypatch.setattr(mlx_batch_generator, "mx_any", make_nothin(False))
|
||||||
|
|
||||||
|
def fake_all_gather(
|
||||||
|
tasks: list[TextGeneration], group: object
|
||||||
|
) -> tuple[list[TextGeneration], list[TextGeneration]]:
|
||||||
|
return (tasks, [])
|
||||||
|
|
||||||
|
monkeypatch.setattr(mlx_batch_generator, "mx_all_gather_tasks", fake_all_gather)
|
||||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
monkeypatch.setattr(
|
||||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
mlx_batch_generator, "apply_chat_template", make_nothin("test prompt")
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mlx_model_output_parsers, "detect_thinking_prompt_suffix", make_nothin(False)
|
||||||
|
)
|
||||||
|
|
||||||
def fake_generate(*_1: object, **_2: object):
|
def fake_generate(*_1: object, **_2: object):
|
||||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
|
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
|
||||||
|
|
||||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
monkeypatch.setattr(mlx_batch_generator, "mlx_generate", fake_generate)
|
||||||
|
|
||||||
|
|
||||||
# Use a fake event_sender to remove test flakiness.
|
# Use a fake event_sender to remove test flakiness.
|
||||||
class EventCollector:
|
class EventCollector:
|
||||||
def __init__(self) -> None:
|
def __init__(self, on_event: Callable[[Event], None] | None = None) -> None:
|
||||||
self.events: list[Event] = []
|
self.events: list[Event] = []
|
||||||
|
self._on_event = on_event
|
||||||
|
|
||||||
def send(self, event: Event) -> None:
|
def send(self, event: Event) -> None:
|
||||||
self.events.append(event)
|
self.events.append(event)
|
||||||
|
if self._on_event:
|
||||||
|
self._on_event(event)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
pass
|
pass
|
||||||
@@ -159,7 +175,7 @@ class MockGroup:
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
||||||
def _run(tasks: Iterable[Task]):
|
def _run(tasks: Iterable[Task], send_after_ready: list[Task] | None = None):
|
||||||
bound_instance = get_bound_mlx_ring_instance(
|
bound_instance = get_bound_mlx_ring_instance(
|
||||||
instance_id=INSTANCE_1_ID,
|
instance_id=INSTANCE_1_ID,
|
||||||
model_id=MODEL_A_ID,
|
model_id=MODEL_A_ID,
|
||||||
@@ -169,7 +185,23 @@ def _run(tasks: Iterable[Task]):
|
|||||||
|
|
||||||
task_sender, task_receiver = mp_channel[Task]()
|
task_sender, task_receiver = mp_channel[Task]()
|
||||||
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
|
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
|
||||||
event_sender = EventCollector()
|
|
||||||
|
on_event: Callable[[Event], None] | None = None
|
||||||
|
if send_after_ready:
|
||||||
|
_saw_running = False
|
||||||
|
|
||||||
|
def _on_event(event: Event) -> None:
|
||||||
|
nonlocal _saw_running
|
||||||
|
if isinstance(event, RunnerStatusUpdated):
|
||||||
|
if isinstance(event.runner_status, RunnerRunning):
|
||||||
|
_saw_running = True
|
||||||
|
elif _saw_running and isinstance(event.runner_status, RunnerReady):
|
||||||
|
for t in send_after_ready:
|
||||||
|
task_sender.send(t)
|
||||||
|
|
||||||
|
on_event = _on_event
|
||||||
|
|
||||||
|
event_sender = EventCollector(on_event=on_event)
|
||||||
|
|
||||||
with task_sender:
|
with task_sender:
|
||||||
for t in tasks:
|
for t in tasks:
|
||||||
@@ -183,18 +215,22 @@ def _run(tasks: Iterable[Task]):
|
|||||||
"exo.worker.runner.llm_inference.runner.mx.distributed.all_gather",
|
"exo.worker.runner.llm_inference.runner.mx.distributed.all_gather",
|
||||||
make_nothin(mx.array([1])),
|
make_nothin(mx.array([1])),
|
||||||
):
|
):
|
||||||
mlx_runner.main(
|
runner = mlx_runner.Runner(
|
||||||
bound_instance,
|
bound_instance,
|
||||||
event_sender, # pyright: ignore[reportArgumentType]
|
event_sender, # pyright: ignore[reportArgumentType]
|
||||||
task_receiver,
|
task_receiver,
|
||||||
cancel_receiver,
|
cancel_receiver,
|
||||||
)
|
)
|
||||||
|
runner.main()
|
||||||
|
|
||||||
return event_sender.events
|
return event_sender.events
|
||||||
|
|
||||||
|
|
||||||
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||||
events = _run([INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK, SHUTDOWN_TASK])
|
events = _run(
|
||||||
|
[INIT_TASK, LOAD_TASK, WARMUP_TASK, CHAT_TASK],
|
||||||
|
send_after_ready=[SHUTDOWN_TASK],
|
||||||
|
)
|
||||||
|
|
||||||
expected_chunk = ChunkGenerated(
|
expected_chunk = ChunkGenerated(
|
||||||
command_id=COMMAND_1_ID,
|
command_id=COMMAND_1_ID,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from exo.shared.types.worker.runner_response import (
|
|||||||
GenerationResponse,
|
GenerationResponse,
|
||||||
ToolCallResponse,
|
ToolCallResponse,
|
||||||
)
|
)
|
||||||
from exo.worker.runner.llm_inference.runner import parse_gpt_oss
|
from exo.worker.runner.llm_inference.model_output_parsers import parse_gpt_oss
|
||||||
|
|
||||||
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
|
# Token IDs from mlx-community/gpt-oss-20b-MXFP4-Q8 tokenizer.
|
||||||
# These are stable since they come from the model's vocabulary.
|
# These are stable since they come from the model's vocabulary.
|
||||||
@@ -107,7 +107,7 @@ def _collect(
|
|||||||
def _gen() -> Generator[GenerationResponse, None, None]:
|
def _gen() -> Generator[GenerationResponse, None, None]:
|
||||||
yield from _make_gen_responses(tokens)
|
yield from _make_gen_responses(tokens)
|
||||||
|
|
||||||
return list(parse_gpt_oss(_gen()))
|
return list(x for x in parse_gpt_oss(_gen()) if x is not None)
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_call(
|
def _get_tool_call(
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from collections.abc import Generator
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
|
||||||
from exo.worker.runner.llm_inference.runner import parse_tool_calls
|
from exo.worker.runner.llm_inference.model_output_parsers import parse_tool_calls
|
||||||
from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser
|
from exo.worker.runner.llm_inference.tool_parsers import make_mlx_parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user