move everything under exo module

This commit is contained in:
Alex Cheema
2024-07-14 21:00:37 -07:00
parent c851644a43
commit 5bbde22a23
42 changed files with 56 additions and 56 deletions

View File

@@ -2,11 +2,11 @@
# They are prompting the cluster to generate a response to a question.
# The cluster is given the question, and the user is given the response.
from inference.mlx.sharded_utils import get_model_path, load_tokenizer
from inference.shard import Shard
from networking.peer_handle import PeerHandle
from networking.grpc.grpc_peer_handle import GRPCPeerHandle
from topology.device_capabilities import DeviceCapabilities
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
from exo.inference.shard import Shard
from exo.networking.peer_handle import PeerHandle
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.topology.device_capabilities import DeviceCapabilities
from typing import List
import asyncio
import argparse

View File

@@ -2,11 +2,11 @@
# They are prompting the cluster to generate a response to a question.
# The cluster is given the question, and the user is given the response.
from inference.mlx.sharded_utils import get_model_path, load_tokenizer
from inference.shard import Shard
from networking.peer_handle import PeerHandle
from networking.grpc.grpc_peer_handle import GRPCPeerHandle
from topology.device_capabilities import DeviceCapabilities
from exo.inference.mlx.sharded_utils import get_model_path, load_tokenizer
from exo.inference.shard import Shard
from exo.networking.peer_handle import PeerHandle
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.topology.device_capabilities import DeviceCapabilities
from typing import List
import asyncio
import argparse

View File

View File

@@ -1,7 +1,7 @@
import mlx.core as mx
from inference.mlx.sharded_model import StatefulShardedModel
from inference.mlx.sharded_utils import load_shard
from inference.shard import Shard
from exo.inference.mlx.sharded_model import StatefulShardedModel
from exo.inference.mlx.sharded_utils import load_shard
from exo.inference.shard import Shard
shard_full = Shard("llama", 0, 31, 32)
shard1 = Shard("llama", 0, 12, 32)

View File

@@ -1,5 +1,5 @@
from inference.shard import Shard
from inference.mlx.sharded_model import StatefulShardedModel
from exo.inference.shard import Shard
from exo.inference.mlx.sharded_model import StatefulShardedModel
import mlx.core as mx
import mlx.nn as nn
from typing import Optional

View File

@@ -1,7 +1,7 @@
from inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
from inference.inference_engine import InferenceEngine
from inference.shard import Shard
from inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
from exo.inference.inference_engine import InferenceEngine
from exo.inference.shard import Shard
from exo.inference.tinygrad.inference import TinygradDynamicShardInferenceEngine
import numpy as np
# An inference engine should work the same for any number of Shards, as long as the Shards are continuous.

View File

@@ -4,12 +4,12 @@ from typing import List
import json, argparse, random, time
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
from exo.inference.tinygrad.models.llama import Transformer, convert_from_huggingface, fix_bf16
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
from tinygrad import Tensor, dtypes, nn, Context, Device, GlobalCounters
from tinygrad.helpers import Profiling, Timing, DEBUG, colored, fetch, tqdm
from inference.shard import Shard
from inference.inference_engine import InferenceEngine
from exo.inference.shard import Shard
from exo.inference.inference_engine import InferenceEngine
import numpy as np
MODEL_PARAMS = {

View File

@@ -6,7 +6,7 @@ from typing import List, Dict
from ..discovery import Discovery
from ..peer_handle import PeerHandle
from .grpc_peer_handle import GRPCPeerHandle
from topology.device_capabilities import DeviceCapabilities, device_capabilities
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities
class GRPCDiscovery(Discovery):
def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities=None):

View File

@@ -7,9 +7,9 @@ from . import node_service_pb2
from . import node_service_pb2_grpc
from ..peer_handle import PeerHandle
from inference.shard import Shard
from topology.topology import Topology
from topology.device_capabilities import DeviceCapabilities
from exo.inference.shard import Shard
from exo.topology.topology import Topology
from exo.topology.device_capabilities import DeviceCapabilities
class GRPCPeerHandle(PeerHandle):
def __init__(self, id: str, address: str, device_capabilities: DeviceCapabilities):

View File

