diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b067265e..e8c5e073 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -60,12 +60,14 @@ repos: openllm-client/src/openllm_client/pb.*| .github/.*| cz.py | + bench.py | hatch_build.py )$ additional_dependencies: - click==8.1.3 - peft - bentoml==1.1.1 + - build==0.10.0 - transformers>=4.31.0 - pandas-stubs - types-psutil diff --git a/bench.py b/bench.py new file mode 100755 index 00000000..dd2c670f --- /dev/null +++ b/bench.py @@ -0,0 +1,139 @@ +from __future__ import annotations +import asyncio +import json + +import aiohttp + +import openllm + +async def send_request(url, prompt, session, model, **attrs): + headers = {'accept': 'application/json', 'Content-Type': 'application/json'} + config = openllm.AutoConfig.for_model(model).model_construct_env(**attrs).model_dump() + data = {'prompt': prompt, 'llm_config': config, 'adapter_name': None} + async with session.post(url, headers=headers, data=json.dumps(data)) as response: + result = await response.text() + print('-' * 10 + '\n\n prompt:', prompt, '\nGeneration:', result, '\n\n' + '-' * 10) + +async def main(): + url = 'http://localhost:3000/v1/generate_stream' + # len=100 + prompts = [ + 'What is the meaning of life?', + 'Explain the concept of quantum entanglement.', + 'Describe the process of photosynthesis.', + 'What are the benefits of regular exercise?', + 'How does the internet work?', + 'Discuss the impact of climate change on ecosystems.', + 'Explain the principles of supply and demand in economics.', + 'What is the history of the Roman Empire?', + 'Describe the structure of a cell.', + 'Discuss the pros and cons of renewable energy sources.', + 'Explain the theory of relativity.', + 'What is the role of DNA in genetics?', + 'Describe the art movement of the Renaissance.', + 'Discuss the causes of World War I.', + 'What are the major functions of the human brain?', + 'Explain the process of evolution by natural selection.', + 'Describe the cultural significance of the Great Wall of China.', + 'What is the impact of social media on society?', + 'Discuss the life and works of Shakespeare.', + 'Explain the concept of artificial intelligence.', + 'What are the different types of chemical reactions?', + "Describe the structure of the Earth's atmosphere.", + 'Discuss the history of the civil rights movement.', + 'What are the economic implications of globalization?', + 'Explain the principles of good nutrition.', + 'Describe the major functions of the immune system.', + 'Discuss the impact of colonialism on Africa.', + 'What is the process of cellular respiration?', + 'Explain the importance of biodiversity.', + 'Discuss the causes and consequences of the Industrial Revolution.', + 'What are the fundamental principles of democracy?', + 'Describe the major components of a computer.', + 'Explain the concept of human rights.', + 'What is the role of enzymes in biological reactions?', + 'Discuss the history of space exploration.', + 'What are the ethical considerations in medical research?', + 'Describe the cultural significance of the Pyramids of Egypt.', + 'Explain the principles of classical physics.', + 'What is the impact of climate change on weather patterns?', + 'Discuss the major events of the American Revolution.', + 'What are the effects of pollution on the environment?', + 'Describe the process of protein synthesis.', + 'Explain the concept of sustainable agriculture.', + 'What is the history of the European Union?', + 'Discuss the impact of the Renaissance on art and culture.', + 'What are the key principles of marketing?', + 'Explain the structure of the periodic table.', + 'Describe the major types of renewable energy.', + 'Discuss the causes and consequences of the French Revolution.', + 'What is the role of the United Nations in international relations?', + 'Explain the principles of game theory in economics.', + 'What are the stages of human development?', + 'Describe the cultural significance of the Taj Mahal.', + 'Discuss the major themes in the works of Ernest Hemingway.', + 'What is the impact of automation on the workforce?', + 'Explain the concept of genetic engineering.', + 'What are the different types of chemical bonds?', + "Describe the layers of the Earth's atmosphere.", + "Discuss the history of the women's suffrage movement.", + 'What are the economic factors influencing consumer behavior?', + 'Explain the principles of conflict resolution.', + 'What is the role of neurotransmitters in the nervous system?', + 'Discuss the impact of colonialism on India.', + 'What is the process of mitosis?', + 'Explain the importance of water conservation.', + 'Describe the cultural significance of the Acropolis in Athens.', + 'Discuss the major philosophical ideas of Plato.', + 'What are the principles of investment in finance?', + 'Explain the structure of a virus.', + 'What is the history of the United Nations?', + 'Discuss the impact of technology on modern art.', + 'What are the key concepts in cognitive psychology?', + 'Describe the major types of non-renewable energy sources.', + 'Explain the causes and consequences of the Russian Revolution.', + 'What is the role of the World Health Organization in global health?', + 'Discuss the principles of ethics in business.', + 'What are the stages of the water cycle?', + 'Explain the concept of social justice.', + 'What is the impact of deforestation on climate change?', + 'Describe the process of meiosis.', + 'Discuss the cultural significance of the Sistine Chapel ceiling.', + 'What are the major themes in the novels of Jane Austen?', + 'Explain the role of branding in marketing.', + 'What is the history of the Internet?', + 'Discuss the impact of artificial intelligence on society.', + 'What are the principles of statistical analysis in research?', + 'Explain the structure of an atom.', + 'What is the significance of the Theory of Evolution by Charles Darwin?', + 'Describe the major types of renewable energy.', + 'Discuss the causes and consequences of the American Civil War.', + 'What is the role of the International Monetary Fund in global economics?', + 'Explain the principles of environmental conservation.', + 'What are the stages of the rock cycle?', + 'Describe the concept of cultural relativism.', + 'Discuss the major contributions of Leonardo da Vinci to art and science.', + 'What is the impact of globalization on cultural diversity?', + 'Explain the process of genetic inheritance.', + 'What are the different forms of government in the world?', + 'Describe the major types of pollution.', + 'Discuss the history of the labor movement.', + 'What are the principles of sustainable urban planning?', + 'Explain the role of hormones in the endocrine system.', + 'What is the cultural significance of the Great Barrier Reef?', + 'Discuss the major ideas of Friedrich Nietzsche.', + 'What is the impact of social media on political movements?', + 'Explain the concept of quantum computing.', + 'What are the principles of international diplomacy?', + 'Describe the major types of ocean ecosystems.', + 'Discuss the causes and consequences of the Cold War.', + 'What is the role of the World Trade Organization in global trade?', + 'Explain the principles of behavioral psychology.', + 'What are the stages of the nitrogen cycle?', + 'Describe the concept of cultural appropriation.', + 'Discuss the major works of Vincent van Gogh.', + ] + async with aiohttp.ClientSession() as session: + await asyncio.gather(*[send_request(url, prompt, session, 'llama', max_new_tokens=4096, top_p=0.21) for _, prompt in enumerate(prompts)]) + +if __name__ == '__main__': asyncio.run(main()) diff --git a/changelog.d/349.feat.md b/changelog.d/349.feat.md new file mode 100644 index 00000000..14d0bb97 --- /dev/null +++ b/changelog.d/349.feat.md @@ -0,0 +1,3 @@ +Added support for continuous batching via vLLM + +Currently benchmark shows that 100 concurrent requests shows around 1218 TPS on 1 A100 running meta-llama/Llama-2-13b-chat-hf diff --git a/openllm-core/src/openllm_core/_configuration.py b/openllm-core/src/openllm_core/_configuration.py index ea2b59b3..7cbcc2c4 100644 --- a/openllm-core/src/openllm_core/_configuration.py +++ b/openllm-core/src/openllm_core/_configuration.py @@ -465,7 +465,6 @@ class ModelSettings(t.TypedDict, total=False): # meta url: str - requires_gpu: bool trust_remote_code: bool service_name: NotRequired[str] requirements: t.Optional[ListStr] @@ -523,7 +522,6 @@ class _ModelSettingsAttr: 'cpu': 'pt', 'nvidia.com/gpu': 'pt' }, name_type='dasherize', - requires_gpu=False, url='', model_type='causal_lm', trust_remote_code=False, @@ -541,7 +539,6 @@ class _ModelSettingsAttr: architecture: str default_backend: t.Dict[LiteralResourceSpec, LiteralBackend] url: str - requires_gpu: bool trust_remote_code: bool service_name: str requirements: t.Optional[ListStr] @@ -737,8 +734,6 @@ class _ConfigAttr: '''The default backend to run LLM based on available accelerator. By default, it will be PyTorch (pt) for most models. For some models, such as Llama, it will use `vllm` or `flax`. It is a dictionary of key as the accelerator spec in k8s ('cpu', 'nvidia.com/gpu', 'amd.com/gpu', 'cloud-tpus.google.com/v2', ...) and the values as supported OpenLLM backend ('flax', 'tf', 'pt', 'vllm', 'ggml', 'mlc')''' __openllm_url__: str = Field(None) '''The resolved url for this LLMConfig.''' - __openllm_requires_gpu__: bool = Field(None) - '''Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU.''' __openllm_trust_remote_code__: bool = Field(None) '''Whether to always trust remote code''' __openllm_service_name__: str = Field(None) @@ -932,7 +927,6 @@ class LLMConfig(_ConfigAttr): __config__ = { "name_type": "lowercase", "trust_remote_code": True, - "requires_gpu": True, "timeout": 3600000, "url": "https://falconllm.tii.ae/", "requirements": ["einops", "xformers", "safetensors"], @@ -1108,8 +1102,6 @@ class LLMConfig(_ConfigAttr): @overload def __getitem__(self, item: t.Literal['url']) -> str: ... @overload - def __getitem__(self, item: t.Literal['requires_gpu']) -> bool: ... - @overload def __getitem__(self, item: t.Literal['trust_remote_code']) -> bool: ... @overload def __getitem__(self, item: t.Literal['service_name']) -> str: ... diff --git a/openllm-core/src/openllm_core/_typing_compat.py b/openllm-core/src/openllm_core/_typing_compat.py index 1054a233..b85e9abd 100644 --- a/openllm-core/src/openllm_core/_typing_compat.py +++ b/openllm-core/src/openllm_core/_typing_compat.py @@ -25,7 +25,7 @@ if t.TYPE_CHECKING: from .utils.lazy import VersionInfo -M = t.TypeVar('M', bound='t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.LLMEngine, peft.PeftModel]') +M = t.TypeVar('M', bound='t.Union[transformers.PreTrainedModel, transformers.Pipeline, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel, vllm.AsyncLLMEngine, peft.PeftModel]') T = t.TypeVar('T', bound='t.Union[transformers.PreTrainedTokenizerFast, transformers.PreTrainedTokenizer, transformers.PreTrainedTokenizerBase]') def get_literal_args(typ: t.Any) -> tuple[str, ...]: @@ -94,7 +94,9 @@ class LLMRunnable(bentoml.Runnable, t.Generic[M, T]): embeddings: RunnableMethod[LLMRunnable[M, T], [list[str]], EmbeddingsOutput] generate: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]] generate_one: RunnableMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal['generated_text'], str]]] - generate_iterator: RunnableMethod[LLMRunnable[M, T], [str], t.Generator[str, None, str]] + generate_iterator: RunnableMethod[LLMRunnable[M, T], [str], t.Iterator[t.Any]] + vllm_generate: RunnableMethod[LLMRunnable[M, T], [str], list[t.Any]] + vllm_generate_iterator: RunnableMethod[LLMRunnable[M, T], [str], t.AsyncGenerator[str, None]] class LLMRunner(bentoml.Runner, t.Generic[M, T]): __doc__: str @@ -111,7 +113,9 @@ class LLMRunner(bentoml.Runner, t.Generic[M, T]): embeddings: RunnerMethod[LLMRunnable[M, T], [list[str]], t.Sequence[EmbeddingsOutput]] generate: RunnerMethod[LLMRunnable[M, T], [str], list[t.Any]] generate_one: RunnerMethod[LLMRunnable[M, T], [str, list[str]], t.Sequence[dict[t.Literal['generated_text'], str]]] - generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.Generator[str, None, str]] + generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.Iterator[t.Any]] + vllm_generate: RunnerMethod[LLMRunnable[M, T], [str], list[t.Any]] + vllm_generate_iterator: RunnerMethod[LLMRunnable[M, T], [str], t.AsyncGenerator[str, None]] def __init__(self, runnable_class: type[LLMRunnable[M, T]], diff --git a/openllm-core/src/openllm_core/config/configuration_baichuan.py b/openllm-core/src/openllm_core/config/configuration_baichuan.py index c16b8bbb..c6b97732 100644 --- a/openllm-core/src/openllm_core/config/configuration_baichuan.py +++ b/openllm-core/src/openllm_core/config/configuration_baichuan.py @@ -40,7 +40,6 @@ class BaichuanConfig(openllm_core.LLMConfig): 'name_type': 'lowercase', 'trust_remote_code': True, 'timeout': 3600000, - 'requires_gpu': True, 'url': 'https://github.com/baichuan-inc/Baichuan-7B', 'requirements': ['cpm-kernels', 'sentencepiece'], 'architecture': 'BaiChuanForCausalLM', diff --git a/openllm-core/src/openllm_core/config/configuration_chatglm.py b/openllm-core/src/openllm_core/config/configuration_chatglm.py index b4450e2e..e5c5e671 100644 --- a/openllm-core/src/openllm_core/config/configuration_chatglm.py +++ b/openllm-core/src/openllm_core/config/configuration_chatglm.py @@ -44,7 +44,6 @@ class ChatGLMConfig(openllm_core.LLMConfig): 'name_type': 'lowercase', 'trust_remote_code': True, 'timeout': 3600000, - 'requires_gpu': True, 'url': 'https://github.com/THUDM/ChatGLM-6B', 'requirements': ['cpm-kernels', 'sentencepiece'], 'architecture': 'ChatGLMForConditionalGeneration', diff --git a/openllm-core/src/openllm_core/config/configuration_falcon.py b/openllm-core/src/openllm_core/config/configuration_falcon.py index 933dc5fe..bb292486 100644 --- a/openllm-core/src/openllm_core/config/configuration_falcon.py +++ b/openllm-core/src/openllm_core/config/configuration_falcon.py @@ -41,7 +41,6 @@ class FalconConfig(openllm_core.LLMConfig): __config__ = { 'name_type': 'lowercase', 'trust_remote_code': True, - 'requires_gpu': True, 'timeout': int(36e6), 'url': 'https://falconllm.tii.ae/', 'requirements': ['einops', 'xformers'], diff --git a/openllm-core/src/openllm_core/config/configuration_gpt_neox.py b/openllm-core/src/openllm_core/config/configuration_gpt_neox.py index 73de02a8..31e1d98a 100644 --- a/openllm-core/src/openllm_core/config/configuration_gpt_neox.py +++ b/openllm-core/src/openllm_core/config/configuration_gpt_neox.py @@ -45,7 +45,6 @@ class GPTNeoXConfig(openllm_core.LLMConfig): __config__ = { 'model_name': 'gpt_neox', 'start_name': 'gpt-neox', - 'requires_gpu': True, 'architecture': 'GPTNeoXForCausalLM', 'url': 'https://github.com/EleutherAI/gpt-neox', 'default_id': 'eleutherai/gpt-neox-20b', diff --git a/openllm-core/src/openllm_core/config/configuration_opt.py b/openllm-core/src/openllm_core/config/configuration_opt.py index f8b1afba..9bdfe152 100644 --- a/openllm-core/src/openllm_core/config/configuration_opt.py +++ b/openllm-core/src/openllm_core/config/configuration_opt.py @@ -10,7 +10,7 @@ START_OPT_COMMAND_DOCSTRING = '''\ Run a LLMServer for OPT model. \b -> See more information about falcon at [facebook/opt-66b](https://huggingface.co/facebook/opt-66b) +> See more information about OPT at [facebook/opt-66b](https://huggingface.co/facebook/opt-66b) \b ## Usage diff --git a/openllm-core/src/openllm_core/config/configuration_starcoder.py b/openllm-core/src/openllm_core/config/configuration_starcoder.py index b95abc38..05ba414b 100644 --- a/openllm-core/src/openllm_core/config/configuration_starcoder.py +++ b/openllm-core/src/openllm_core/config/configuration_starcoder.py @@ -36,11 +36,9 @@ class StarCoderConfig(openllm_core.LLMConfig): """ __config__ = { 'name_type': 'lowercase', - 'requires_gpu': True, 'url': 'https://github.com/bigcode-project/starcoder', 'architecture': 'GPTBigCodeForCausalLM', 'requirements': ['bitsandbytes'], - 'workers_per_resource': 0.5, 'default_id': 'bigcode/starcoder', 'model_ids': ['bigcode/starcoder', 'bigcode/starcoderbase'] } diff --git a/openllm-python/pyproject.toml b/openllm-python/pyproject.toml index 40b3b50e..47d78447 100644 --- a/openllm-python/pyproject.toml +++ b/openllm-python/pyproject.toml @@ -115,7 +115,7 @@ openai = ["openai", "tiktoken"] opt = ["flax>=0.7", "jax", "jaxlib", "tensorflow", "keras"] playground = ["jupyter", "notebook", "ipython", "jupytext", "nbformat"] starcoder = ["bitsandbytes"] -vllm = ["vllm>=0.1.6", "ray"] +vllm = ["vllm>=0.1.7", "ray"] [tool.hatch.version] fallback-version = "0.0.0" diff --git a/openllm-python/src/openllm/_assign.py b/openllm-python/src/openllm/_assign.py index db696ac6..c08fe4ab 100644 --- a/openllm-python/src/openllm/_assign.py +++ b/openllm-python/src/openllm/_assign.py @@ -9,7 +9,6 @@ import openllm from openllm.exceptions import OpenLLMException from openllm_core._configuration import _object_getattribute from openllm_core._configuration import _setattr_class -from openllm_core._schema import unmarshal_vllm_outputs from openllm_core._typing_compat import DictStrAny from openllm_core._typing_compat import ListStr from openllm_core._typing_compat import M @@ -22,6 +21,7 @@ from openllm_core.utils import LazyLoader from openllm_core.utils import codegen from openllm_core.utils import device_count from openllm_core.utils import first_not_none +from openllm_core.utils import get_debug_mode from openllm_core.utils import is_torch_available if t.TYPE_CHECKING: @@ -38,35 +38,37 @@ else: def import_model(fn: import_model_protocol[bentoml.Model, M, T]) -> t.Callable[[LLM[M, T]], bentoml.Model]: @functools.wraps(fn) def inner(self: LLM[M, T], *decls: t.Any, trust_remote_code: bool | None = None, **attrs: t.Any) -> bentoml.Model: - trust_remote_code = first_not_none(trust_remote_code, default=self.trust_remote_code) (model_decls, model_attrs), _ = self.llm_parameters decls = (*model_decls, *decls) attrs = {**model_attrs, **attrs} - return fn(self, *decls, trust_remote_code=trust_remote_code, **attrs) + return fn(self, *decls, trust_remote_code=first_not_none(trust_remote_code, default=self.trust_remote_code), **attrs) return inner -def load_model(fn: load_model_protocol[M, T]) -> t.Callable[[LLM[M, T]], M | vllm.LLMEngine]: +def load_model(fn: load_model_protocol[M, T]) -> t.Callable[[LLM[M, T]], M | vllm.AsyncLLMEngine]: @functools.wraps(fn) - def inner(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.LLMEngine: + def inner(self: LLM[M, T], *decls: t.Any, **attrs: t.Any) -> M | vllm.AsyncLLMEngine: if self.__llm_backend__ == 'vllm': num_gpus, dev = 1, device_count() if dev >= 2: num_gpus = min(dev // 2 * 2, dev) - # TODO: Do some more processing with token_id once we support token streaming try: - return vllm.LLMEngine.from_engine_args( - vllm.EngineArgs(model=self._bentomodel.path, - tokenizer=self._bentomodel.path if self.tokenizer_id == 'local' else self.tokenizer_id, - tokenizer_mode='auto', - tensor_parallel_size=num_gpus, - dtype='auto', - worker_use_ray=False)) + return vllm.AsyncLLMEngine.from_engine_args( + vllm.AsyncEngineArgs(model=self._bentomodel.path, + tokenizer=self._bentomodel.path if self.tokenizer_id == 'local' else self.tokenizer_id, + tokenizer_mode='auto', + tensor_parallel_size=num_gpus, + dtype='auto', + disable_log_requests=not get_debug_mode(), + worker_use_ray=False, + engine_use_ray=False)) except Exception as err: traceback.print_exc() raise OpenLLMException(f'Failed to initialise vLLMEngine due to the following error:\n{err}') from None else: (model_decls, model_attrs), _ = self.llm_parameters - return fn(self, *(*model_decls, *decls), **{**model_attrs, **attrs}) + decls = (*model_decls, *decls) + attrs = {**model_attrs, **attrs} + return fn(self, *decls, **attrs) return inner @@ -108,12 +110,6 @@ def make_llm_attributes(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]], N func_call = f"_impl_{cls.__name__}_{func}={cached_func_name} if {cached_func_name} is not _cached_LLMSerialisation_get('{func}') else __serialisation_{func}" lines.extend([f'{cached_func_name}=cls.{func}', func_call, _setattr_class(func, f'{impl_name}(_impl_{cls.__name__}_{func})')]) - # assign vLLM implementation - if cls.__llm_backend__ == 'vllm': - vllm_func = {f'_vllm_{it}': fn for it, fn in zip(('generate', 'generate_iterator', 'postprocess_generate'), (vllm_generate, vllm_generate_iterator, vllm_postprocess_generate))} - globs.update(vllm_func) - lines.extend([_setattr_class(it[6:], it) for it in vllm_func]) - interface_anns = codegen.get_annotations(LLMInterface) # cached attribute initialisation @@ -131,46 +127,3 @@ def make_llm_attributes(cls: type[LLM[M, T]]) -> t.Callable[[type[LLM[M, T]]], N lines.extend([_setattr_class(dunder_support(fn), f"cls.{fn} is not _cached_LLMFunction_get('{fn}')") for fn in bool_attr]) return codegen.generate_function(cls, '__assign_llm_attr', lines, args=('cls', *args), globs=globs, annotations={'cls': 't.Type[LLM]', 'return': None}) - -def vllm_postprocess_generate(self: LLM['vllm.LLMEngine', T], prompt: str, generation_result: list[dict[str, t.Any]], **_: t.Any) -> str: - return generation_result[0]['outputs'][0]['text'] - -def vllm_generate_iterator(self: LLM['vllm.LLMEngine', T], - prompt: str, - /, - *, - echo: bool = False, - stop: str | t.Iterable[str] | None = None, - stop_token_ids: list[int] | None = None, - **attrs: t.Any) -> t.Iterator[dict[str, t.Any]]: - request_id: str | None = attrs.pop('request_id', None) - if request_id is None: raise ValueError('request_id must not be None.') - if stop_token_ids is None: stop_token_ids = [] - stop_token_ids.append(self.tokenizer.eos_token_id) - stop_: set[str] = set() - if isinstance(stop, str) and stop != '': stop_.add(stop) - elif isinstance(stop, list) and stop != []: stop_.update(stop) - for tid in stop_token_ids: - if tid: stop_.add(self.tokenizer.decode(tid)) - - if self.config['temperature'] <= 1e-5: top_p = 1.0 - else: top_p = self.config['top_p'] - config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs) - self.model.add_request(request_id=request_id, prompt=prompt, sampling_params=config.to_sampling_config()) - while self.model.has_unfinished_requests(): - for request_output in self.model.step(): - prompt = request_output.prompt - if echo: text_outputs = [prompt + output.text for output in request_output.outputs] - else: text_outputs = [output.text for output in request_output.outputs] - yield {'text': text_outputs, 'error_code': 0} - if request_output.finished: break - -def vllm_generate(self: LLM['vllm.LLMEngine', T], prompt: str, **attrs: t.Any) -> list[dict[str, t.Any]]: - request_id: str | None = attrs.pop('request_id', None) - if request_id is None: raise ValueError('request_id must not be None.') - outputs: list[vllm.RequestOutput] = [] - # TODO: support prompt_token_ids - self.model.add_request(request_id=request_id, prompt=prompt, sampling_params=self.config.model_construct_env(**attrs).to_sampling_config()) - while self.model.has_unfinished_requests(): - outputs.extend([r for r in self.model.step() if r.finished]) - return [unmarshal_vllm_outputs(i) for i in outputs] diff --git a/openllm-python/src/openllm/_llm.py b/openllm-python/src/openllm/_llm.py index f3697ba6..7aeb0a9d 100644 --- a/openllm-python/src/openllm/_llm.py +++ b/openllm-python/src/openllm/_llm.py @@ -1,6 +1,7 @@ # mypy: disable-error-code="name-defined,attr-defined" from __future__ import annotations import abc +import asyncio import gc import inspect import logging @@ -22,6 +23,7 @@ import openllm_core from bentoml._internal.models.model import ModelSignature from openllm_core._configuration import FineTuneConfig from openllm_core._configuration import LLMConfig +from openllm_core._prompt import process_prompt from openllm_core._schema import EmbeddingsOutput from openllm_core._typing_compat import AdaptersMapping from openllm_core._typing_compat import AdaptersTuple @@ -46,7 +48,7 @@ from openllm_core.utils import ReprMixin from openllm_core.utils import apply from openllm_core.utils import bentoml_cattr from openllm_core.utils import codegen -from openllm_core.utils import device_count +from openllm_core.utils import ensure_exec_coro from openllm_core.utils import first_not_none from openllm_core.utils import generate_hash_from_file from openllm_core.utils import is_peft_available @@ -58,7 +60,6 @@ from openllm_core.utils import validate_is_path from ._assign import make_llm_attributes from ._quantisation import infer_quantisation_config from .exceptions import ForbiddenAttributeError -from .exceptions import GpuNotAvailableError from .exceptions import OpenLLMException from .utils import infer_auto_class @@ -67,6 +68,7 @@ if t.TYPE_CHECKING: import peft import torch import transformers + import vllm from openllm_core._configuration import PeftType from openllm_core.utils.representation import ReprArgs @@ -742,16 +744,13 @@ class LLM(LLMInterface[M, T], ReprMixin): @property def model(self) -> M: - # Run check for GPU - if self.config['requires_gpu'] and device_count() < 1: - raise GpuNotAvailableError(f'{self} only supports running with GPU (None available).') from None # NOTE: the signature of load_model here is the wrapper under _wrapped_load_model if self.__llm_model__ is None: model = self.load_model(*self._model_decls, **self._model_attrs) # If OOM, then it is probably you don't have enough VRAM to run this model. if self.__llm_backend__ == 'pt' and is_torch_available(): loaded_in_kbit = getattr(model, 'is_loaded_in_8bit', False) or getattr(model, 'is_loaded_in_4bit', False) or getattr(model, 'is_quantized', False) - if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit: + if torch.cuda.is_available() and torch.cuda.device_count() == 1 and not loaded_in_kbit and not isinstance(model, transformers.Pipeline): try: model = model.to('cuda') except Exception as err: @@ -971,92 +970,90 @@ class LLM(LLMInterface[M, T], ReprMixin): from ._generation import prepare_logits_processor len_prompt = len(prompt) + config = self.config.model_construct_env(**attrs) if stop_token_ids is None: stop_token_ids = [] stop_token_ids.append(self.tokenizer.eos_token_id) - logits_processor = prepare_logits_processor(self.config) + logits_processor = prepare_logits_processor(config) - input_ids = self.tokenizer(prompt).input_ids + with torch.inference_mode(): + input_ids = self.tokenizer(prompt).input_ids - if context_length is None: context_length = get_context_length(self.model.config) - max_src_len = context_length - self.config['max_new_tokens'] - 1 + if context_length is None: context_length = get_context_length(self.model.config) + max_src_len = context_length - config['max_new_tokens'] - 1 - input_ids = input_ids[-max_src_len:] - output_ids = list(input_ids) - input_echo_len = len(input_ids) + input_ids = input_ids[-max_src_len:] + output_ids = list(input_ids) + input_echo_len = len(input_ids) - past_key_values = out = token = None - for i in range(self.config['max_new_tokens']): - torch.cuda.synchronize() - if i == 0: # prefill - out = self.model(torch.as_tensor([input_ids], device=self.device), use_cache=True) - logits = out.logits - past_key_values = out.past_key_values - else: # decoding - out = self.model(input_ids=torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values) + past_key_values = out = token = None + finish_reason = None + for i in range(config['max_new_tokens']): + torch.cuda.synchronize() + if i == 0: # prefill + out = self.model(torch.as_tensor([input_ids], device=self.device), use_cache=True) + else: # decoding + out = self.model(torch.as_tensor([[token]], device=self.device), use_cache=True, past_key_values=past_key_values) logits = out.logits past_key_values = out.past_key_values + torch.cuda.synchronize() - if logits_processor: - if self.config['repetition_penalty'] > 1.0: - tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=self.device) + if logits_processor: + if config['repetition_penalty'] > 1.0: + tmp_output_ids: t.Any = torch.as_tensor([output_ids], device=self.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] else: - tmp_output_ids = None - last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] - else: - last_token_logits = logits[0, -1, :] + last_token_logits = logits[0, -1, :] - # Switch to CPU by avoiding some bugs in mps backend. - if self.device.type == 'mps': last_token_logits = last_token_logits.float().to('cpu') + # Switch to CPU by avoiding some bugs in mps backend. + if self.device.type == 'mps': last_token_logits = last_token_logits.float().to('cpu') - if self.config['temperature'] < 1e-5 or self.config['top_p'] < 1e-8: - token = int(torch.argmax(last_token_logits)) # greedy - else: - probs = torch.softmax(last_token_logits, dim=-1) - token = int(torch.multinomial(probs, num_samples=1)) - output_ids.append(token) - torch.cuda.synchronize() - - if token in stop_token_ids: stopped = True - else: stopped = False - - # Yield the output tokens - if i % stream_interval == 0 or i == self.config['max_new_tokens'] - 1 or stopped: - if echo: - tmp_output_ids = output_ids - rfind_start = len_prompt + if config['temperature'] < 1e-5 or config['top_p'] < 1e-8: + token = int(torch.argmax(last_token_logits)) # greedy else: - tmp_output_ids = output_ids[input_echo_len:] - rfind_start = 0 - output = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True) + probs = torch.softmax(last_token_logits, dim=-1) + indices = torch.multinomial(probs, num_samples=2) + token = int(indices.tolist()[0]) + output_ids.append(token) - partially_stopped = False - if stop: - if isinstance(stop, str): - pos = output.rfind(stop, rfind_start) - if pos != -1: output, stopped = output[:pos], True - else: partially_stopped = is_partial_stop(output, stop) - elif isinstance(stop, t.Iterable): - for each_stop in stop: - pos = output.rfind(each_stop, rfind_start) - if pos != -1: - output, stopped = output[:pos], True - break - else: - partially_stopped = is_partial_stop(output, each_stop) - if partially_stopped: break - else: raise ValueError('Invalid stop field type.') + stopped = token in stop_token_ids - # Prevent yielding partial stop sequence - if not partially_stopped: - yield {'text': output, 'usage': {'prompt_tokens': input_echo_len, 'completion_tokens': i, 'total_tokens': input_echo_len + i}, 'finish_reason': None} - if stopped: break + # Yield the output tokens + if i % stream_interval == 0 or i == config['max_new_tokens'] - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len_prompt + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + output = self.tokenizer.decode(tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True) - # Finish stream event, which contains finish reason - if i == self.config['max_new_tokens'] - 1: finish_reason = 'length' - elif stopped: finish_reason = 'stop' - else: finish_reason = None - yield {'text': output, 'usage': {'prompt_tokens': input_echo_len, 'completion_tokens': i, 'total_tokens': input_echo_len + i}, 'finish_reason': finish_reason} + partially_stopped = False + if stop: + if isinstance(stop, str): + pos = output.rfind(stop, rfind_start) + if pos != -1: output, stopped = output[:pos], True + else: partially_stopped = is_partial_stop(output, stop) + elif isinstance(stop, t.Iterable): + for each_stop in stop: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output, stopped = output[:pos], True + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: break + else: raise ValueError('Invalid stop field type.') + + # Prevent yielding partial stop sequence + if not partially_stopped: + yield {'text': output, 'usage': {'prompt_tokens': input_echo_len, 'completion_tokens': i, 'total_tokens': input_echo_len + i}, 'finish_reason': None} + if stopped: break + else: finish_reason = 'length' # finish stream events + if stopped: finish_reason = 'stop' + yield {'text': output, 'usage': {'prompt_tokens': input_echo_len, 'completion_tokens': i, 'total_tokens': input_echo_len + i}, 'finish_reason': finish_reason} # Clean del past_key_values, out gc.collect() @@ -1192,7 +1189,6 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate def generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]: adapter_name = attrs.pop('adapter_name', None) if adapter_name is not None: __self.set_adapter(adapter_name) - if __self.backend == 'vllm': attrs.setdefault('request_id', openllm_core.utils.gen_random_uuid()) return self.generate(prompt, **attrs) @bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore @@ -1207,8 +1203,7 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate if adapter_name is not None: __self.set_adapter(adapter_name) pre = 0 for outputs in self.generate_iterator(prompt, request_id=openllm_core.utils.gen_random_uuid(), **attrs): - output_text = outputs['text'][0] if __self.backend == 'vllm' else outputs['text'] - output_text = output_text.strip().split(' ') + output_text = outputs['text'].strip().split(' ') now = len(output_text) - 1 if now > pre: yield ' '.join(output_text[pre:now]) + ' ' @@ -1216,13 +1211,81 @@ def llm_runnable_class(self: LLM[M, T], embeddings_sig: ModelSignature, generate yield ' '.join(output_text[pre:]) + ' ' return ' '.join(output_text) + ' ' - return types.new_class( - self.__class__.__name__ + 'Runnable', (_Runnable,), {}, - lambda ns: ns.update({ - 'SUPPORTED_RESOURCES': ('nvidia.com/gpu', 'amd.com/gpu') if self.config['requires_gpu'] else ('nvidia.com/gpu', 'amd.com/gpu', 'cpu'), - '__module__': self.__module__, - '__doc__': self.config['env'].start_docstring - })) + @bentoml.Runnable.method(**method_signature(generate_sig)) # type: ignore + def vllm_generate(__self: _Runnable, prompt: str, **attrs: t.Any) -> list[t.Any]: + stop: str | t.Iterable[str] | None = attrs.pop('stop', None) + stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None) + adapter_name = attrs.pop('adapter_name', None) + if adapter_name is not None: __self.set_adapter(adapter_name) + request_id: str | None = attrs.pop('request_id', None) + if request_id is None: raise ValueError('request_id must not be None.') + + if stop_token_ids is None: stop_token_ids = [] + stop_token_ids.append(self.tokenizer.eos_token_id) + stop_: set[str] = set() + if isinstance(stop, str) and stop != '': stop_.add(stop) + elif isinstance(stop, list) and stop != []: stop_.update(stop) + for tid in stop_token_ids: + if tid: stop_.add(self.tokenizer.decode(tid)) + + if self.config['temperature'] <= 1e-5: top_p = 1.0 + else: top_p = self.config['top_p'] + config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs) + sampling_params = config.to_sampling_config() + + async def loop() -> list[str]: + async for request_output in t.cast('vllm.AsyncLLMEngine', self.model).generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id): + pass + return [output.text for output in request_output.outputs] + + try: + return asyncio.run(loop()) + except RuntimeError: + try: + return ensure_exec_coro(loop()) + except Exception: + raise + + @bentoml.Runnable.method(**method_signature(generate_iterator_sig)) # type: ignore + async def vllm_generate_iterator(__self: _Runnable, prompt: str, **attrs: t.Any) -> t.AsyncGenerator[str, None]: + # TODO: System prompt support + pre = 0 + prompt = process_prompt(prompt, None, False) + echo = attrs.pop('echo', False) + stop: str | t.Iterable[str] | None = attrs.pop('stop', None) + stop_token_ids: list[int] | None = attrs.pop('stop_token_ids', None) + adapter_name = attrs.pop('adapter_name', None) + if adapter_name is not None: __self.set_adapter(adapter_name) + request_id: str | None = attrs.pop('request_id', None) + if request_id is None: raise ValueError('request_id must not be None.') + + if stop_token_ids is None: stop_token_ids = [] + stop_token_ids.append(self.tokenizer.eos_token_id) + stop_: set[str] = set() + if isinstance(stop, str) and stop != '': stop_.add(stop) + elif isinstance(stop, list) and stop != []: stop_.update(stop) + for tid in stop_token_ids: + if tid: stop_.add(self.tokenizer.decode(tid)) + + if self.config['temperature'] <= 1e-5: top_p = 1.0 + else: top_p = self.config['top_p'] + config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs) + sampling_params = config.to_sampling_config() + async for request_output in t.cast('vllm.AsyncLLMEngine', self.model).generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id): + if echo: text_outputs = [prompt + output.text for output in request_output.outputs] + else: text_outputs = [output.text for output in request_output.outputs] + output_text = text_outputs[0] + output_text = output_text.strip().split(' ') + now = len(output_text) - 1 + if now > pre: + yield ' '.join(output_text[pre:now]) + ' ' + pre = now + yield ' '.join(output_text[pre:]) + ' ' + + return types.new_class(self.__class__.__name__ + 'Runnable', (_Runnable,), {}, + lambda ns: ns.update({ + 'SUPPORTED_RESOURCES': ('nvidia.com/gpu', 'amd.com/gpu', 'cpu'), '__module__': self.__module__, '__doc__': self.config['env'].start_docstring + })) def llm_runner_class(self: LLM[M, T]) -> type[LLMRunner[M, T]]: def available_adapters(_: LLMRunner[M, T]) -> PeftAdapterOutput: diff --git a/openllm-python/src/openllm/_service.py b/openllm-python/src/openllm/_service.py index fc2be132..c19e4adc 100644 --- a/openllm-python/src/openllm/_service.py +++ b/openllm-python/src/openllm/_service.py @@ -47,14 +47,24 @@ _JsonInput = bentoml.io.JSON.from_sample({'prompt': '', 'llm_config': llm_config async def generate_v1(input_dict: dict[str, t.Any]) -> openllm.GenerationOutput: qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict) config = qa_inputs.llm_config.model_dump() - responses = await runner.generate.async_run(qa_inputs.prompt, **{'adapter_name': qa_inputs.adapter_name, **config}) + if runner.backend == 'vllm': + responses = await runner.vllm_generate.async_run(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, request_id=openllm_core.utils.gen_random_uuid(), **config) + else: + responses = await runner.generate.async_run(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, **config) return openllm.GenerationOutput(responses=responses, configuration=config) @svc.api(route='/v1/generate_stream', input=_JsonInput, output=bentoml.io.Text(content_type='text/event-stream')) async def generate_stream_v1(input_dict: dict[str, t.Any]) -> t.AsyncGenerator[str, None]: echo = input_dict.pop('echo', False) qa_inputs = openllm.GenerationInput.from_llm_config(llm_config)(**input_dict) - return runner.generate_iterator.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, **qa_inputs.llm_config.model_dump()) + if runner.backend == 'vllm': + return runner.vllm_generate_iterator.async_stream(qa_inputs.prompt, + adapter_name=qa_inputs.adapter_name, + echo=echo, + request_id=openllm_core.utils.gen_random_uuid(), + **qa_inputs.llm_config.model_dump()) + else: + return runner.generate_iterator.async_stream(qa_inputs.prompt, adapter_name=qa_inputs.adapter_name, echo=echo, **qa_inputs.llm_config.model_dump()) @svc.api(route='/v1/metadata', input=bentoml.io.Text(), diff --git a/openllm-python/src/openllm/cli/_factory.py b/openllm-python/src/openllm/cli/_factory.py index 085e0ed1..3e1cd289 100644 --- a/openllm-python/src/openllm/cli/_factory.py +++ b/openllm-python/src/openllm/cli/_factory.py @@ -112,15 +112,7 @@ Available official model_id(s): [default: {llm_config['default_id']}] \b {orjson.dumps(llm_config['model_ids'], option=orjson.OPT_INDENT_2).decode()} -''', - ) - - if llm_config['requires_gpu'] and openllm.utils.device_count() < 1: - # NOTE: The model requires GPU, therefore we will return a dummy command - command_attrs.update({ - 'short_help': '(Disabled because there is no GPU available)', 'help': f'{model} is currently not available to run on your local machine because it requires GPU for inference.' - }) - return noop_command(group, llm_config, _serve_grpc, **command_attrs) +''') @group.command(**command_attrs) @start_decorator(llm_config, serve_grpc=_serve_grpc) @@ -230,19 +222,6 @@ Available official model_id(s): [default: {llm_config['default_id']}] return start_cmd -def noop_command(group: click.Group, llm_config: LLMConfig, _serve_grpc: bool, **command_attrs: t.Any) -> click.Command: - context_settings = command_attrs.pop('context_settings', {}) - context_settings.update({'ignore_unknown_options': True, 'allow_extra_args': True}) - command_attrs['context_settings'] = context_settings - # NOTE: The model requires GPU, therefore we will return a dummy command - @group.command(**command_attrs) - def noop(**_: t.Any) -> LLMConfig: - termui.echo('No GPU available, therefore this command is disabled', fg='red') - openllm.utils.analytics.track_start_init(llm_config) - return llm_config - - return noop - def start_decorator(llm_config: LLMConfig, serve_grpc: bool = False) -> t.Callable[[FC], t.Callable[[FC], FC]]: def wrapper(fn: FC) -> t.Callable[[FC], FC]: composed = openllm.utils.compose( diff --git a/openllm-python/src/openllm/cli/entrypoint.py b/openllm-python/src/openllm/cli/entrypoint.py index 669cdacb..a1dcd82b 100644 --- a/openllm-python/src/openllm/cli/entrypoint.py +++ b/openllm-python/src/openllm/cli/entrypoint.py @@ -651,8 +651,6 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo json_data[m] = { 'architecture': config['architecture'], 'model_id': config['model_ids'], - 'cpu': not config['requires_gpu'], - 'gpu': True, 'backend': backend, 'installation': f'"openllm[{m}]"' if m in OPTIONAL_DEPENDENCIES or config['requirements'] else 'openllm', } @@ -680,13 +678,11 @@ def models_command(ctx: click.Context, output: LiteralOutput, show_available: bo import tabulate tabulate.PRESERVE_WHITESPACE = True - # llm, architecture, url, model_id, installation, cpu, gpu, backend - data: list[str | tuple[str, str, list[str], str, LiteralString, LiteralString, tuple[LiteralBackend, ...]]] = [] + # llm, architecture, url, model_id, installation, backend + data: list[str | tuple[str, str, list[str], str, tuple[LiteralBackend, ...]]] = [] for m, v in json_data.items(): - data.extend([(m, v['architecture'], v['model_id'], v['installation'], '❌' if not v['cpu'] else '✅', '✅', v['backend'],)]) - column_widths = [ - int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 12), int(termui.COLUMNS / 12), int(termui.COLUMNS / 12), int(termui.COLUMNS / 4), - ] + data.extend([(m, v['architecture'], v['model_id'], v['installation'], v['backend'])]) + column_widths = [int(termui.COLUMNS / 12), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4), int(termui.COLUMNS / 6), int(termui.COLUMNS / 4)] if len(data) == 0 and len(failed_initialized) > 0: termui.echo('Exception found while parsing models:\n', fg='yellow') @@ -868,6 +864,7 @@ def query_command( res = client.query(prompt, return_response='raw', **{**client.configuration, **_memoized}) if output == 'pretty': response = client.config.postprocess_generate(prompt, res['responses']) + if isinstance(response, dict) and 'text' in response: response = response['text'] termui.echo('\n\n==Responses==\n', fg='white') termui.echo(response, fg=generated_fg) elif output == 'json': diff --git a/openllm-python/src/openllm/models/gpt_neox/modeling_gpt_neox.py b/openllm-python/src/openllm/models/gpt_neox/modeling_gpt_neox.py index d219f9fd..3f0a5607 100644 --- a/openllm-python/src/openllm/models/gpt_neox/modeling_gpt_neox.py +++ b/openllm-python/src/openllm/models/gpt_neox/modeling_gpt_neox.py @@ -14,19 +14,3 @@ class GPTNeoX(openllm.LLM['transformers.GPTNeoXForCausalLM', 'transformers.GPTNe def import_kwargs(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: import torch return {'device_map': 'auto' if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None}, {} - - def load_model(self, *args: t.Any, **attrs: t.Any) -> transformers.GPTNeoXForCausalLM: - import transformers - model = transformers.AutoModelForCausalLM.from_pretrained(self._bentomodel.path, *args, **attrs) - if self.config.use_half_precision: model.half() - return model - - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - import torch - with torch.inference_mode(): - return self.tokenizer.batch_decode( - self.model.generate(self.tokenizer(prompt, return_tensors='pt').to(self.device).input_ids, - do_sample=True, - generation_config=self.config.model_construct_env(**attrs).to_generation_config(), - pad_token_id=self.tokenizer.eos_token_id, - stopping_criteria=openllm.StoppingCriteriaList([openllm.StopOnTokens()]))) diff --git a/openllm-python/src/openllm/models/starcoder/modeling_starcoder.py b/openllm-python/src/openllm/models/starcoder/modeling_starcoder.py index c7dd3b9a..b50442b0 100644 --- a/openllm-python/src/openllm/models/starcoder/modeling_starcoder.py +++ b/openllm-python/src/openllm/models/starcoder/modeling_starcoder.py @@ -1,5 +1,4 @@ from __future__ import annotations -import logging import typing as t import bentoml @@ -31,16 +30,3 @@ class StarCoder(openllm.LLM['transformers.GPTBigCodeForCausalLM', 'transformers. return bentoml.transformers.save_model(self.tag, model, custom_objects={'tokenizer': tokenizer}, labels=generate_labels(self)) finally: torch.cuda.empty_cache() - - def generate(self, prompt: str, **attrs: t.Any) -> list[str]: - import torch - with torch.inference_mode(): - # eos_token_id=self.tokenizer.convert_tokens_to_ids("<|end|>"), # NOTE: this is for finetuning starcoder - # NOTE: support fine-tuning starcoder - result_tensor = self.model.generate(self.tokenizer.encode(prompt, return_tensors='pt').to(self.device), - do_sample=True, - pad_token_id=self.tokenizer.eos_token_id, - generation_config=self.config.model_construct_env(**attrs).to_generation_config()) - # TODO: We will probably want to return the tokenizer here so that we can manually process this - # return (skip_special_tokens=False, clean_up_tokenization_spaces=False)) - return self.tokenizer.batch_decode(result_tensor[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) diff --git a/openllm-python/tests/_strategies/_configuration.py b/openllm-python/tests/_strategies/_configuration.py index 5ad3a60c..35a58d90 100644 --- a/openllm-python/tests/_strategies/_configuration.py +++ b/openllm-python/tests/_strategies/_configuration.py @@ -20,7 +20,6 @@ def model_settings(draw: st.DrawFn): 'model_ids': st.lists(st.text(), min_size=1), 'architecture': st.text(min_size=1), 'url': st.text(), - 'requires_gpu': st.booleans(), 'trust_remote_code': st.booleans(), 'requirements': st.none() | st.lists(st.text(), min_size=1), 'default_backend': st.dictionaries(st.sampled_from(['cpu', 'nvidia.com/gpu']), st.sampled_from(['vllm', 'pt', 'tf', 'flax'])), diff --git a/tools/dependencies.py b/tools/dependencies.py index 22e11e5e..21a85ddc 100755 --- a/tools/dependencies.py +++ b/tools/dependencies.py @@ -133,7 +133,7 @@ AGENTS_DEPS = ['transformers[agents]>=4.30', 'diffusers', 'soundfile'] PLAYGROUND_DEPS = ['jupyter', 'notebook', 'ipython', 'jupytext', 'nbformat'] GGML_DEPS = ['ctransformers'] GPTQ_DEPS = ['auto-gptq[triton]>=0.4.2', 'optimum>=1.12.0'] -VLLM_DEPS = ['vllm>=0.1.6', 'ray'] +VLLM_DEPS = ['vllm>=0.1.7', 'ray'] _base_requirements: dict[str, t.Any] = { inflection.dasherize(name): config_cls.__openllm_requirements__ for name, config_cls in openllm.CONFIG_MAPPING.items() if config_cls.__openllm_requirements__ diff --git a/tools/update-config-stubs.py b/tools/update-config-stubs.py index b22d7735..3c2b0e3c 100755 --- a/tools/update-config-stubs.py +++ b/tools/update-config-stubs.py @@ -49,7 +49,6 @@ _value_docstring = { ```''', 'default_backend': '''The default backend to run LLM based on available accelerator. By default, it will be PyTorch (pt) for most models. For some models, such as Llama, it will use `vllm` or `flax`. It is a dictionary of key as the accelerator spec in k8s ('cpu', 'nvidia.com/gpu', 'amd.com/gpu', 'cloud-tpus.google.com/v2', ...) and the values as supported OpenLLM backend ('flax', 'tf', 'pt', 'vllm', 'ggml', 'mlc')''', 'url': 'The resolved url for this LLMConfig.', - 'requires_gpu': 'Determines if this model is only available on GPU. By default it supports GPU and fallback to CPU.', 'trust_remote_code': 'Whether to always trust remote code', 'service_name': "Generated service name for this LLMConfig. By default, it is \"generated_{model_name}_service.py\"", 'requirements': 'The default PyPI requirements needed to run this given LLM. By default, we will depend on bentoml, torch, transformers.',