Compare commits

..

1 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
e4834dc615 Add support for Ollama API 2026-02-20 00:17:15 +00:00
3 changed files with 790 additions and 0 deletions

View File

@@ -0,0 +1,456 @@
from __future__ import annotations
import json
from collections.abc import AsyncGenerator
from typing import Any
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.common import CommandId
from exo.shared.types.ollama_api import (
OllamaChatRequest,
OllamaChatResponse,
OllamaDoneReason,
OllamaGenerateRequest,
OllamaGenerateResponse,
OllamaMessage,
OllamaToolCall,
OllamaToolFunction,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _map_done_reason(
finish_reason: str | None,
) -> OllamaDoneReason | None:
if finish_reason is None:
return None
if finish_reason == "stop":
return "stop"
if finish_reason == "length":
return "length"
if finish_reason in ("tool_calls", "function_call"):
return "tool_call"
if finish_reason == "error":
return "error"
return "stop"
def _try_parse_json(value: str) -> dict[str, Any] | str:
try:
return json.loads(value) # type: ignore
except json.JSONDecodeError:
return value
def _build_tool_calls(chunk: ToolCallChunk) -> list[OllamaToolCall]:
tool_calls: list[OllamaToolCall] = []
for index, tool in enumerate(chunk.tool_calls):
# tool.arguments is always str; try to parse as JSON dict for Ollama format
arguments: dict[str, Any] | str = _try_parse_json(tool.arguments)
tool_calls.append(
OllamaToolCall(
id=tool.id,
type="function",
function=OllamaToolFunction(
name=tool.name, arguments=arguments, index=index
),
)
)
return tool_calls
def _get_usage(
chunk: TokenChunk | ToolCallChunk,
) -> tuple[int | None, int | None]:
"""Extract (prompt_eval_count, eval_count) from a chunk."""
if chunk.usage is not None:
return (chunk.usage.prompt_tokens, chunk.usage.completion_tokens)
if chunk.stats is not None:
return (chunk.stats.prompt_tokens, chunk.stats.generation_tokens)
return (None, None)
def ollama_request_to_text_generation(
request: OllamaChatRequest,
) -> TextGenerationTaskParams:
"""Convert Ollama chat request to exo's internal text generation format."""
instructions: str | None = None
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
tool_message_index = 0
for msg in request.messages:
content = msg.content or ""
if msg.role == "system":
if instructions is None:
instructions = content
else:
instructions = f"{instructions}\n{content}"
chat_template_messages.append({"role": "system", "content": content})
continue
if msg.role in ("user", "assistant") and (
msg.content is not None or msg.thinking is not None or msg.tool_calls
):
input_messages.append(InputMessage(role=msg.role, content=content))
dumped: dict[str, Any] = {"role": msg.role, "content": content}
if msg.thinking is not None:
dumped["thinking"] = msg.thinking
if msg.tool_calls is not None:
tool_calls_list: list[dict[str, Any]] = []
for tc in msg.tool_calls:
function: dict[str, Any] = {
"name": tc.function.name,
"arguments": (
json.dumps(tc.function.arguments)
if isinstance(tc.function.arguments, dict)
else tc.function.arguments
),
}
if tc.function.index is not None:
function["index"] = tc.function.index
tool_call: dict[str, Any] = {"function": function}
if tc.id is not None:
tool_call["id"] = tc.id
if tc.type is not None:
tool_call["type"] = tc.type
tool_calls_list.append(tool_call)
dumped["tool_calls"] = tool_calls_list
if msg.name is not None:
dumped["name"] = msg.name
if msg.role == "tool":
tool_message_index += 1
tool_call_id = msg.tool_name or msg.name or f"tool_{tool_message_index}"
dumped["tool_call_id"] = tool_call_id
if msg.tool_name is not None:
dumped["tool_name"] = msg.tool_name
chat_template_messages.append(dumped)
options = request.options
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=options.num_predict if options else None,
temperature=options.temperature if options else None,
top_p=options.top_p if options else None,
top_k=options.top_k if options else None,
stop=options.stop if options else None,
seed=options.seed if options else None,
stream=request.stream,
tools=request.tools,
enable_thinking=request.think,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def generate_ollama_chat_stream(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate streaming responses in Ollama format (newline-delimited JSON)."""
thinking_parts: list[str] = []
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
error_response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant", content=chunk.error_message
),
done=True,
done_reason="error",
)
yield f"{error_response.model_dump_json(exclude_none=True)}\n"
return
case ToolCallChunk():
prompt_eval, eval_count = _get_usage(chunk)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant",
content="",
tool_calls=_build_tool_calls(chunk),
thinking="".join(thinking_parts) if thinking_parts else None,
),
done=True,
done_reason="tool_call",
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
return
case TokenChunk():
done = chunk.finish_reason is not None
if chunk.is_thinking:
thinking_parts.append(chunk.text)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant", content="", thinking=chunk.text
),
done=False,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
elif done:
prompt_eval, eval_count = _get_usage(chunk)
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(
role="assistant",
content=chunk.text,
),
done=True,
done_reason=_map_done_reason(chunk.finish_reason),
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
else:
response = OllamaChatResponse(
model=str(chunk.model),
message=OllamaMessage(role="assistant", content=chunk.text),
done=False,
)
yield f"{response.model_dump_json(exclude_none=True)}\n"
if done:
return
async def collect_ollama_chat_response(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
"""Collect streaming chunks into a single non-streaming Ollama response.
Returns an AsyncGenerator[str] (single yield) for consistency with FastAPI
StreamingResponse cancellation handling.
"""
text_parts: list[str] = []
thinking_parts: list[str] = []
tool_calls: list[OllamaToolCall] = []
model: str | None = None
finish_reason: str | None = None
prompt_eval_count: int | None = None
eval_count: int | None = None
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
raise ValueError(chunk.error_message or "Internal server error")
case TokenChunk():
if model is None:
model = str(chunk.model)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
case ToolCallChunk():
if model is None:
model = str(chunk.model)
tool_calls.extend(_build_tool_calls(chunk))
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
combined_text = "".join(text_parts)
combined_thinking = "".join(thinking_parts) if thinking_parts else None
assert model is not None
yield OllamaChatResponse(
model=model,
message=OllamaMessage(
role="assistant",
content=combined_text,
thinking=combined_thinking,
tool_calls=tool_calls if tool_calls else None,
),
done=True,
done_reason=_map_done_reason(finish_reason),
prompt_eval_count=prompt_eval_count,
eval_count=eval_count,
).model_dump_json(exclude_none=True)
return
# ── /api/generate ──
def ollama_generate_request_to_text_generation(
request: OllamaGenerateRequest,
) -> TextGenerationTaskParams:
"""Convert Ollama generate request to exo's internal text generation format."""
chat_template_messages: list[dict[str, Any]] = []
if request.system:
chat_template_messages.append({"role": "system", "content": request.system})
chat_template_messages.append({"role": "user", "content": request.prompt})
options = request.options
return TextGenerationTaskParams(
model=request.model,
input=[InputMessage(role="user", content=request.prompt)],
instructions=request.system,
max_output_tokens=options.num_predict if options else None,
temperature=options.temperature if options else None,
top_p=options.top_p if options else None,
top_k=options.top_k if options else None,
stop=options.stop if options else None,
seed=options.seed if options else None,
stream=request.stream,
enable_thinking=request.think,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def generate_ollama_generate_stream(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str, None]:
"""Generate streaming responses for /api/generate in Ollama NDJSON format."""
thinking_parts: list[str] = []
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
done=True,
done_reason="error",
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
return
case ToolCallChunk():
# generate endpoint doesn't support tools; emit as done
prompt_eval, eval_count = _get_usage(chunk)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
done=True,
done_reason="stop",
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
return
case TokenChunk():
done = chunk.finish_reason is not None
if chunk.is_thinking:
thinking_parts.append(chunk.text)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response="",
thinking=chunk.text,
done=False,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
elif done:
prompt_eval, eval_count = _get_usage(chunk)
resp = OllamaGenerateResponse(
model=str(chunk.model),
response=chunk.text,
done=True,
done_reason=_map_done_reason(chunk.finish_reason),
prompt_eval_count=prompt_eval,
eval_count=eval_count,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
else:
resp = OllamaGenerateResponse(
model=str(chunk.model),
response=chunk.text,
done=False,
)
yield f"{resp.model_dump_json(exclude_none=True)}\n"
if done:
return
async def collect_ollama_generate_response(
_command_id: CommandId,
chunk_stream: AsyncGenerator[
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
],
) -> AsyncGenerator[str]:
"""Collect chunks into a single non-streaming /api/generate response."""
text_parts: list[str] = []
thinking_parts: list[str] = []
model: str | None = None
finish_reason: str | None = None
prompt_eval_count: int | None = None
eval_count: int | None = None
async for chunk in chunk_stream:
match chunk:
case PrefillProgressChunk():
continue
case ErrorChunk():
raise ValueError(chunk.error_message or "Internal server error")
case TokenChunk():
if model is None:
model = str(chunk.model)
if chunk.is_thinking:
thinking_parts.append(chunk.text)
else:
text_parts.append(chunk.text)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
case ToolCallChunk():
if model is None:
model = str(chunk.model)
finish_reason = chunk.finish_reason
prompt_eval_count, eval_count = _get_usage(chunk)
assert model is not None
yield OllamaGenerateResponse(
model=model,
response="".join(text_parts),
thinking="".join(thinking_parts) if thinking_parts else None,
done=True,
done_reason=_map_done_reason(finish_reason),
prompt_eval_count=prompt_eval_count,
eval_count=eval_count,
).model_dump_json(exclude_none=True)
return

View File

@@ -32,6 +32,14 @@ from exo.master.adapters.claude import (
collect_claude_response,
generate_claude_stream,
)
from exo.master.adapters.ollama import (
collect_ollama_chat_response,
collect_ollama_generate_response,
generate_ollama_chat_stream,
generate_ollama_generate_stream,
ollama_generate_request_to_text_generation,
ollama_request_to_text_generation,
)
from exo.master.adapters.responses import (
collect_responses_response,
generate_responses_stream,
@@ -141,6 +149,19 @@ from exo.shared.types.events import (
TracesMerged,
)
from exo.shared.types.memory import Memory
from exo.shared.types.ollama_api import (
OllamaChatRequest,
OllamaChatResponse,
OllamaGenerateRequest,
OllamaGenerateResponse,
OllamaModelDetails,
OllamaModelTag,
OllamaPsModel,
OllamaPsResponse,
OllamaShowRequest,
OllamaShowResponse,
OllamaTagsResponse,
)
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
@@ -300,6 +321,20 @@ class API:
self.app.get("/images/{image_id}")(self.get_image)
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
# Ollama API — health checks (must be before static files mount)
self.app.head("/")(self._ollama_root)
self.app.head("/api/version")(self.ollama_version)
# Ollama API
self.app.post("/api/chat", response_model=None)(self.ollama_chat)
self.app.post("/api/api/chat", response_model=None)(self.ollama_chat)
self.app.post("/api/v1/chat", response_model=None)(self.ollama_chat)
self.app.post("/api/generate", response_model=None)(self.ollama_generate)
self.app.get("/api/tags")(self.ollama_tags)
self.app.get("/api/api/tags")(self.ollama_tags)
self.app.get("/api/v1/tags")(self.ollama_tags)
self.app.post("/api/show")(self.ollama_show)
self.app.get("/api/ps")(self.ollama_ps)
self.app.get("/api/version")(self.ollama_version)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(self.stream_events)
self.app.post("/download/start")(self.start_download)
@@ -1293,6 +1328,158 @@ class API:
media_type="application/json",
)
async def _ollama_root(self) -> JSONResponse:
"""Respond to HEAD / from Ollama CLI connectivity checks."""
return JSONResponse(content="Ollama is running")
async def ollama_chat(
self, request: Request
) -> OllamaChatResponse | StreamingResponse:
"""Ollama Chat API — accepts JSON regardless of Content-Type."""
body = await request.body()
payload = OllamaChatRequest.model_validate_json(body)
task_params = ollama_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_ollama_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
else:
return StreamingResponse(
collect_ollama_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def ollama_generate(
self, request: Request
) -> OllamaGenerateResponse | StreamingResponse:
"""Ollama Generate API — accepts JSON regardless of Content-Type."""
body = await request.body()
payload = OllamaGenerateRequest.model_validate_json(body)
task_params = ollama_generate_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_ollama_generate_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"Connection": "close",
"X-Accel-Buffering": "no",
},
)
else:
return StreamingResponse(
collect_ollama_generate_response(
command.command_id,
self._token_chunk_stream(command.command_id),
),
media_type="application/json",
)
async def ollama_tags(self) -> OllamaTagsResponse:
"""Returns list of models in Ollama tags format. We return the downloaded ones only."""
def none_if_empty(value: str) -> str | None:
return value or None
downloaded_model_ids: set[str] = set()
for node_downloads in self.state.downloads.values():
for dl in node_downloads:
if isinstance(dl, DownloadCompleted):
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
cards = [
c for c in await get_model_cards() if c.model_id in downloaded_model_ids
]
now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
return OllamaTagsResponse(
models=[
OllamaModelTag(
name=str(card.model_id),
model=str(card.model_id),
modified_at=now,
size=card.storage_size.in_bytes,
digest="sha256:000000000000",
details=OllamaModelDetails(
family=none_if_empty(card.family),
quantization_level=none_if_empty(card.quantization),
),
)
for card in cards
]
)
async def ollama_show(self, request: Request) -> OllamaShowResponse:
"""Returns model information in Ollama show format."""
body = await request.body()
payload = OllamaShowRequest.model_validate_json(body)
try:
card = await ModelCard.load(ModelId(payload.name))
except Exception as exc:
raise HTTPException(
status_code=404, detail=f"Model not found: {payload.name}"
) from exc
return OllamaShowResponse(
details=OllamaModelDetails(
family=card.family or None,
quantization_level=card.quantization or None,
),
)
async def ollama_ps(self) -> OllamaPsResponse:
"""Returns list of running models (active instances)."""
models: list[OllamaPsModel] = []
seen: set[str] = set()
for instance in self.state.instances.values():
model_id = str(instance.shard_assignments.model_id)
if model_id in seen:
continue
seen.add(model_id)
models.append(
OllamaPsModel(
name=model_id,
model=model_id,
size=0,
)
)
return OllamaPsResponse(models=models)
async def ollama_version(self) -> dict[str, str]:
"""Returns version information for Ollama API compatibility."""
return {"version": "exo v1.0"}
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()

View File

@@ -0,0 +1,147 @@
from __future__ import annotations
import time
from typing import Any, Literal
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import ModelId
# https://github.com/ollama/ollama/blob/main/docs/api.md
OllamaRole = Literal["system", "user", "assistant", "tool"]
OllamaDoneReason = Literal["stop", "length", "tool_call", "error"]
class OllamaToolFunction(BaseModel, frozen=True):
name: str
arguments: dict[str, Any] | str
index: int | None = None
class OllamaToolCall(BaseModel, frozen=True):
id: str | None = None
type: Literal["function"] | None = None
function: OllamaToolFunction
class OllamaMessage(BaseModel, frozen=True):
role: OllamaRole
content: str | None = None
thinking: str | None = None
tool_calls: list[OllamaToolCall] | None = None
name: str | None = None
tool_name: str | None = None
images: list[str] | None = None
class OllamaOptions(BaseModel, frozen=True):
num_predict: int | None = None
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
stop: str | list[str] | None = None
seed: int | None = None
class OllamaChatRequest(BaseModel, frozen=True):
model: ModelId
messages: list[OllamaMessage]
stream: bool = True
options: OllamaOptions | None = None
tools: list[dict[str, Any]] | None = None
format: Literal["json"] | dict[str, Any] | None = None
keep_alive: str | int | None = None
think: bool | None = None
class OllamaGenerateRequest(BaseModel, frozen=True):
model: ModelId
prompt: str = ""
system: str | None = None
stream: bool = True
options: OllamaOptions | None = None
format: Literal["json"] | dict[str, Any] | None = None
keep_alive: str | int | None = None
think: bool | None = None
raw: bool = False
class OllamaGenerateResponse(BaseModel, frozen=True, strict=True):
model: str
created_at: str = Field(
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
)
response: str
thinking: str | None = None
done: bool
done_reason: OllamaDoneReason | None = None
total_duration: int | None = None
load_duration: int | None = None
prompt_eval_count: int | None = None
prompt_eval_duration: int | None = None
eval_count: int | None = None
eval_duration: int | None = None
class OllamaShowRequest(BaseModel, frozen=True):
name: str
verbose: bool | None = None
class OllamaChatResponse(BaseModel, frozen=True, strict=True):
model: str
created_at: str = Field(
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
)
message: OllamaMessage
done: bool
done_reason: OllamaDoneReason | None = None
total_duration: int | None = None
load_duration: int | None = None
prompt_eval_count: int | None = None
prompt_eval_duration: int | None = None
eval_count: int | None = None
eval_duration: int | None = None
class OllamaModelDetails(BaseModel, frozen=True, strict=True):
format: str | None = None
family: str | None = None
parameter_size: str | None = None
quantization_level: str | None = None
class OllamaModelTag(BaseModel, frozen=True, strict=True):
name: str
model: str | None = None
modified_at: str | None = None
size: int | None = None
digest: str | None = None
details: OllamaModelDetails | None = None
class OllamaTagsResponse(BaseModel, frozen=True, strict=True):
models: list[OllamaModelTag]
class OllamaShowResponse(BaseModel, frozen=True, strict=True):
modelfile: str | None = None
parameters: str | None = None
template: str | None = None
details: OllamaModelDetails | None = None
model_info: dict[str, Any] | None = None
class OllamaPsModel(BaseModel, frozen=True, strict=True):
name: str
model: str
size: int
digest: str | None = None
details: OllamaModelDetails | None = None
expires_at: str | None = None
size_vram: int | None = None
class OllamaPsResponse(BaseModel, frozen=True, strict=True):
models: list[OllamaPsModel]