mirror of
https://github.com/exo-explore/exo.git
synced 2025-12-23 22:27:50 -05:00
move everything under exo module
This commit is contained in:
@@ -2,11 +2,11 @@
|
||||
# They are prompting the cluster to generate a response to a question.
|
||||
# The cluster is given the question, and the user is given the response.
|
||||
|
||||
from inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
||||
from inference.shard import Shard
|
||||
from networking.peer_handle import PeerHandle
|
||||
from networking.grpc.grpc_peer_handle import GRPCPeerHandle
|
||||
from topology.device_capabilities import DeviceCapabilities
|
||||
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
||||
from exo.inference.shard import Shard
|
||||
from exo.networking.peer_handle import PeerHandle
|
||||
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
|
||||
from exo.topology.device_capabilities import DeviceCapabilities
|
||||
from typing import List
|
||||
import asyncio
|
||||
import argparse
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
# They are prompting the cluster to generate a response to a question.
|
||||
# The cluster is given the question, and the user is given the response.
|
||||
|
||||
from inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
||||
from inference.shard import Shard
|
||||
from networking.peer_handle import PeerHandle
|
||||
from networking.grpc.grpc_peer_handle import GRPCPeerHandle
|
||||
from topology.device_capabilities import DeviceCapabilities
|
||||
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
|
||||
from exo.inference.shard import Shard
|
||||
from exo.networking.peer_handle import PeerHandle
|
||||
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
|
||||
from exo.topology.device_capabilities import DeviceCapabilities
|
||||
from typing import List
|
||||
import asyncio
|
||||
import argparse
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import mlx.core as mx
|
||||
from inference.mlx.sharded_model import StatefulShardedModel
|
||||
from inference.mlx.sharded_utils import load_shard
|
||||
from inference.shard import Shard
|
||||
from exo.inference.mlx.sharded_model import StatefulShardedModel
|
||||
from exo.inference.mlx.sharded_utils import load_shard
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
shard_full = Shard("llama", 0, 31, 32)
|
||||
shard1 = Shard("llama", 0, 12, 32)
|
||||
@@ -1,5 +1,5 @@
|
||||
from inference.shard import Shard
|
||||
from inference.mlx.sharded_model import StatefulShardedModel
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.mlx.sharded_model import StatefulShardedModel
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from typing import Optional
|
||||
@@ -1,7 +1,7 @@
|
||||
from inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from inference.inference_engine import InferenceEngine
|
||||
from inference.shard import Shard
|
||||
from inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
|
||||
import numpy as np
|
||||
|
||||
# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.
|
||||
@@ -4,12 +4,12 @@ from typing import List
|
||||
import json, argparse, random, time
|
||||
import tiktoken
|
||||
from tiktoken.load import load_tiktoken_bpe
|
||||
from inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
|
||||
from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
|
||||
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
|
||||
from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
|
||||
from tinygrad.helpers import Profiling, Timing, DEBUG, colored, fetch, tqdm
|
||||
from inference.shard import Shard
|
||||
from inference.inference_engine import InferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
from exo.inference.inference_engine import InferenceEngine
|
||||
import numpy as np
|
||||
|
||||
MODEL_PARAMS = {
|
||||
@@ -6,7 +6,7 @@ from typing import List, Dict
|
||||
from ..discovery import Discovery
|
||||
from ..peer_handle import PeerHandle
|
||||
from .grpc_peer_handle import GRPCPeerHandle
|
||||
from topology.device_capabilities import DeviceCapabilities, device_capabilities
|
||||
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities
|
||||
|
||||
class GRPCDiscovery(Discovery):
|
||||
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities=None):
|
||||
@@ -7,9 +7,9 @@ from . import node_service_pb2
|
||||
from . import node_service_pb2_grpc
|
||||
|
||||
from ..peer_handle import PeerHandle
|
||||
from inference.shard import Shard
|
||||
from topology.topology import Topology
|
||||
from topology.device_capabilities import DeviceCapabilities
|
||||
from exo.inference.shard import Shard
|
||||
from exo.topology.topology import Topology
|
||||
from exo.topology.device_capabilities import DeviceCapabilities
|
||||
|
||||
class GRPCPeerHandle(PeerHandle):
|
||||
def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):
|
||||
@@ -4,9 +4,9 @@ import numpy as np
|
||||
|
||||
from . import node_service_pb2
|
||||
from . import node_service_pb2_grpc
|
||||
from inference.shard import Shard
|
||||
from exo.inference.shard import Shard
|
||||
|
||||
from orchestration import Node
|
||||
from exo.orchestration import Node
|
||||
|
||||
import uuid
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
from inference.shard import Shard
|
||||
from topology.device_capabilities import DeviceCapabilities
|
||||
from topology.topology import Topology
|
||||
from exo.inference.shard import Shard
|
||||
from exo.topology.device_capabilities import DeviceCapabilities
|
||||
from exo.topology.topology import Topology
|
||||
|
||||
class PeerHandle(ABC):
|
||||
@abstractmethod
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from inference.shard import Shard
|
||||
from topology.topology import Topology
|
||||
from exo.inference.shard import Shard
|
||||
from exo.topology.topology import Topology
|
||||
|
||||
class Node(ABC):
|
||||
@abstractmethod
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import List, Dict, Optional, Callable, Tuple
|
||||
import numpy as np
|
||||
from networking import Discovery, PeerHandle, Server
|
||||
from inference.inference_engine import InferenceEngine, Shard
|
||||
from exo.networking import Discovery, PeerHandle, Server
|
||||
from exo.inference.inference_engine import InferenceEngine, Shard
|
||||
from .node import Node
|
||||
from topology.topology import Topology
|
||||
from topology.device_capabilities import device_capabilities
|
||||
from topology.partitioning_strategy import PartitioningStrategy
|
||||
from topology.partitioning_strategy import Partition
|
||||
from exo.topology.topology import Topology
|
||||
from exo.topology.device_capabilities import device_capabilities
|
||||
from exo.topology.partitioning_strategy import PartitioningStrategy
|
||||
from exo.topology.partitioning_strategy import Partition
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest.mock import Mock, AsyncMock
|
||||
import numpy as np
|
||||
|
||||
from .standard_node import StandardNode
|
||||
from networking.peer_handle import PeerHandle
|
||||
from exo.networking.peer_handle import PeerHandle
|
||||
|
||||
class TestNode(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
0
exo/topology/__init__.py
Normal file
0
exo/topology/__init__.py
Normal file
@@ -1,8 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from dataclasses import dataclass
|
||||
from inference.shard import Shard
|
||||
from networking.peer_handle import PeerHandle
|
||||
from exo.inference.shard import Shard
|
||||
from exo.networking.peer_handle import PeerHandle
|
||||
from .topology import Topology
|
||||
|
||||
# Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
from .partitioning_strategy import PartitioningStrategy
|
||||
from inference.shard import Shard
|
||||
from exo.inference.shard import Shard
|
||||
from .topology import Topology
|
||||
from .partitioning_strategy import Partition
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from topology.device_capabilities import mac_device_capabilities, DeviceCapabilities
|
||||
from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities
|
||||
|
||||
class TestMacDeviceCapabilities(unittest.TestCase):
|
||||
@patch('subprocess.check_output')
|
||||
12
main.py
12
main.py
@@ -3,12 +3,12 @@ import asyncio
|
||||
import signal
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from orchestration.standard_node import StandardNode
|
||||
from networking.grpc.grpc_server import GRPCServer
|
||||
from inference.mlx.sharded_inference_engine import MLXFixedShardInferenceEngine
|
||||
from inference.shard import Shard
|
||||
from networking.grpc.grpc_discovery import GRPCDiscovery
|
||||
from topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
||||
from exo.orchestration.standard_node import StandardNode
|
||||
from exo.networking.grpc.grpc_server import GRPCServer
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXFixedShardInferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
|
||||
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
||||
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
||||
|
||||
@@ -4,12 +4,12 @@ import signal
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from typing import List
|
||||
from orchestration.standard_node import StandardNode
|
||||
from networking.grpc.grpc_server import GRPCServer
|
||||
from inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from inference.shard import Shard
|
||||
from networking.grpc.grpc_discovery import GRPCDiscovery
|
||||
from topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
||||
from exo.orchestration.standard_node import StandardNode
|
||||
from exo.networking.grpc.grpc_server import GRPCServer
|
||||
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
|
||||
from exo.inference.shard import Shard
|
||||
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
|
||||
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
|
||||
|
||||
# parse args
|
||||
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")
|
||||
|
||||
Reference in New Issue
Block a user