From fc963c42ce83b2beea2cb5bf747a5cb7b59ded74 Mon Sep 17 00:00:00 2001 From: Aaron Pham <29749331+aarnphm@users.noreply.github.com> Date: Sun, 16 Jul 2023 01:52:21 -0400 Subject: [PATCH] fix: build isolation (#116) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- DEVELOPMENT.md | 6 ++ hatch.toml | 37 ++++---- pyproject.toml | 10 +-- src/openllm/_llm.py | 5 +- src/openllm/_package.py | 7 +- src/openllm/_strategies.py | 89 +++++++++++-------- src/openllm/cli.py | 71 +++++++++++---- .../models/stablelm/modeling_stablelm.py | 15 ++-- src/openllm/testing.py | 2 +- .../flan_t5_test/test_flan_t5[container].json | 33 +++++++ .../flan_t5_test/test_flan_t5[local].json | 33 +++++++ .../opt_test/test_opt_125m[local].json | 34 +++++++ tools/dependencies.py | 5 -- typings/cuda/cuda.pyi | 1 + 14 files changed, 255 insertions(+), 93 deletions(-) create mode 100644 tests/models/__snapshots__/flan_t5_test/test_flan_t5[container].json create mode 100644 tests/models/__snapshots__/flan_t5_test/test_flan_t5[local].json create mode 100644 tests/models/__snapshots__/opt_test/test_opt_125m[local].json diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index ae5d9c44..da0de791 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -147,6 +147,12 @@ Run snapshot testing for model outputs: hatch run tests:models ``` +To update the snapshot, do the following: + +```bash +hatch run tests:snapshot-models +``` + ## Releasing a New Version To release a new version, use `./tools/run-release-action`. It requires `gh`, diff --git a/hatch.toml b/hatch.toml index 7fa7bdd2..84789f25 100644 --- a/hatch.toml +++ b/hatch.toml @@ -8,21 +8,7 @@ dependencies = [ "tomlkit", # NOTE: Using under ./tools/update-readme.py "markdown-it-py", - # NOTE: Tests strategies with Hypothesis and pytest, and snapshot testing with syrupy - "coverage[toml]>=6.5", - "filelock>=3.7.1", - "pytest", - "pytest-cov", - "pytest-mock", - "pytest-randomly", - "pytest-rerunfailures", - "pytest-asyncio>=0.21.0", - "pytest-xdist[psutil]", - "trustme", - "hypothesis", - "syrupy", ] -features = ['flan-t5'] [envs.default.scripts] changelog = "towncrier build --version main --draft" quality = [ @@ -37,15 +23,32 @@ setup = "pre-commit install" typing = "pre-commit run typecheck --all-files" watch-typing = "pyright {args:src/openllm} -w" [envs.tests] -extra-dependencies = [ +dependencies = [ # NOTE: interact with docker for container tests. "docker", + # NOTE: Tests strategies with Hypothesis and pytest, and snapshot testing with syrupy + "coverage[toml]>=6.5", + "filelock>=3.7.1", + "pytest", + "pytest-cov", + "pytest-mock", + "pytest-randomly", + "pytest-rerunfailures", + "pytest-asyncio>=0.21.0", + "pytest-xdist[psutil]", + "trustme", + "hypothesis", + "syrupy", ] +features = ['flan-t5', 'baichuan'] +skip-install = false +template = 'tests' [envs.tests.scripts] -_run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml" +_run_script = "pytest --cov --cov-report={env:COVERAGE_REPORT:term-missing} --cov-config=pyproject.toml -vv" distributed = "_run_script --reruns 5 --reruns-delay 3 --ignore tests/models -n 3 -r aR {args:tests}" -models = "_run_script -r aR {args:tests/models}" +models = "_run_script -s {args:tests/models}" python = "_run_script --reruns 5 --reruns-delay 3 --ignore tests/models -r aR {args:tests}" +snapshot-models = "_run_script -s --snapshot-update {args:tests/models}" [envs.tests.overrides] env.GITHUB_ACTIONS.env-vars = "COVERAGE_REPORT=" [envs.coverage] diff --git a/pyproject.toml b/pyproject.toml index e235df1b..dff28e25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,14 +73,14 @@ all = [ "openllm[mpt]", "openllm[starcoder]", "openllm[baichuan]", - "openllm[flan-t5]", - "openllm[openai]", + "openllm[ggml]", + "openllm[opt]", "openllm[gptq]", - "openllm[fine-tune]", + "openllm[flan-t5]", "openllm[agents]", "openllm[playground]", - "openllm[opt]", - "openllm[ggml]", + "openllm[openai]", + "openllm[fine-tune]", "openllm[vllm]", ] baichuan = ["cpm-kernels", "sentencepiece"] diff --git a/src/openllm/_llm.py b/src/openllm/_llm.py index 6574d1bf..d2a2d4d1 100644 --- a/src/openllm/_llm.py +++ b/src/openllm/_llm.py @@ -253,7 +253,10 @@ def resolve_peft_config_type(adapter_map: dict[str, str | None] | None): _reserved_namespace = {"config_class", "model", "tokenizer", "import_kwargs"} M = t.TypeVar("M", bound="transformers.PreTrainedModel") -T = t.TypeVar("T", bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer]") +T = t.TypeVar( + "T", + bound="t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]", +) def _default_post_init(self: LLM[t.Any, t.Any]): diff --git a/src/openllm/_package.py b/src/openllm/_package.py index 701f889d..d1e7cd3e 100644 --- a/src/openllm/_package.py +++ b/src/openllm/_package.py @@ -17,6 +17,7 @@ These utilities will stay internal, and its API can be changed or updated withou """ from __future__ import annotations import importlib.metadata +import inspect import logging import os import typing as t @@ -349,4 +350,8 @@ def create_bento( bento._fs.writetext(service_fs_path, script) - return bento.save(bento_store=_bento_store, model_store=_model_store) + signatures = inspect.signature(bento.save).parameters + if "model_store" in signatures: + return bento.save(bento_store=_bento_store, model_store=_model_store) + # backward arguments. `model_store` is added recently + return bento.save(bento_store=_bento_store) diff --git a/src/openllm/_strategies.py b/src/openllm/_strategies.py index 76ad3327..ec447dca 100644 --- a/src/openllm/_strategies.py +++ b/src/openllm/_strategies.py @@ -13,7 +13,6 @@ # limitations under the License. from __future__ import annotations -import functools import inspect import logging import math @@ -31,14 +30,11 @@ from bentoml._internal.resource import system_resources from bentoml._internal.runner.strategy import THREAD_ENVS from bentoml._internal.runner.strategy import Strategy -from .utils import LazyLoader from .utils import LazyType from .utils import ReprMixin if t.TYPE_CHECKING: - import torch - import bentoml ListIntStr = list[int | str] @@ -48,7 +44,6 @@ if t.TYPE_CHECKING: else: DynResource = Resource[t.List[str]] - torch = LazyLoader("torch", globals(), "torch") ListIntStr = list # NOTE: We need to do this so that overload can register @@ -135,26 +130,50 @@ def _from_system(cls: type[DynResource]) -> list[str]: It relies on torch.cuda implementation and in turns respect CUDA_VISIBLE_DEVICES. """ - if cls.resource_id == "amd.com/gpu": - if not psutil.LINUX: - warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL) - return [] - - # ROCm does not currently have the rocm_smi wheel. - # So we need to use the ctypes bindings directly. - # we don't want to use CLI because parsing is a pain. - sys.path.append("/opt/rocm/libexec/rocm_smi") - try: - # refers to https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py - from rsmiBindings import rocmsmi as rocmsmi - except (ModuleNotFoundError, ImportError): - # In this case the binary is not found, returning empty list - return [] - finally: - sys.path.remove("/opt/rocm/libexec/rocm_smi") visible_devices = _parse_visible_devices() if visible_devices is None: - return [str(i) for i in range(torch.cuda.device_count())] if torch.cuda.is_available() else [] + if cls.resource_id == "amd.com/gpu": + if not psutil.LINUX: + warnings.warn("AMD GPUs is currently only supported on Linux.", stacklevel=_STACK_LEVEL) + return [] + + # ROCm does not currently have the rocm_smi wheel. + # So we need to use the ctypes bindings directly. + # we don't want to use CLI because parsing is a pain. + sys.path.append("/opt/rocm/libexec/rocm_smi") + try: + from ctypes import byref + from ctypes import c_uint32 + + # refers to https://github.com/RadeonOpenCompute/rocm_smi_lib/blob/master/python_smi_tools/rsmiBindings.py + from rsmiBindings import rocmsmi + from rsmiBindings import rsmi_status_t + + device_count = c_uint32(0) + ret = rocmsmi.rsmi_num_monitor_devices(byref(device_count)) + if ret == rsmi_status_t.RSMI_STATUS_SUCCESS: + return [str(i) for i in range(device_count.value)] + return [] + except (ModuleNotFoundError, ImportError): + # In this case the binary is not found, returning empty list + return [] + finally: + sys.path.remove("/opt/rocm/libexec/rocm_smi") + else: + try: + from cuda import cuda + + err, *_ = cuda.cuInit(0) + if err != cuda.CUresult.CUDA_SUCCESS: + logger.warning("Failed to initialise CUDA", stacklevel=_STACK_LEVEL) + return [] + err, device_count = cuda.cuDeviceGetCount() + if err != cuda.CUresult.CUDA_SUCCESS: + logger.warning("Failed to get available devices under system.", stacklevel=_STACK_LEVEL) + return [] + return [str(i) for i in range(device_count)] + except (ImportError, RuntimeError): + return [] return visible_devices @@ -199,26 +218,17 @@ def _from_spec(cls: type[DynResource], spec: t.Any) -> list[str]: ) -@functools.lru_cache -def _raw_uuid_nvml() -> list[str] | None: +def _raw_device_uuid_nvml() -> list[str] | None: """Return list of device UUID as reported by NVML or None if NVML discovery/initialization failed.""" - try: - from cuda import cuda - except ImportError: - if sys.platform == "darwin": - raise RuntimeError("GPU is not available on Darwin system.") from None - raise RuntimeError( - "Failed to initialise CUDA runtime binding. Make sure that 'cuda-python' is setup correctly." - ) from None - from ctypes import CDLL from ctypes import byref + from ctypes import c_int from ctypes import c_void_p from ctypes import create_string_buffer try: nvml_h = CDLL("libnvidia-ml.so.1") - except OSError: + except Exception: warnings.warn("Failed to find nvidia binding", stacklevel=_STACK_LEVEL) return @@ -226,12 +236,13 @@ def _raw_uuid_nvml() -> list[str] | None: if rc != 0: warnings.warn("Can't initialize NVML", stacklevel=_STACK_LEVEL) return - err, dev_count = cuda.cuDeviceGetCount() - if err != cuda.CUresult.CUDA_SUCCESS: + dev_count = c_int(-1) + rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) + if rc != 0: warnings.warn("Failed to get available device from system.", stacklevel=_STACK_LEVEL) return uuids: list[str] = [] - for idx in range(dev_count): + for idx in range(dev_count.value): dev_id = c_void_p() rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id)) if rc != 0: @@ -267,7 +278,7 @@ def _validate(cls: type[DynResource], val: list[t.Any]): # correctly parse handle for el in val: if el.startswith("GPU-") or el.startswith("MIG-"): - uuids = _raw_uuid_nvml() + uuids = _raw_device_uuid_nvml() if uuids is None: raise ValueError("Failed to parse available GPUs UUID") if el not in uuids: diff --git a/src/openllm/cli.py b/src/openllm/cli.py index 3d4b5925..24ea7dd6 100644 --- a/src/openllm/cli.py +++ b/src/openllm/cli.py @@ -95,6 +95,8 @@ from .utils import set_quiet_mode if t.TYPE_CHECKING: import torch + from bentoml._internal.bento import BentoStore + from ._types import AnyCallable from ._types import ClickFunctionWrapper from ._types import DictStrAny @@ -1399,6 +1401,7 @@ def _start( ) +@inject def _build( model_name: str, /, @@ -1414,8 +1417,10 @@ def _build( runtime: t.Literal["ggml", "transformers"] = "transformers", dockerfile_template: str | None = None, overwrite: bool = False, - format: t.Literal["bento", "container"] = "bento", + push: bool = False, + containerize: bool = False, additional_args: list[str] | None = None, + bento_store: BentoStore = Provide[BentoMLContainer.bento_store], ) -> bentoml.Bento: """Package a LLM into a Bento. @@ -1455,14 +1460,17 @@ def _build( dockerfile_template: The dockerfile template to use for building BentoLLM. See https://docs.bentoml.com/en/latest/guides/containerization.html#dockerfile-template. overwrite: Whether to overwrite the existing BentoLLM. By default, this is set to ``False``. - format: The output format to build this LLM. By default it will build the BentoLLM. 'container' is equivalent of 'openllm build && bentoml containerize ' + push: Whether to push the result bento to BentoCloud. Make sure to login with 'bentoml cloud login' first. + containerize: Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'. + Note that 'containerize' and 'push' are mutually exclusive additional_args: Additional arguments to pass to ``openllm build``. + bento_store: Optional BentoStore for saving this BentoLLM. Default to the default BentoML local store. Returns: ``bentoml.Bento | str``: BentoLLM instance. This can be used to serve the LLM or can be pushed to BentoCloud. If 'format="container"', then it returns the default 'container_name:container_tag' """ - args: ListStr = [model_name, "--runtime", runtime, "--format", format] + args: ListStr = [sys.executable, "-m", "openllm", "build", model_name, "--machine", "--runtime", runtime] if quantize and bettertransformer: raise OpenLLMException("'quantize' and 'bettertransformer' are currently mutually exclusive.") @@ -1472,6 +1480,13 @@ def _build( if bettertransformer: args.append("--bettertransformer") + if containerize and push: + raise OpenLLMException("'containerize' and 'push' are currently mutually exclusive.") + if push: + args.extend(["--push"]) + if containerize: + args.extend(["--containerize"]) + if model_id: args.extend(["--model-id", model_id]) if build_ctx: @@ -1491,7 +1506,19 @@ def _build( if additional_args: args.extend(additional_args) - return build_command.main(args=args, standalone_mode=False) + try: + output = subprocess.check_output(args, env=os.environ.copy(), cwd=build_ctx or os.getcwd()) + except subprocess.CalledProcessError as e: + logger.error("Exception caught while building %s", model_name, exc_info=e) + if e.stderr: + raise OpenLLMException(e.stderr.decode("utf-8")) from None + raise OpenLLMException(str(e)) from None + # NOTE: This usually only concern BentoML devs. + pattern = r"^__tag__:[^:\n]+:[^:\n]+" + matched = re.search(pattern, output.decode("utf-8").strip(), re.MULTILINE) + assert matched is not None, f"Failed to find tag from output: {output}" + _, _, tag = matched.group(0).partition(":") + return bentoml.get(tag, _bento_store=bento_store) def _import_model( @@ -1564,12 +1591,13 @@ start, start_grpc, build, import_model, list_models = ( ) @model_id_option(click) @output_option +@click.option("--machine", is_flag=True, default=False, hidden=True) @click.option("--overwrite", is_flag=True, help="Overwrite existing Bento for given LLM if it already exists.") @workers_per_resource_option(click, build=True) -@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name="Optimisation options.") +@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name="Optimisation options") @quantize_option(cog.optgroup, build=True) @bettertransformer_option(cog.optgroup) -@cog.optgroup.option( +@click.option( "--runtime", type=click.Choice(["ggml", "transformers"]), default="transformers", @@ -1604,14 +1632,15 @@ start, start_grpc, build, import_model, list_models = ( type=click.File(), help="Optional custom dockerfile template to be used with this BentoLLM.", ) -@click.option( - "--format", - default="bento", - type=click.Choice(["bento", "container"]), - help="The output format for 'openllm build'. By default this will build a BentoLLM. 'container' is the shortcut of 'openllm build && bentoml containerize'.", - hidden=not get_debug_mode(), +@cog.optgroup.group(cls=cog.MutuallyExclusiveOptionGroup, name="Utilities options") +@cog.optgroup.option( + "--containerize", + default=False, + is_flag=True, + type=click.BOOL, + help="Whether to containerize the Bento after building. '--containerize' is the shortcut of 'openllm build && bentoml containerize'.", ) -@click.option( +@cog.optgroup.option( "--push", default=False, is_flag=True, @@ -1632,9 +1661,10 @@ def build_command( workers_per_resource: float | None, adapter_id: tuple[str, ...], build_ctx: str | None, + machine: bool, model_version: str | None, dockerfile_template: t.TextIO | None, - format: t.Literal["bento", "container"], + containerize: bool, push: bool, **attrs: t.Any, ): @@ -1665,6 +1695,9 @@ def build_command( # we are just doing the parsing here. adapter_map[_adapter_id] = adapter_name[0] if len(adapter_name) > 0 else None + if machine: + output = "porcelain" + if enable_features: enable_features = tuple(itertools.chain.from_iterable((s.split(",") for s in enable_features))) @@ -1759,7 +1792,11 @@ def build_command( if current_adapter_map_envvar is not None: os.environ["OPENLLM_ADAPTER_MAP"] = current_adapter_map_envvar - if output == "pretty": + if machine: + # NOTE: We will prefix the tag with __tag__ and we can use regex to correctly + # get the tag from 'bentoml.bentos.build|build_bentofile' + _echo(f"__tag__:{bento.tag}", fg="white") + elif output == "pretty": if not get_quiet_mode(): _echo("\n" + OPENLLM_FIGLET, fg="white") if not _previously_built: @@ -1792,12 +1829,10 @@ def build_command( else: _echo(bento.tag) - if format == "container" and push: - ctx.fail("'--format=container' and '--push' are mutually exclusive.") if push: client = BentoMLContainer.bentocloud_client.get() client.push_bento(bento) - elif format == "container": + elif containerize: backend = os.getenv("BENTOML_CONTAINERIZE_BACKEND", "docker") _echo(f"Building {bento} into a LLMContainer using backend '{backend}'", fg="magenta") if not bentoml.container.health(backend): diff --git a/src/openllm/models/stablelm/modeling_stablelm.py b/src/openllm/models/stablelm/modeling_stablelm.py index 5dc2051c..d439deb2 100644 --- a/src/openllm/models/stablelm/modeling_stablelm.py +++ b/src/openllm/models/stablelm/modeling_stablelm.py @@ -25,9 +25,11 @@ from ..._prompt import default_formatter if t.TYPE_CHECKING: import transformers # noqa import torch + import torch.amp else: transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers") torch = openllm.utils.LazyLoader("torch", globals(), "torch") + torch.amp = openllm.utils.LazyLoader("torch.amp", globals(), "torch.amp") logger = logging.getLogger(__name__) @@ -42,10 +44,7 @@ class StableLM(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTN @property def import_kwargs(self): - model_kwds = { - "torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32, - "device_map": "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, - } + model_kwds = {"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32} tokenizer_kwds: dict[str, t.Any] = {} return model_kwds, tokenizer_kwds @@ -103,5 +102,9 @@ class StableLM(openllm.LLM["transformers.GPTNeoXForCausalLM", "transformers.GPTN inputs = t.cast("torch.Tensor", self.tokenizer(prompt, return_tensors="pt")).to(self.device) with torch.inference_mode(): - tokens = self.model.generate(**inputs, **generation_kwargs) - return [self.tokenizer.decode(tokens[0], skip_special_tokens=True)] + if torch.cuda.is_available(): + with torch.amp.autocast("cuda", torch.float16): + tokens = self.model.generate(**inputs, **generation_kwargs) + else: + tokens = self.model.generate(**inputs, **generation_kwargs) + return [self.tokenizer.decode(tokens[0], skip_special_tokens=True)] diff --git a/src/openllm/testing.py b/src/openllm/testing.py index 44d69334..1f6caa2a 100644 --- a/src/openllm/testing.py +++ b/src/openllm/testing.py @@ -89,7 +89,7 @@ def prepare( implementation: LiteralRuntime = "pt", deployment_mode: t.Literal["container", "local"] = "local", clean_context: contextlib.ExitStack | None = None, - cleanup: bool = False, + cleanup: bool = True, ): if clean_context is None: clean_context = contextlib.ExitStack() diff --git a/tests/models/__snapshots__/flan_t5_test/test_flan_t5[container].json b/tests/models/__snapshots__/flan_t5_test/test_flan_t5[container].json new file mode 100644 index 00000000..38506cbd --- /dev/null +++ b/tests/models/__snapshots__/flan_t5_test/test_flan_t5[container].json @@ -0,0 +1,33 @@ +{ + "configuration": { + "generation_config": { + "diversity_penalty": 0.0, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "encoder_repetition_penalty": 1.0, + "epsilon_cutoff": 0.0, + "eta_cutoff": 0.0, + "length_penalty": 1.0, + "max_new_tokens": 10, + "min_length": 0, + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "remove_invalid_values": false, + "renormalize_logits": false, + "repetition_penalty": 1.0, + "temperature": 0.9, + "top_k": 50, + "top_p": 0.9, + "typical_p": 1.0, + "use_cache": true + } + }, + "responses": [ + "life is a complete physical life" + ] +} \ No newline at end of file diff --git a/tests/models/__snapshots__/flan_t5_test/test_flan_t5[local].json b/tests/models/__snapshots__/flan_t5_test/test_flan_t5[local].json new file mode 100644 index 00000000..6f1deb95 --- /dev/null +++ b/tests/models/__snapshots__/flan_t5_test/test_flan_t5[local].json @@ -0,0 +1,33 @@ +{ + "configuration": { + "generation_config": { + "diversity_penalty": 0.0, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "encoder_repetition_penalty": 1.0, + "epsilon_cutoff": 0.0, + "eta_cutoff": 0.0, + "length_penalty": 1.0, + "max_new_tokens": 10, + "min_length": 0, + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "remove_invalid_values": false, + "renormalize_logits": false, + "repetition_penalty": 1.0, + "temperature": 0.9, + "top_k": 50, + "top_p": 0.9, + "typical_p": 1.0, + "use_cache": true + } + }, + "responses": [ + "life is a state" + ] +} \ No newline at end of file diff --git a/tests/models/__snapshots__/opt_test/test_opt_125m[local].json b/tests/models/__snapshots__/opt_test/test_opt_125m[local].json new file mode 100644 index 00000000..b17a783d --- /dev/null +++ b/tests/models/__snapshots__/opt_test/test_opt_125m[local].json @@ -0,0 +1,34 @@ +{ + "configuration": { + "format_outputs": false, + "generation_config": { + "diversity_penalty": 0.0, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "encoder_repetition_penalty": 1.0, + "epsilon_cutoff": 0.0, + "eta_cutoff": 0.0, + "length_penalty": 1.0, + "max_new_tokens": 20, + "min_length": 0, + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "remove_invalid_values": false, + "renormalize_logits": false, + "repetition_penalty": 1.0, + "temperature": 0.75, + "top_k": 15, + "top_p": 1.0, + "typical_p": 1.0, + "use_cache": true + } + }, + "responses": [ + "What is Deep learning?\n\nDeep learning is a new, highly-advanced, and powerful tool for the deep learning" + ] +} \ No newline at end of file diff --git a/tools/dependencies.py b/tools/dependencies.py index e8f6df3e..7ef75f1e 100755 --- a/tools/dependencies.py +++ b/tools/dependencies.py @@ -18,8 +18,6 @@ from __future__ import annotations import dataclasses import os -import shutil -import subprocess import typing as t import inflection @@ -277,9 +275,6 @@ def main() -> int: f.write("-r nightly-requirements.txt\n-e .[all]\n") f.writelines([f"{v.to_str()}\n" for v in _NIGHTLY_MAPPING.values() if v.requires_gpu]) - if shutil.which("taplo"): - return subprocess.check_call(["taplo", "format", os.path.join(ROOT, "pyproject.toml")]) - return 0 diff --git a/typings/cuda/cuda.pyi b/typings/cuda/cuda.pyi index 982d3e5f..07ff7ab5 100644 --- a/typings/cuda/cuda.pyi +++ b/typings/cuda/cuda.pyi @@ -24,3 +24,4 @@ class CUdevice(_CUMixin): ... def cuDeviceGetCount() -> tuple[CUresult, int]: ... def cuDeviceGet(dev: int) -> tuple[CUresult, CUdevice]: ... +def cuInit(flags: int) -> tuple[CUresult]: ...