mirror of
https://github.com/exo-explore/exo.git
synced 2026-03-05 14:48:28 -05:00
34 lines
1.5 KiB
Python
34 lines
1.5 KiB
Python
from inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
|
from inference.inference_engine import InferenceEngine
|
|
from inference.shard import Shard
|
|
from inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
|
import numpy as np
|
|
|
|
# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
|
|
async def test_inference_engine(inference_engine: InferenceEngine, model_id: str, input_data: np.array):
|
|
# inference_engine.reset_shard(Shard("", 0,0,0))
|
|
resp_full, _ = await inference_engine.infer_prompt(shard=Shard(model_id=model_id, start_layer=0, end_layer=1, n_layers=2), prompt="In one word, what is the capital of USA? ")
|
|
|
|
print("resp_full", resp_full)
|
|
print("decoded", inference_engine.tokenizer.decode(resp_full))
|
|
|
|
# inference_engine.reset_shard(Shard("", 0,0,0))
|
|
|
|
# resp1, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=0, end_layer=0, n_layers=2), input_data=input_data)
|
|
# resp2, _ = await inference_engine.infer_tensor(shard=Shard(model_id=model_id, start_layer=1, end_layer=1, n_layers=2), input_data=resp1)
|
|
|
|
# assert np.array_equal(resp_full, resp2)
|
|
|
|
import asyncio
|
|
|
|
# asyncio.run(test_inference_engine(
|
|
# MLXDynamicShardInferenceEngine(),
|
|
# "mlx-community/Meta-Llama-3-8B-Instruct-4bit",
|
|
# [1234]
|
|
# ))
|
|
|
|
asyncio.run(test_inference_engine(
|
|
TinygradDynamicShardInferenceEngine(),
|
|
"/Users/alex/Library/Caches/tinygrad/downloads/llama3-8b-sfr",
|
|
[1234]
|
|
)) |