@@ -4,9 +4,9 @@ import numpy as np
from . import node_service_pb2
from . import node_service_pb2_grpc
from inference.shard import Shard
from exo.inference.shard import Shard
from orchestration import Node
from exo.orchestration import Node
import uuid

View File

@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
from typing import Optional, Tuple
import numpy as np
from inference.shard import Shard
from topology.device_capabilities import DeviceCapabilities
from topology.topology import Topology
from exo.inference.shard import Shard
from exo.topology.device_capabilities import DeviceCapabilities
from exo.topology.topology import Topology
class PeerHandle(ABC):
@abstractmethod

View File

@@ -1,8 +1,8 @@
from typing import Optional, Tuple
import numpy as np
from abc import ABC, abstractmethod
from inference.shard import Shard
from topology.topology import Topology
from exo.inference.shard import Shard
from exo.topology.topology import Topology
class Node(ABC):
@abstractmethod

View File

@@ -1,12 +1,12 @@
from typing import List, Dict, Optional, Callable, Tuple
import numpy as np
from networking import Discovery, PeerHandle, Server
from inference.inference_engine import InferenceEngine, Shard
from exo.networking import Discovery, PeerHandle, Server
from exo.inference.inference_engine import InferenceEngine, Shard
from .node import Node
from topology.topology import Topology
from topology.device_capabilities import device_capabilities
from topology.partitioning_strategy import PartitioningStrategy
from topology.partitioning_strategy import Partition
from exo.topology.topology import Topology
from exo.topology.device_capabilities import device_capabilities
from exo.topology.partitioning_strategy import PartitioningStrategy
from exo.topology.partitioning_strategy import Partition
import asyncio
import uuid

View File

@@ -3,7 +3,7 @@ from unittest.mock import Mock, AsyncMock
import numpy as np
from .standard_node import StandardNode
from networking.peer_handle import PeerHandle
from exo.networking.peer_handle import PeerHandle
class TestNode(unittest.IsolatedAsyncioTestCase):
def setUp(self):

0
exo/topology/__init__.py Normal file
View File

View File

@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from typing import List, Optional
from dataclasses import dataclass
from inference.shard import Shard
from networking.peer_handle import PeerHandle
from exo.inference.shard import Shard
from exo.networking.peer_handle import PeerHandle
from .topology import Topology
# Partitions shard-space into pieces of contiguous shards, represented by floating point range [start, end) between 0 and 1

View File

@@ -1,6 +1,6 @@
from typing import List
from .partitioning_strategy import PartitioningStrategy
from inference.shard import Shard
from exo.inference.shard import Shard
from .topology import Topology
from .partitioning_strategy import Partition

View File

@@ -1,6 +1,6 @@
import unittest
from unittest.mock import patch
from topology.device_capabilities import mac_device_capabilities, DeviceCapabilities
from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities
class TestMacDeviceCapabilities(unittest.TestCase):
@patch('subprocess.check_output')

12
main.py
View File

@@ -3,12 +3,12 @@ import asyncio
import signal
import mlx.core as mx
import mlx.nn as nn
from orchestration.standard_node import StandardNode
from networking.grpc.grpc_server import GRPCServer
from inference.mlx.sharded_inference_engine import MLXFixedShardInferenceEngine
from inference.shard import Shard
from networking.grpc.grpc_discovery import GRPCDiscovery
from topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.orchestration.standard_node import StandardNode
from exo.networking.grpc.grpc_server import GRPCServer
from exo.inference.mlx.sharded_inference_engine import MLXFixedShardInferenceEngine
from exo.inference.shard import Shard
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")

View File

@@ -4,12 +4,12 @@ import signal
import mlx.core as mx
import mlx.nn as nn
from typing import List
from orchestration.standard_node import StandardNode
from networking.grpc.grpc_server import GRPCServer
from inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
from inference.shard import Shard
from networking.grpc.grpc_discovery import GRPCDiscovery
from topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.orchestration.standard_node import StandardNode
from exo.networking.grpc.grpc_server import GRPCServer
from exo.inference.mlx.sharded_inference_engine import MLXDynamicShardInferenceEngine
from exo.inference.shard import Shard
from exo.networking.grpc.grpc_discovery import GRPCDiscovery
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
# parse args
parser = argparse.ArgumentParser(description="Initialize GRPC Discovery")