mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-19 23:36:30 -05:00
Compare commits
1 Commits
main
...
leo/add-ol
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e4834dc615 |
456
src/exo/master/adapters/ollama.py
Normal file
456
src/exo/master/adapters/ollama.py
Normal file
@@ -0,0 +1,456 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
PrefillProgressChunk,
|
||||
TokenChunk,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.ollama_api import (
|
||||
OllamaChatRequest,
|
||||
OllamaChatResponse,
|
||||
OllamaDoneReason,
|
||||
OllamaGenerateRequest,
|
||||
OllamaGenerateResponse,
|
||||
OllamaMessage,
|
||||
OllamaToolCall,
|
||||
OllamaToolFunction,
|
||||
)
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def _map_done_reason(
|
||||
finish_reason: str | None,
|
||||
) -> OllamaDoneReason | None:
|
||||
if finish_reason is None:
|
||||
return None
|
||||
if finish_reason == "stop":
|
||||
return "stop"
|
||||
if finish_reason == "length":
|
||||
return "length"
|
||||
if finish_reason in ("tool_calls", "function_call"):
|
||||
return "tool_call"
|
||||
if finish_reason == "error":
|
||||
return "error"
|
||||
return "stop"
|
||||
|
||||
|
||||
def _try_parse_json(value: str) -> dict[str, Any] | str:
|
||||
try:
|
||||
return json.loads(value) # type: ignore
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
|
||||
def _build_tool_calls(chunk: ToolCallChunk) -> list[OllamaToolCall]:
|
||||
tool_calls: list[OllamaToolCall] = []
|
||||
for index, tool in enumerate(chunk.tool_calls):
|
||||
# tool.arguments is always str; try to parse as JSON dict for Ollama format
|
||||
arguments: dict[str, Any] | str = _try_parse_json(tool.arguments)
|
||||
tool_calls.append(
|
||||
OllamaToolCall(
|
||||
id=tool.id,
|
||||
type="function",
|
||||
function=OllamaToolFunction(
|
||||
name=tool.name, arguments=arguments, index=index
|
||||
),
|
||||
)
|
||||
)
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _get_usage(
|
||||
chunk: TokenChunk | ToolCallChunk,
|
||||
) -> tuple[int | None, int | None]:
|
||||
"""Extract (prompt_eval_count, eval_count) from a chunk."""
|
||||
if chunk.usage is not None:
|
||||
return (chunk.usage.prompt_tokens, chunk.usage.completion_tokens)
|
||||
if chunk.stats is not None:
|
||||
return (chunk.stats.prompt_tokens, chunk.stats.generation_tokens)
|
||||
return (None, None)
|
||||
|
||||
|
||||
def ollama_request_to_text_generation(
|
||||
request: OllamaChatRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
"""Convert Ollama chat request to exo's internal text generation format."""
|
||||
instructions: str | None = None
|
||||
input_messages: list[InputMessage] = []
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
tool_message_index = 0
|
||||
|
||||
for msg in request.messages:
|
||||
content = msg.content or ""
|
||||
|
||||
if msg.role == "system":
|
||||
if instructions is None:
|
||||
instructions = content
|
||||
else:
|
||||
instructions = f"{instructions}\n{content}"
|
||||
chat_template_messages.append({"role": "system", "content": content})
|
||||
continue
|
||||
|
||||
if msg.role in ("user", "assistant") and (
|
||||
msg.content is not None or msg.thinking is not None or msg.tool_calls
|
||||
):
|
||||
input_messages.append(InputMessage(role=msg.role, content=content))
|
||||
|
||||
dumped: dict[str, Any] = {"role": msg.role, "content": content}
|
||||
if msg.thinking is not None:
|
||||
dumped["thinking"] = msg.thinking
|
||||
if msg.tool_calls is not None:
|
||||
tool_calls_list: list[dict[str, Any]] = []
|
||||
for tc in msg.tool_calls:
|
||||
function: dict[str, Any] = {
|
||||
"name": tc.function.name,
|
||||
"arguments": (
|
||||
json.dumps(tc.function.arguments)
|
||||
if isinstance(tc.function.arguments, dict)
|
||||
else tc.function.arguments
|
||||
),
|
||||
}
|
||||
if tc.function.index is not None:
|
||||
function["index"] = tc.function.index
|
||||
tool_call: dict[str, Any] = {"function": function}
|
||||
if tc.id is not None:
|
||||
tool_call["id"] = tc.id
|
||||
if tc.type is not None:
|
||||
tool_call["type"] = tc.type
|
||||
tool_calls_list.append(tool_call)
|
||||
dumped["tool_calls"] = tool_calls_list
|
||||
if msg.name is not None:
|
||||
dumped["name"] = msg.name
|
||||
if msg.role == "tool":
|
||||
tool_message_index += 1
|
||||
tool_call_id = msg.tool_name or msg.name or f"tool_{tool_message_index}"
|
||||
dumped["tool_call_id"] = tool_call_id
|
||||
if msg.tool_name is not None:
|
||||
dumped["tool_name"] = msg.tool_name
|
||||
chat_template_messages.append(dumped)
|
||||
|
||||
options = request.options
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages
|
||||
if input_messages
|
||||
else [InputMessage(role="user", content="")],
|
||||
instructions=instructions,
|
||||
max_output_tokens=options.num_predict if options else None,
|
||||
temperature=options.temperature if options else None,
|
||||
top_p=options.top_p if options else None,
|
||||
top_k=options.top_k if options else None,
|
||||
stop=options.stop if options else None,
|
||||
seed=options.seed if options else None,
|
||||
stream=request.stream,
|
||||
tools=request.tools,
|
||||
enable_thinking=request.think,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
async def generate_ollama_chat_stream(
|
||||
_command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate streaming responses in Ollama format (newline-delimited JSON)."""
|
||||
thinking_parts: list[str] = []
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
|
||||
case ErrorChunk():
|
||||
error_response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(
|
||||
role="assistant", content=chunk.error_message
|
||||
),
|
||||
done=True,
|
||||
done_reason="error",
|
||||
)
|
||||
yield f"{error_response.model_dump_json(exclude_none=True)}\n"
|
||||
return
|
||||
|
||||
case ToolCallChunk():
|
||||
prompt_eval, eval_count = _get_usage(chunk)
|
||||
response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=_build_tool_calls(chunk),
|
||||
thinking="".join(thinking_parts) if thinking_parts else None,
|
||||
),
|
||||
done=True,
|
||||
done_reason="tool_call",
|
||||
prompt_eval_count=prompt_eval,
|
||||
eval_count=eval_count,
|
||||
)
|
||||
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||
return
|
||||
|
||||
case TokenChunk():
|
||||
done = chunk.finish_reason is not None
|
||||
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(
|
||||
role="assistant", content="", thinking=chunk.text
|
||||
),
|
||||
done=False,
|
||||
)
|
||||
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||
elif done:
|
||||
prompt_eval, eval_count = _get_usage(chunk)
|
||||
response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content=chunk.text,
|
||||
),
|
||||
done=True,
|
||||
done_reason=_map_done_reason(chunk.finish_reason),
|
||||
prompt_eval_count=prompt_eval,
|
||||
eval_count=eval_count,
|
||||
)
|
||||
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||
else:
|
||||
response = OllamaChatResponse(
|
||||
model=str(chunk.model),
|
||||
message=OllamaMessage(role="assistant", content=chunk.text),
|
||||
done=False,
|
||||
)
|
||||
yield f"{response.model_dump_json(exclude_none=True)}\n"
|
||||
|
||||
if done:
|
||||
return
|
||||
|
||||
|
||||
async def collect_ollama_chat_response(
|
||||
_command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
"""Collect streaming chunks into a single non-streaming Ollama response.
|
||||
|
||||
Returns an AsyncGenerator[str] (single yield) for consistency with FastAPI
|
||||
StreamingResponse cancellation handling.
|
||||
"""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
tool_calls: list[OllamaToolCall] = []
|
||||
model: str | None = None
|
||||
finish_reason: str | None = None
|
||||
prompt_eval_count: int | None = None
|
||||
eval_count: int | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
|
||||
case ErrorChunk():
|
||||
raise ValueError(chunk.error_message or "Internal server error")
|
||||
|
||||
case TokenChunk():
|
||||
if model is None:
|
||||
model = str(chunk.model)
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||
|
||||
case ToolCallChunk():
|
||||
if model is None:
|
||||
model = str(chunk.model)
|
||||
tool_calls.extend(_build_tool_calls(chunk))
|
||||
finish_reason = chunk.finish_reason
|
||||
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
combined_thinking = "".join(thinking_parts) if thinking_parts else None
|
||||
assert model is not None
|
||||
|
||||
yield OllamaChatResponse(
|
||||
model=model,
|
||||
message=OllamaMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
thinking=combined_thinking,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
done=True,
|
||||
done_reason=_map_done_reason(finish_reason),
|
||||
prompt_eval_count=prompt_eval_count,
|
||||
eval_count=eval_count,
|
||||
).model_dump_json(exclude_none=True)
|
||||
return
|
||||
|
||||
|
||||
# ── /api/generate ──
|
||||
|
||||
|
||||
def ollama_generate_request_to_text_generation(
|
||||
request: OllamaGenerateRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
"""Convert Ollama generate request to exo's internal text generation format."""
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
if request.system:
|
||||
chat_template_messages.append({"role": "system", "content": request.system})
|
||||
chat_template_messages.append({"role": "user", "content": request.prompt})
|
||||
|
||||
options = request.options
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=[InputMessage(role="user", content=request.prompt)],
|
||||
instructions=request.system,
|
||||
max_output_tokens=options.num_predict if options else None,
|
||||
temperature=options.temperature if options else None,
|
||||
top_p=options.top_p if options else None,
|
||||
top_k=options.top_k if options else None,
|
||||
stop=options.stop if options else None,
|
||||
seed=options.seed if options else None,
|
||||
stream=request.stream,
|
||||
enable_thinking=request.think,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
async def generate_ollama_generate_stream(
|
||||
_command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate streaming responses for /api/generate in Ollama NDJSON format."""
|
||||
thinking_parts: list[str] = []
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
|
||||
case ErrorChunk():
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response="",
|
||||
done=True,
|
||||
done_reason="error",
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
return
|
||||
|
||||
case ToolCallChunk():
|
||||
# generate endpoint doesn't support tools; emit as done
|
||||
prompt_eval, eval_count = _get_usage(chunk)
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response="",
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
prompt_eval_count=prompt_eval,
|
||||
eval_count=eval_count,
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
return
|
||||
|
||||
case TokenChunk():
|
||||
done = chunk.finish_reason is not None
|
||||
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response="",
|
||||
thinking=chunk.text,
|
||||
done=False,
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
elif done:
|
||||
prompt_eval, eval_count = _get_usage(chunk)
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response=chunk.text,
|
||||
done=True,
|
||||
done_reason=_map_done_reason(chunk.finish_reason),
|
||||
prompt_eval_count=prompt_eval,
|
||||
eval_count=eval_count,
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
else:
|
||||
resp = OllamaGenerateResponse(
|
||||
model=str(chunk.model),
|
||||
response=chunk.text,
|
||||
done=False,
|
||||
)
|
||||
yield f"{resp.model_dump_json(exclude_none=True)}\n"
|
||||
|
||||
if done:
|
||||
return
|
||||
|
||||
|
||||
async def collect_ollama_generate_response(
|
||||
_command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[
|
||||
ErrorChunk | ToolCallChunk | TokenChunk | PrefillProgressChunk, None
|
||||
],
|
||||
) -> AsyncGenerator[str]:
|
||||
"""Collect chunks into a single non-streaming /api/generate response."""
|
||||
text_parts: list[str] = []
|
||||
thinking_parts: list[str] = []
|
||||
model: str | None = None
|
||||
finish_reason: str | None = None
|
||||
prompt_eval_count: int | None = None
|
||||
eval_count: int | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
match chunk:
|
||||
case PrefillProgressChunk():
|
||||
continue
|
||||
case ErrorChunk():
|
||||
raise ValueError(chunk.error_message or "Internal server error")
|
||||
case TokenChunk():
|
||||
if model is None:
|
||||
model = str(chunk.model)
|
||||
if chunk.is_thinking:
|
||||
thinking_parts.append(chunk.text)
|
||||
else:
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||
case ToolCallChunk():
|
||||
if model is None:
|
||||
model = str(chunk.model)
|
||||
finish_reason = chunk.finish_reason
|
||||
prompt_eval_count, eval_count = _get_usage(chunk)
|
||||
|
||||
assert model is not None
|
||||
yield OllamaGenerateResponse(
|
||||
model=model,
|
||||
response="".join(text_parts),
|
||||
thinking="".join(thinking_parts) if thinking_parts else None,
|
||||
done=True,
|
||||
done_reason=_map_done_reason(finish_reason),
|
||||
prompt_eval_count=prompt_eval_count,
|
||||
eval_count=eval_count,
|
||||
).model_dump_json(exclude_none=True)
|
||||
return
|
||||
@@ -32,6 +32,14 @@ from exo.master.adapters.claude import (
|
||||
collect_claude_response,
|
||||
generate_claude_stream,
|
||||
)
|
||||
from exo.master.adapters.ollama import (
|
||||
collect_ollama_chat_response,
|
||||
collect_ollama_generate_response,
|
||||
generate_ollama_chat_stream,
|
||||
generate_ollama_generate_stream,
|
||||
ollama_generate_request_to_text_generation,
|
||||
ollama_request_to_text_generation,
|
||||
)
|
||||
from exo.master.adapters.responses import (
|
||||
collect_responses_response,
|
||||
generate_responses_stream,
|
||||
@@ -141,6 +149,19 @@ from exo.shared.types.events import (
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.ollama_api import (
|
||||
OllamaChatRequest,
|
||||
OllamaChatResponse,
|
||||
OllamaGenerateRequest,
|
||||
OllamaGenerateResponse,
|
||||
OllamaModelDetails,
|
||||
OllamaModelTag,
|
||||
OllamaPsModel,
|
||||
OllamaPsResponse,
|
||||
OllamaShowRequest,
|
||||
OllamaShowResponse,
|
||||
OllamaTagsResponse,
|
||||
)
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
@@ -300,6 +321,20 @@ class API:
|
||||
self.app.get("/images/{image_id}")(self.get_image)
|
||||
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
|
||||
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
|
||||
# Ollama API — health checks (must be before static files mount)
|
||||
self.app.head("/")(self._ollama_root)
|
||||
self.app.head("/api/version")(self.ollama_version)
|
||||
# Ollama API
|
||||
self.app.post("/api/chat", response_model=None)(self.ollama_chat)
|
||||
self.app.post("/api/api/chat", response_model=None)(self.ollama_chat)
|
||||
self.app.post("/api/v1/chat", response_model=None)(self.ollama_chat)
|
||||
self.app.post("/api/generate", response_model=None)(self.ollama_generate)
|
||||
self.app.get("/api/tags")(self.ollama_tags)
|
||||
self.app.get("/api/api/tags")(self.ollama_tags)
|
||||
self.app.get("/api/v1/tags")(self.ollama_tags)
|
||||
self.app.post("/api/show")(self.ollama_show)
|
||||
self.app.get("/api/ps")(self.ollama_ps)
|
||||
self.app.get("/api/version")(self.ollama_version)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(self.stream_events)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
@@ -1293,6 +1328,158 @@ class API:
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def _ollama_root(self) -> JSONResponse:
|
||||
"""Respond to HEAD / from Ollama CLI connectivity checks."""
|
||||
return JSONResponse(content="Ollama is running")
|
||||
|
||||
async def ollama_chat(
|
||||
self, request: Request
|
||||
) -> OllamaChatResponse | StreamingResponse:
|
||||
"""Ollama Chat API — accepts JSON regardless of Content-Type."""
|
||||
body = await request.body()
|
||||
payload = OllamaChatRequest.model_validate_json(body)
|
||||
task_params = ollama_request_to_text_generation(payload)
|
||||
resolved_model = await self._resolve_and_validate_text_model(
|
||||
ModelId(task_params.model)
|
||||
)
|
||||
task_params = task_params.model_copy(update={"model": resolved_model})
|
||||
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
generate_ollama_chat_stream(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_ollama_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def ollama_generate(
|
||||
self, request: Request
|
||||
) -> OllamaGenerateResponse | StreamingResponse:
|
||||
"""Ollama Generate API — accepts JSON regardless of Content-Type."""
|
||||
body = await request.body()
|
||||
payload = OllamaGenerateRequest.model_validate_json(body)
|
||||
task_params = ollama_generate_request_to_text_generation(payload)
|
||||
resolved_model = await self._resolve_and_validate_text_model(
|
||||
ModelId(task_params.model)
|
||||
)
|
||||
task_params = task_params.model_copy(update={"model": resolved_model})
|
||||
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
generate_ollama_generate_stream(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "close",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_ollama_generate_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def ollama_tags(self) -> OllamaTagsResponse:
|
||||
"""Returns list of models in Ollama tags format. We return the downloaded ones only."""
|
||||
|
||||
def none_if_empty(value: str) -> str | None:
|
||||
return value or None
|
||||
|
||||
downloaded_model_ids: set[str] = set()
|
||||
for node_downloads in self.state.downloads.values():
|
||||
for dl in node_downloads:
|
||||
if isinstance(dl, DownloadCompleted):
|
||||
downloaded_model_ids.add(dl.shard_metadata.model_card.model_id)
|
||||
|
||||
cards = [
|
||||
c for c in await get_model_cards() if c.model_id in downloaded_model_ids
|
||||
]
|
||||
|
||||
now = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
return OllamaTagsResponse(
|
||||
models=[
|
||||
OllamaModelTag(
|
||||
name=str(card.model_id),
|
||||
model=str(card.model_id),
|
||||
modified_at=now,
|
||||
size=card.storage_size.in_bytes,
|
||||
digest="sha256:000000000000",
|
||||
details=OllamaModelDetails(
|
||||
family=none_if_empty(card.family),
|
||||
quantization_level=none_if_empty(card.quantization),
|
||||
),
|
||||
)
|
||||
for card in cards
|
||||
]
|
||||
)
|
||||
|
||||
async def ollama_show(self, request: Request) -> OllamaShowResponse:
|
||||
"""Returns model information in Ollama show format."""
|
||||
body = await request.body()
|
||||
payload = OllamaShowRequest.model_validate_json(body)
|
||||
try:
|
||||
card = await ModelCard.load(ModelId(payload.name))
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Model not found: {payload.name}"
|
||||
) from exc
|
||||
|
||||
return OllamaShowResponse(
|
||||
details=OllamaModelDetails(
|
||||
family=card.family or None,
|
||||
quantization_level=card.quantization or None,
|
||||
),
|
||||
)
|
||||
|
||||
async def ollama_ps(self) -> OllamaPsResponse:
|
||||
"""Returns list of running models (active instances)."""
|
||||
models: list[OllamaPsModel] = []
|
||||
seen: set[str] = set()
|
||||
for instance in self.state.instances.values():
|
||||
model_id = str(instance.shard_assignments.model_id)
|
||||
if model_id in seen:
|
||||
continue
|
||||
seen.add(model_id)
|
||||
models.append(
|
||||
OllamaPsModel(
|
||||
name=model_id,
|
||||
model=model_id,
|
||||
size=0,
|
||||
)
|
||||
)
|
||||
return OllamaPsResponse(models=models)
|
||||
|
||||
async def ollama_version(self) -> dict[str, str]:
|
||||
"""Returns version information for Ollama API compatibility."""
|
||||
return {"version": "exo v1.0"}
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
|
||||
147
src/exo/shared/types/ollama_api.py
Normal file
147
src/exo/shared/types/ollama_api.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
|
||||
# https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
|
||||
OllamaRole = Literal["system", "user", "assistant", "tool"]
|
||||
OllamaDoneReason = Literal["stop", "length", "tool_call", "error"]
|
||||
|
||||
|
||||
class OllamaToolFunction(BaseModel, frozen=True):
|
||||
name: str
|
||||
arguments: dict[str, Any] | str
|
||||
index: int | None = None
|
||||
|
||||
|
||||
class OllamaToolCall(BaseModel, frozen=True):
|
||||
id: str | None = None
|
||||
type: Literal["function"] | None = None
|
||||
function: OllamaToolFunction
|
||||
|
||||
|
||||
class OllamaMessage(BaseModel, frozen=True):
|
||||
role: OllamaRole
|
||||
content: str | None = None
|
||||
thinking: str | None = None
|
||||
tool_calls: list[OllamaToolCall] | None = None
|
||||
name: str | None = None
|
||||
tool_name: str | None = None
|
||||
images: list[str] | None = None
|
||||
|
||||
|
||||
class OllamaOptions(BaseModel, frozen=True):
|
||||
num_predict: int | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
class OllamaChatRequest(BaseModel, frozen=True):
|
||||
model: ModelId
|
||||
messages: list[OllamaMessage]
|
||||
stream: bool = True
|
||||
options: OllamaOptions | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
format: Literal["json"] | dict[str, Any] | None = None
|
||||
keep_alive: str | int | None = None
|
||||
think: bool | None = None
|
||||
|
||||
|
||||
class OllamaGenerateRequest(BaseModel, frozen=True):
|
||||
model: ModelId
|
||||
prompt: str = ""
|
||||
system: str | None = None
|
||||
stream: bool = True
|
||||
options: OllamaOptions | None = None
|
||||
format: Literal["json"] | dict[str, Any] | None = None
|
||||
keep_alive: str | int | None = None
|
||||
think: bool | None = None
|
||||
raw: bool = False
|
||||
|
||||
|
||||
class OllamaGenerateResponse(BaseModel, frozen=True, strict=True):
|
||||
model: str
|
||||
created_at: str = Field(
|
||||
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
)
|
||||
response: str
|
||||
thinking: str | None = None
|
||||
done: bool
|
||||
done_reason: OllamaDoneReason | None = None
|
||||
total_duration: int | None = None
|
||||
load_duration: int | None = None
|
||||
prompt_eval_count: int | None = None
|
||||
prompt_eval_duration: int | None = None
|
||||
eval_count: int | None = None
|
||||
eval_duration: int | None = None
|
||||
|
||||
|
||||
class OllamaShowRequest(BaseModel, frozen=True):
|
||||
name: str
|
||||
verbose: bool | None = None
|
||||
|
||||
|
||||
class OllamaChatResponse(BaseModel, frozen=True, strict=True):
|
||||
model: str
|
||||
created_at: str = Field(
|
||||
default_factory=lambda: time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
|
||||
)
|
||||
message: OllamaMessage
|
||||
done: bool
|
||||
done_reason: OllamaDoneReason | None = None
|
||||
total_duration: int | None = None
|
||||
load_duration: int | None = None
|
||||
prompt_eval_count: int | None = None
|
||||
prompt_eval_duration: int | None = None
|
||||
eval_count: int | None = None
|
||||
eval_duration: int | None = None
|
||||
|
||||
|
||||
class OllamaModelDetails(BaseModel, frozen=True, strict=True):
|
||||
format: str | None = None
|
||||
family: str | None = None
|
||||
parameter_size: str | None = None
|
||||
quantization_level: str | None = None
|
||||
|
||||
|
||||
class OllamaModelTag(BaseModel, frozen=True, strict=True):
|
||||
name: str
|
||||
model: str | None = None
|
||||
modified_at: str | None = None
|
||||
size: int | None = None
|
||||
digest: str | None = None
|
||||
details: OllamaModelDetails | None = None
|
||||
|
||||
|
||||
class OllamaTagsResponse(BaseModel, frozen=True, strict=True):
|
||||
models: list[OllamaModelTag]
|
||||
|
||||
|
||||
class OllamaShowResponse(BaseModel, frozen=True, strict=True):
|
||||
modelfile: str | None = None
|
||||
parameters: str | None = None
|
||||
template: str | None = None
|
||||
details: OllamaModelDetails | None = None
|
||||
model_info: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class OllamaPsModel(BaseModel, frozen=True, strict=True):
|
||||
name: str
|
||||
model: str
|
||||
size: int
|
||||
digest: str | None = None
|
||||
details: OllamaModelDetails | None = None
|
||||
expires_at: str | None = None
|
||||
size_vram: int | None = None
|
||||
|
||||
|
||||
class OllamaPsResponse(BaseModel, frozen=True, strict=True):
|
||||
models: list[OllamaPsModel]
|
||||
Reference in New Issue
Block a user