mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-29 00:01:15 -05:00
Compare commits
1 Commits
aiohttp
...
sami/fix-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87fb5febe1 |
@@ -5,7 +5,7 @@
|
||||
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
|
||||
</picture>
|
||||
|
||||
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
|
||||
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
|
||||
@@ -225,7 +225,7 @@ private final class ExoUpdaterDelegate: NSObject, SPUUpdaterDelegate {
|
||||
}
|
||||
}
|
||||
|
||||
private func showNotification(title: String, body: String) {
|
||||
nonisolated private func showNotification(title: String, body: String) {
|
||||
let center = UNUserNotificationCenter.current()
|
||||
let content = UNMutableNotificationContent()
|
||||
content.title = title
|
||||
|
||||
@@ -241,11 +241,11 @@ enum NetworkSetupHelper {
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
|
||||
# Switch back to Automatic network location
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
networksetup -switchtolocation Automatic >/dev/null 2>&1 || true
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
networksetup -listlocations | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
networksetup -listlocations 2>/dev/null | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo >/dev/null 2>&1 || true
|
||||
} || true
|
||||
|
||||
# Re-enable any Thunderbolt Bridge service if it exists
|
||||
@@ -255,12 +255,12 @@ enum NetworkSetupHelper {
|
||||
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
|
||||
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
|
||||
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
|
||||
')
|
||||
') || true
|
||||
[ -z "$tb_devices" ] && return 0
|
||||
|
||||
# For each bridge device, check if it contains Thunderbolt interfaces
|
||||
for bridge in bridge0 bridge1 bridge2; do
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}') || true
|
||||
[ -z "$members" ] && continue
|
||||
|
||||
for tb_dev in $tb_devices; do
|
||||
@@ -269,7 +269,7 @@ enum NetworkSetupHelper {
|
||||
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
|
||||
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
|
||||
/Device:/ && $0 ~ dev { print svc; exit }
|
||||
')
|
||||
') || true
|
||||
if [ -n "$service_name" ]; then
|
||||
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
|
||||
return 0
|
||||
@@ -277,8 +277,9 @@ enum NetworkSetupHelper {
|
||||
fi
|
||||
done
|
||||
done
|
||||
return 0
|
||||
}
|
||||
find_and_enable_thunderbolt_bridge
|
||||
find_and_enable_thunderbolt_bridge || true
|
||||
|
||||
echo "EXO network components removed successfully"
|
||||
"""
|
||||
|
||||
@@ -127,21 +127,24 @@ final class ThunderboltBridgeService: ObservableObject {
|
||||
|
||||
// 2. Request specific network configuration rights
|
||||
let rightName = "system.services.systemconfiguration.network"
|
||||
var item = AuthorizationItem(
|
||||
name: rightName,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
var rights = AuthorizationRights(count: 1, items: &item)
|
||||
|
||||
status = AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
status = rightName.withCString { nameCString in
|
||||
var item = AuthorizationItem(
|
||||
name: nameCString,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
return withUnsafeMutablePointer(to: &item) { itemPointer in
|
||||
var rights = AuthorizationRights(count: 1, items: itemPointer)
|
||||
return AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
}
|
||||
}
|
||||
guard status == errAuthorizationSuccess else {
|
||||
if status == errAuthorizationCanceled {
|
||||
throw ThunderboltBridgeError.authorizationCanceled
|
||||
|
||||
@@ -6,6 +6,8 @@ readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"aiofiles>=24.1.0",
|
||||
"aiohttp>=3.12.14",
|
||||
"types-aiofiles>=24.1.0.20250708",
|
||||
"pydantic>=2.11.7",
|
||||
"fastapi>=0.116.1",
|
||||
"filelock>=3.18.0",
|
||||
@@ -15,9 +17,9 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.5",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -61,7 +63,6 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
|
||||
@@ -121,7 +121,6 @@ class DownloadCoordinator:
|
||||
def _start_download_task(
|
||||
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
logger.warning("starting download for {shard}")
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
# Emit ongoing status
|
||||
|
||||
@@ -8,13 +8,13 @@ import traceback
|
||||
from collections.abc import Awaitable
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, cast
|
||||
from typing import Callable, Literal
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import aiohttp
|
||||
import certifi
|
||||
import httpx
|
||||
from huggingface_hub import (
|
||||
snapshot_download, # pyright: ignore[reportUnknownVariableType]
|
||||
)
|
||||
@@ -176,7 +176,7 @@ async def fetch_file_list_with_cache(
|
||||
# Fetch failed - try cache fallback
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"{type(e).__name__}: Failed to fetch file list for {model_id}, using cached data"
|
||||
f"Failed to fetch file list for {model_id}, using cached data: {e}"
|
||||
)
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
@@ -196,7 +196,7 @@ async def fetch_file_list_with_retry(
|
||||
except Exception as e:
|
||||
if attempt == n_attempts - 1:
|
||||
raise e
|
||||
await asyncio.sleep(min(16, 0.5 * float(2.0 ** int(attempt))))
|
||||
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
|
||||
raise Exception(
|
||||
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
|
||||
)
|
||||
@@ -211,25 +211,26 @@ async def _fetch_file_list(
|
||||
headers = await get_download_headers()
|
||||
async with (
|
||||
create_http_session(timeout_profile="short") as session,
|
||||
session.get(url, headers=headers) as response,
|
||||
):
|
||||
response = await session.get(url, headers=headers)
|
||||
if response.status_code in [401, 403]:
|
||||
msg = await _build_auth_error_message(response.status_code, model_id)
|
||||
if response.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(response.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to fetch file list: {response.status_code}")
|
||||
|
||||
data = TypeAdapter(list[FileListEntry]).validate_json(response.text)
|
||||
files: list[FileListEntry] = []
|
||||
for item in data:
|
||||
if item.type == "file":
|
||||
files.append(FileListEntry.model_validate(item))
|
||||
elif item.type == "directory" and recursive:
|
||||
subfiles = await _fetch_file_list(
|
||||
model_id, revision, item.path, recursive
|
||||
)
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
if response.status == 200:
|
||||
data_json = await response.text()
|
||||
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
|
||||
files: list[FileListEntry] = []
|
||||
for item in data:
|
||||
if item.type == "file":
|
||||
files.append(FileListEntry.model_validate(item))
|
||||
elif item.type == "directory" and recursive:
|
||||
subfiles = await _fetch_file_list(
|
||||
model_id, revision, item.path, recursive
|
||||
)
|
||||
files.extend(subfiles)
|
||||
return files
|
||||
else:
|
||||
raise Exception(f"Failed to fetch file list: {response.status}")
|
||||
|
||||
|
||||
async def get_download_headers() -> dict[str, str]:
|
||||
@@ -237,29 +238,34 @@ async def get_download_headers() -> dict[str, str]:
|
||||
|
||||
|
||||
def create_http_session(
|
||||
auto_decompress: bool = False,
|
||||
timeout_profile: Literal["short", "long"] = "long",
|
||||
) -> httpx.AsyncClient:
|
||||
) -> aiohttp.ClientSession:
|
||||
if timeout_profile == "short":
|
||||
total_timeout = 30
|
||||
connect_timeout = 10
|
||||
read_timeout = 30
|
||||
sock_read_timeout = 30
|
||||
sock_connect_timeout = 10
|
||||
else:
|
||||
total_timeout = 1800
|
||||
connect_timeout = 60
|
||||
read_timeout = 1800
|
||||
sock_read_timeout = 1800
|
||||
sock_connect_timeout = 60
|
||||
|
||||
ssl_context = ssl.create_default_context(
|
||||
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
|
||||
)
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
|
||||
# default here is to load env vars
|
||||
return httpx.AsyncClient(
|
||||
verify=ssl_context,
|
||||
timeout=httpx.Timeout(
|
||||
return aiohttp.ClientSession(
|
||||
auto_decompress=auto_decompress,
|
||||
connector=connector,
|
||||
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=total_timeout,
|
||||
connect=connect_timeout,
|
||||
read=read_timeout,
|
||||
write=total_timeout,
|
||||
pool=total_timeout,
|
||||
sock_read=sock_read_timeout,
|
||||
sock_connect=sock_connect_timeout,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -286,28 +292,26 @@ async def file_meta(
|
||||
headers = await get_download_headers()
|
||||
async with (
|
||||
create_http_session(timeout_profile="short") as session,
|
||||
session.stream("HEAD", url, headers=headers) as r,
|
||||
session.head(url, headers=headers) as r,
|
||||
):
|
||||
if r.status_code == 307:
|
||||
if r.status == 307:
|
||||
# On redirect, only trust Hugging Face's x-linked-* headers.
|
||||
x_linked_size = cast(str | None, r.headers.get("x-linked-size"))
|
||||
x_linked_etag = cast(str | None, r.headers.get("x-linked-etag"))
|
||||
x_linked_size = r.headers.get("x-linked-size")
|
||||
x_linked_etag = r.headers.get("x-linked-etag")
|
||||
if x_linked_size and x_linked_etag:
|
||||
content_length = int(x_linked_size)
|
||||
etag = trim_etag(x_linked_etag)
|
||||
return content_length, etag
|
||||
# Otherwise, follow the redirect to get authoritative size/hash
|
||||
redirected_location = cast(str | None, r.headers.get("location"))
|
||||
redirected_location = r.headers.get("location")
|
||||
return await file_meta(model_id, revision, path, redirected_location)
|
||||
if r.status_code in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status_code, model_id)
|
||||
if r.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
content_length = cast(
|
||||
str | None,
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length"),
|
||||
content_length = int(
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
||||
)
|
||||
content_length = 0 if content_length is None else int(content_length)
|
||||
etag = cast(str | None, r.headers.get("x-linked-etag") or r.headers.get("etag"))
|
||||
etag = r.headers.get("x-linked-etag") or r.headers.get("etag")
|
||||
assert content_length > 0, f"No content length for {url}"
|
||||
assert etag is not None, f"No remote hash for {url}"
|
||||
etag = trim_etag(etag)
|
||||
@@ -336,7 +340,7 @@ async def download_file_with_retry(
|
||||
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
logger.error(traceback.format_exc())
|
||||
await asyncio.sleep(min(16, 0.5 * (2.0**attempt)))
|
||||
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
|
||||
raise Exception(
|
||||
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
|
||||
)
|
||||
@@ -349,7 +353,6 @@ async def _download_file(
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
) -> Path:
|
||||
logger.warning(f"downloading {path} from {model_id} to {target_dir}")
|
||||
target_path = target_dir / path
|
||||
|
||||
if await aios.path.exists(target_path):
|
||||
@@ -389,20 +392,20 @@ async def _download_file(
|
||||
n_read = resume_byte_pos or 0
|
||||
async with (
|
||||
create_http_session(timeout_profile="long") as session,
|
||||
session.stream("GET", url, headers=headers, follow_redirects=True) as r,
|
||||
session.get(url, headers=headers) as r,
|
||||
):
|
||||
if r.status_code == 404:
|
||||
if r.status == 404:
|
||||
raise FileNotFoundError(f"File not found: {url}")
|
||||
if r.status_code in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status_code, model_id)
|
||||
if r.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
assert r.status_code in [200, 206], (
|
||||
f"Failed to download {path} from {url}: {r.status_code}"
|
||||
assert r.status in [200, 206], (
|
||||
f"Failed to download {path} from {url}: {r.status}"
|
||||
)
|
||||
async with aiofiles.open(
|
||||
partial_path, "ab" if resume_byte_pos else "wb"
|
||||
) as f:
|
||||
async for chunk in r.aiter_bytes(8 * 1024 * 1024):
|
||||
while chunk := await r.content.read(8 * 1024 * 1024):
|
||||
n_read = n_read + (await f.write(chunk))
|
||||
on_progress(n_read, length, False)
|
||||
|
||||
|
||||
@@ -168,8 +168,7 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
yield await task
|
||||
# TODO: except Exception
|
||||
except Exception as e:
|
||||
task.cancel()
|
||||
logger.opt(exception=e).error("Error downloading shard")
|
||||
logger.error("Error downloading shard:", e)
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
@@ -7,14 +7,7 @@ import tomlkit
|
||||
from anyio import Path, open_file
|
||||
from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
BaseModel,
|
||||
Field,
|
||||
PositiveInt,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic import BaseModel, Field, PositiveInt, field_validator
|
||||
|
||||
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
|
||||
from exo.shared.types.common import ModelId
|
||||
@@ -128,14 +121,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"kimi-k2.5": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2.5"),
|
||||
storage_size=Memory.from_gb(617),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.1
|
||||
"llama-3.1-8b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
@@ -718,18 +703,15 @@ if EXO_ENABLE_IMAGE_MODELS:
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
architectures: list[str] | None = None
|
||||
# Common field names for number of layers across different architectures
|
||||
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
num_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layer: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
layer_count: int = Field(
|
||||
validation_alias=AliasChoices(
|
||||
"num_hidden_layers",
|
||||
"num_layers",
|
||||
"n_layer",
|
||||
"n_layers",
|
||||
"num_decoder_layers",
|
||||
"decoder_layers",
|
||||
)
|
||||
)
|
||||
architectures: list[str] | None = None
|
||||
|
||||
@property
|
||||
def supports_tensor(self) -> bool:
|
||||
@@ -744,27 +726,25 @@ class ConfigData(BaseModel):
|
||||
["GptOssForCausalLM"],
|
||||
]
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def defer_to_text_config(cls, data: dict[str, Any]):
|
||||
text_config = data.get("text_config")
|
||||
if text_config is None:
|
||||
return data
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
# Check common field names for layer count
|
||||
layer_fields = [
|
||||
self.num_hidden_layers,
|
||||
self.num_layers,
|
||||
self.n_layer,
|
||||
self.n_layers,
|
||||
self.num_decoder_layers,
|
||||
self.decoder_layers,
|
||||
]
|
||||
|
||||
for field in [
|
||||
"architectures",
|
||||
"hidden_size",
|
||||
"num_hidden_layers",
|
||||
"num_layers",
|
||||
"n_layer",
|
||||
"n_layers",
|
||||
"num_decoder_layers",
|
||||
"decoder_layers",
|
||||
]:
|
||||
if (val := text_config.get(field)) is not None: # pyright: ignore[reportAny]
|
||||
data[field] = val
|
||||
for layer_count in layer_fields:
|
||||
if layer_count is not None:
|
||||
return layer_count
|
||||
|
||||
return data
|
||||
raise ValueError(
|
||||
f"No layer count found in config.json: {self.model_dump_json()}"
|
||||
)
|
||||
|
||||
|
||||
async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import time
|
||||
from collections.abc import Hashable
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
K = TypeVar("K")
|
||||
|
||||
|
||||
class KeyedBackoff[K: Hashable]:
|
||||
class KeyedBackoff(Generic[K]):
|
||||
"""Tracks exponential backoff state per key."""
|
||||
|
||||
def __init__(self, base: float = 0.5, cap: float = 10.0):
|
||||
|
||||
@@ -23,7 +23,6 @@ from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
|
||||
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
@@ -345,7 +344,7 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model, KimiK25Model)):
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -454,7 +453,7 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(
|
||||
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel, KimiK25Model)
|
||||
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
|
||||
) and hasattr(inner_model_instance, "num_layers"):
|
||||
logger.info(
|
||||
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
|
||||
@@ -165,11 +165,12 @@ def mlx_distributed_init(
|
||||
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_JACCL_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
@@ -258,10 +259,10 @@ def shard_and_load(
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
|
||||
# Estimate timeout based on model size (5x default for large queued workloads)
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
|
||||
# Estimate timeout based on model size
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb
|
||||
timeout_seconds = base_timeout + model_size_gb / 5
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
@@ -338,35 +339,8 @@ def load_tokenizer_for_model_id(
|
||||
|
||||
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
|
||||
if "kimi-k2" in model_id_lower:
|
||||
import importlib.util
|
||||
import types
|
||||
|
||||
sys.path.insert(0, str(model_path))
|
||||
|
||||
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
|
||||
tool_decl_path = model_path / "tool_declaration_ts.py"
|
||||
if tool_decl_path.exists():
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"tool_declaration_ts", tool_decl_path
|
||||
)
|
||||
if spec and spec.loader:
|
||||
tool_decl_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["tool_declaration_ts"] = tool_decl_module
|
||||
spec.loader.exec_module(tool_decl_module)
|
||||
|
||||
# Load tokenization_kimi with patched source (convert relative to absolute import)
|
||||
tok_path = model_path / "tokenization_kimi.py"
|
||||
source = tok_path.read_text()
|
||||
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
|
||||
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
|
||||
if spec:
|
||||
tok_module = types.ModuleType("tokenization_kimi")
|
||||
tok_module.__file__ = str(tok_path)
|
||||
sys.modules["tokenization_kimi"] = tok_module
|
||||
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
|
||||
TikTokenTokenizer = tok_module.TikTokenTokenizer # type: ignore[attr-defined] # noqa: N806
|
||||
else:
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
|
||||
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user